package com.xebialabs.satellite.streaming

import java.nio.file.Path

import akka.stream.OverflowStrategy
import akka.stream.scaladsl._
import akka.stream.stage.{Context, PushStage, SyncDirective, TerminationDirective}
import akka.util.{ByteString, ByteStringBuilder}
import com.xebialabs.satellite.protocol.UploadId
import grizzled.slf4j.Logging

object DownloadStages extends Logging {

  import concurrent.duration._

  object Throttling {
    def flow(implicit streamConfig: StreamConfig) = streamConfig.throttleSpeed match {
      case Some(limit) =>
        Flow() { implicit builder =>
          import akka.stream.scaladsl.FlowGraph.Implicits._
          val packetSize = computePacketSize(limit, streamConfig.tickDuration)

          val zipped = builder.add(Zip[Unit, ByteString]())
          val ticker = builder.add(Source(0.second, streamConfig.tickDuration, ()))
          val backPressure = builder.add(Flow[ByteString].buffer(2, OverflowStrategy.backpressure))
          val bufferedBeforeDiskFlow = builder.add(bufferBeforeDisk(10, packetSize))
          val createPacketFlow = builder.add(createPacket(packetSize))

          ticker ~> zipped.in0
          createPacketFlow ~> zipped.in1
          zipped.out ~> builder.add(mapToBytes) ~> backPressure ~> bufferedBeforeDiskFlow

          (createPacketFlow.inlet, bufferedBeforeDiskFlow.outlet)
        }

      case None =>
        Flow[ByteString]
    }

    private def bufferBeforeDisk(grouped: Int, sizeInKb: Int): Flow[ByteString, ByteString, Unit] =
      Flow[ByteString].grouped(grouped * sizeInKb).map { packets =>
        val b = new ByteStringBuilder
        b.sizeHint(grouped * sizeInKb * kbytes)
        packets.foreach(b.++=)
        b.result()
      }

    private def createPacket(sizeInKb: Int) = Flow[ByteString]
      .map(_.grouped(kbytes))
      .map(b => Source(() => b))
      .flatten(FlattenStrategy.concat[ByteString])
      .grouped(sizeInKb)
      .map { packets =>
      val b = new ByteStringBuilder
      b.sizeHint(sizeInKb * kbytes)
      packets.foreach(b.++=(_))
      b.result()
    }

    private def computePacketSize(limit: Int, tickDuration: FiniteDuration): Int = {
      val timeFactor = (1.second / tickDuration).toInt
      val packetSize = limit / timeFactor
      packetSize.max(1)
    }

    private def mapToBytes = Flow[(Unit, ByteString)].map(_._2)

    private val kbytes = 1024
  }

  object WriteToFile {

    def flow(id: UploadId, filePath: Path, writer: ChannelFactory) = Flow[ByteString].transform(() => new FileWriteStage(id, filePath, writer))

    private class FileWriteStage(uploadId: UploadId, filePath: Path, writer: ChannelFactory) extends PushStage[ByteString, ByteString] with Logging {

      val channel = writer.createFileChannel(filePath)

      override def onPush(elem: ByteString, ctx: Context[ByteString]): SyncDirective = {
        logger.trace(s"[$uploadId] written ${elem.length} bytes to $filePath")
        channel.write(elem.asByteBuffer)
        ctx.push(elem)
      }

      override def onUpstreamFinish(ctx: Context[ByteString]): TerminationDirective = {
        logger.debug(s"Closing file channel $filePath onUpstreamFinish")
        close()
        super.onUpstreamFinish(ctx)
      }

      override def onUpstreamFailure(cause: Throwable, ctx: Context[ByteString]): TerminationDirective = {
        logger.error(s"Closing file channel $filePath because of error. File might be corrupted", cause)
        close()
        super.onUpstreamFailure(cause, ctx)
      }

      def close(): Unit = {
        if (channel.isOpen) {
          channel.close()
        }
      }

      override def onDownstreamFinish(ctx: Context[ByteString]): TerminationDirective = {
        logger.debug(s"Closing file channel $filePath onDownstreamFinish")
        close()
        super.onDownstreamFinish(ctx)

      }
    }
  }
}
