package com.xebialabs.deployit.engine.tasker.distribution

import akka.actor.{Actor, ActorRef, ActorSelection, Cancellable, PoisonPill}
import akka.serialization.SerializationExtension
import com.xebialabs.deployit.engine.tasker.distribution.MessageSequence._
import com.xebialabs.deployit.engine.tasker.distribution.MessageSequenceReceiver.ChunkingTimeoutPropName
import grizzled.slf4j.Logging

import java.util.UUID
import java.util.concurrent.TimeUnit
import scala.concurrent.ExecutionContext
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}

object MessageSequence {

  case class Start(id: String, sender: ActorRef, totalSize: Int, chunks: Int, serializer: Int, clazz: Class[_ <: AnyRef])

  case class Progress(start: Start, bytes: Array[Byte], timeout: Cancellable)

  case class Chunk(id: String, chunkNumber: Int, chunkSize: Int, bytes: Array[Byte])

  case class ChunkTimeout(id: String)

  /**
    * Sent if a partially completed chunking session times out.
    *
    * @param id The ID of the chunking session.  This will match the value returned by sendChunked().
    */
  case class Timeout(id: String)

}

trait MaxSizeMessageSequenceSender extends MessageSequenceSender {
  this: Actor with Logging =>

  val MaxFrameSizePropName = "akka.remote.artery.advanced.maximum-frame-size"

  override val chunkSize: Int = calcMaxChunkSize

  private def calcMaxChunkSize: Int = {
    val (baseChunkBytes, _) = serialize(Chunk(generateId, 0, 0, new Array[Byte](0)))
    val maxFrameSize = context.system.settings.config.getBytes(MaxFrameSizePropName).toInt
    val buffer = 1000 // to ensure that chunk fits to the max frame size after sending the message
    if (baseChunkBytes.length >= maxFrameSize) {
      logger.warn(s"Chunking is not possible. Base chunk size[$baseChunkBytes] is more then $MaxFrameSizePropName[$maxFrameSize].")
      Integer.MAX_VALUE
    } else
      maxFrameSize - buffer - baseChunkBytes.length
  }
}

class ChunkReceivingForwarder(target: ActorSelection) extends Actor with Logging with MessageSequenceReceiver {
  override def receive: Receive = {
    case msg =>
      target ! msg
      context.become(waitForResponse(sender()))
  }
  def waitForResponse(origSender: ActorRef): Receive = receiveChunks orElse {
    case msg =>
      origSender ! msg
      self ! PoisonPill
  }
}

class ChunkSendingForwarder(target: ActorSelection, override val shouldNotChunk: Boolean) extends Actor with Logging with MaxSizeMessageSequenceSender {
  override def receive: Receive = {
    case msg: AnyRef =>
      sendChunked(target, msg)
      context.become(waitForResponse(sender()))
  }
  def waitForResponse(origSender: ActorRef): Receive = {
    case msg =>
      origSender ! msg
      self ! PoisonPill
  }

}

/**
 * Trait for sending of a message as a sequence of chunks.
 */
trait MessageSequenceSender {
  this: Actor with Logging =>

  private implicit val ec: ExecutionContext = context.system.dispatcher
  private val serialization = SerializationExtension(context.system)

  val chunkSize: Int
  val shouldNotChunk: Boolean

  def generateId: String = UUID.randomUUID.toString

  def sendChunked(to: ActorRef, msg: AnyRef): String = sendChunked(self, to, msg)
  def sendChunked(to: ActorSelection, msg: AnyRef): String = sendChunked(self, to, msg)
  def sendChunked(snd: ActorRef, to: ActorRef, msg: AnyRef): String = sendChunked(snd, to.tell _, msg)
  def sendChunked(snd: ActorRef, to: ActorSelection, msg: AnyRef): String = sendChunked(snd, to.tell _, msg)

  private def sendChunked(snd: ActorRef, send: (Any, ActorRef) => Unit, msg: AnyRef): String = {
    if (shouldNotChunk) {
      logger.trace(s"Sending not chunked ${msg.getClass.getName}, configured off.")
      send(msg, snd)
      null
    } else {
      startChunked(snd, send, msg).map(Function.tupled((bytes: Array[Byte], start: Start) => {
        sendRange(send, start, bytes, 0, start.chunks)
        start.id
      })).orNull
    }
  }

