package com.xebialabs.satellite.streaming

import java.nio.file.Path

import akka.NotUsed
import akka.stream._
import akka.stream.scaladsl._
import akka.stream.stage._
import akka.util.{ByteString, ByteStringBuilder}
import com.xebialabs.satellite.protocol.UploadId
import grizzled.slf4j.Logging

object DownloadStages extends Logging {

  import concurrent.duration._

  object Throttling {
    def flow(implicit streamConfig: StreamConfig) = streamConfig.throttleSpeed match {
      case Some(limit) =>
        Flow.fromGraph(GraphDSL.create() { implicit builder =>
          import GraphDSL.Implicits._
          val packetSize = computePacketSize(limit, streamConfig.tickDuration)

          val zipped = builder.add(Zip[Unit, ByteString]())
          val ticker = builder.add(Source.tick(0.second, streamConfig.tickDuration, ()))
          val backPressure = builder.add(Flow[ByteString].buffer(2, OverflowStrategy.backpressure))
          val bufferedBeforeDiskFlow = builder.add(bufferBeforeDisk(10, packetSize))
          val createPacketFlow = builder.add(createPacket(packetSize))

          ticker ~> zipped.in0
          createPacketFlow ~> zipped.in1
          zipped.out ~> builder.add(mapToBytes) ~> backPressure ~> bufferedBeforeDiskFlow

          FlowShape(createPacketFlow.in, bufferedBeforeDiskFlow.out)
        })

      case None =>
        Flow[ByteString]
    }

    private def bufferBeforeDisk(grouped: Int, sizeInKb: Int): Flow[ByteString, ByteString, NotUsed] =
      Flow[ByteString].grouped(grouped * sizeInKb).map { packets =>
        val b = new ByteStringBuilder
        b.sizeHint(grouped * sizeInKb * kbytes)
        packets.foreach(b.++=)
        b.result()
      }

    private def createPacket(sizeInKb: Int) = Flow[ByteString]
      .map(_.grouped(kbytes))
      .flatMapConcat(b => Source.fromIterator(() => b))
      .grouped(sizeInKb)
      .map { packets =>
        val b = new ByteStringBuilder
        b.sizeHint(sizeInKb * kbytes)
        packets.foreach(b.++=(_))
        b.result()
      }

    private def computePacketSize(limit: Int, tickDuration: FiniteDuration): Int = {
      val timeFactor = (1.second / tickDuration).toInt
      val packetSize = limit / timeFactor
      packetSize.max(1)
    }

    private def mapToBytes = Flow[(Unit, ByteString)].map(_._2)

    private val kbytes = 1024
  }

  object WriteToFile {

    def flow(id: UploadId, filePath: Path, writer: ChannelFactory) = Flow[ByteString].via(completeOn(id, filePath, writer))

    private val in: Inlet[ByteString] = Inlet.create("WriteToFile.in")

    private val out: Outlet[ByteString] = Outlet.create("WriteToFile.out")

    private val flowShape: FlowShape[ByteString, ByteString] = FlowShape.of(in, out)

    def completeOn(uploadId: UploadId, filePath: Path, writer: ChannelFactory) = new GraphStage[FlowShape[ByteString, ByteString]] with Logging {

      val channel = writer.createFileChannel(filePath)

      override def createLogic(inheritedAttributes: Attributes) = new GraphStageLogic(shape) with InHandler with OutHandler {

        override def onPush() = {
          val elem: ByteString = grab(in)
          logger.trace(s"[$uploadId] written ${elem.length} bytes to $filePath")
          channel.write(elem.asByteBuffer)
          push(out, elem)
        }

        override def onUpstreamFinish() = {
          logger.debug(s"Closing file channel $filePath onUpstreamFinish")
          close()
          super.onUpstreamFinish()
        }

        override def onUpstreamFailure(cause: Throwable) = {
          logger.error(s"Closing file channel $filePath because of error. File might be corrupted", cause)
          close()
          super.onUpstreamFailure(cause)
        }

        def close(): Unit = {
          if (channel.isOpen) {
            channel.close()
          }
        }

        override def onPull(): Unit = pull(in)

        setHandlers(in, out, this)
      }

      override def shape: FlowShape[ByteString, ByteString] = flowShape
    }
  }

}