  private def startChunked(snd: ActorRef, send: (Any, ActorRef) => Unit, msg: AnyRef): Option[(Array[Byte], Start)] = {
    val (bytes, serializerId) = serialize(msg)

    val totalSize = bytes.length
    val add = if (totalSize % chunkSize == 0) 0 else 1
    val numChunks = totalSize / chunkSize + add

    if (numChunks == 1) {
      logger.debug(s"Sending not chunked ${msg.getClass.getName}, size=$totalSize, chunk size=$chunkSize.")
      send(msg, snd)
      Option.empty
    } else {
      val id = generateId
      logger.debug(s"Sending chunked ${msg.getClass.getName}, size=$totalSize, id=$id.")
      val start = Start(id, snd, totalSize, numChunks, serializerId, msg.getClass)
      send(start, self)
      Option((bytes, start))
    }
  }

  private def sendRange(send: (AnyRef, ActorRef) => Unit, start: Start, bytes: Array[Byte], beginChunk: Int, endChunk: Int): Unit = {
    var chunkNum = beginChunk
    import java.util
    while (chunkNum < endChunk) {
      val begin = chunkNum * chunkSize
      val end = Math.min(begin + chunkSize, start.totalSize)
      val chunkBytes = util.Arrays.copyOfRange(bytes, begin, end)
      val chunk = Chunk(start.id, chunkNum, chunkSize, chunkBytes)
      send(chunk, self)
      chunkNum += 1
    }
  }

  protected def serialize(obj: AnyRef): (Array[Byte], Int) = {
    val serializer = serialization.findSerializerFor(obj)
    (serializer.toBinary(obj), serializer.identifier)
  }
}

object MessageSequenceReceiver {
  val ChunkingTimeoutPropName = "akka.remote.chunking.timeout"
}

/**
  * Trait for receiving messages as a sequence of chuncks.
  */
trait MessageSequenceReceiver {
  this: Actor with Logging =>

  private implicit val ec: ExecutionContext = context.system.dispatcher
  private val serialization = SerializationExtension(context.system)
  private var inFlight = Map.empty[String, Progress]

  /**
    * A partially finished chunking session will timeout after this interval. Override as required.
    */
  val chunkingTimeout: FiniteDuration =
    FiniteDuration(context.system.settings.config.getDuration(ChunkingTimeoutPropName).toMillis, TimeUnit.MILLISECONDS)

  def receiveChunks: Receive = {
    case s: Start =>
      inFlight += s.id -> Progress(s, new Array[Byte](s.totalSize), scheduleTimeout(s.id))

    case Chunk(id, chunkNumber, chunkSize, bytes) =>
      inFlight.get(id) match {
        case None => logMissingInFlight(id)
        case Some(progress) =>
          progress.timeout.cancel()
          logger.debug(s"Received #${chunkNumber + 1}/${progress.start.chunks} in $id")
          val offset = chunkNumber * chunkSize
          System.arraycopy(bytes, 0, progress.bytes, offset, bytes.length)

          if (chunkNumber == progress.start.chunks - 1) {
            completeChunking(progress)
          } else {
            inFlight += id -> progress.copy(timeout = scheduleTimeout(id))
          }
      }

    case ChunkTimeout(id) =>
      inFlight.get(id) match {
        case None => logMissingInFlight(id)
        case Some(progress) =>
          logger.warn(s"Chunk timed out: ${progress.start.id}")
          inFlight -= id
          self ! Timeout(id)
      }
  }

  private def completeChunking(progress: Progress): Unit = {
    deserialize(progress.bytes, progress.start.serializer, progress.start.clazz) match {
      case Failure(ex) => logger.error(s"Exception deserializing ${progress.start.id}: ${ex.getMessage}", ex)
      case Success(obj) => self.tell(obj, progress.start.sender)
    }
    inFlight -= progress.start.id
  }

  protected def deserialize(bytes: Array[Byte], serializerId: Int, clazz: Class[_ <: AnyRef]): Try[AnyRef] = {
    serialization.deserialize(bytes, serializerId, clazz.getName)
  }

  private def logMissingInFlight(id: String): Unit =
    logger.error(s"Unexpected chunking $id.  Currently in-flight: ${inFlight.keys}")

  private def scheduleTimeout(id: String): Cancellable =
    context.system.scheduler.scheduleOnce(chunkingTimeout, self, ChunkTimeout(id))
}

/**
 * Implementation of the Message Sequence pattern from EIP
 * (http://www.eaipatterns.com/MessageSequence.html) for sending messages larger than
 * akka.remote.artery.advanced.maximum-frame-size.
 *
 * Mix this trait into the actors sending and receiving the large messages.
 *
 * To send, use `sendChunked(to, message)` instead of `to ! message`.
 *
 * To receive, use
 * {{{
 * def receive = receiveChunks orElse {
 *   case Task
 * }
 * }}}
 */
trait MessageSequence extends MessageSequenceSender with MessageSequenceReceiver {
  this: Actor with Logging =>
}
