package com.xebialabs.deployit.engine.tasker.distribution

import java.util.concurrent.atomic.AtomicInteger

import akka.actor.{Actor, ActorRef, ActorSelection, Cancellable, ExtendedActorSystem, PoisonPill, Props, Terminated}
import com.xebialabs.deployit.engine.api.distribution.TaskExecutionWorkerRepository
import com.xebialabs.deployit.engine.api.dto.Worker
import com.xebialabs.deployit.engine.tasker.{ActorContextCreationSupport, TaskRegistryEmpty}
import WorkerManager.messages._
import com.xebialabs.deployit.engine.tasker.distribution.messages.{WorkerConfigurationMismatch, WorkerRegistered}
import com.xebialabs.deployit.engine.tasker.distribution.util.workerAddress
import com.xebialabs.xlplatform.settings.CommonSettings
import grizzled.slf4j.Logging

import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.concurrent.duration._
import scala.language.postfixOps

private object util {
  def workerAddress(worker: ActorRef): String = worker.path.address.toString
  def isLocal(worker: ActorRef): Boolean = worker.path.address.hasLocalScope
}

object WorkerManager {
  val name = "worker-manager"

  def props(repository: TaskExecutionWorkerRepository, configurationHash: String, shouldNotChunk: Boolean = true): Props = Props(new WorkerManager(repository, configurationHash, shouldNotChunk))

  object messages {

    abstract class WorkerAware {
      type Self <: WorkerAware
      def withWorkerId(workerId: Integer): Self
    }

    case class SelectOne(path: String, msg: Any)
    case class Publish(path: String, aggregator: ActorRef)
    case object NoWorkersError

    class Found[T](private val result: Option[T])
    object Found {
      def unapply[T](arg: Found[T]): Option[Option[T]] = Some(arg.result)
    }
    class FoundAll[T](private val result: Seq[T])
    object FoundAll {
      def unapply[T](arg: FoundAll[T]): Option[Seq[T]] = Some(arg.result)
    }

    case class ShutdownWorker(address: List[String], force: Boolean = false)
    case class WorkerShutdownStarted(workers: List[String])
    case class WorkerDraining()

    case class FetchWorkers()
    case class WorkersFetched(healthy: List[String], incompatible: List[String], draining: List[String])
  }
}

protected object messages {

  case class RegisterWorker(worker: ActorRef, workerName: String, configHash: String, workerPublicKey: Array[Byte])
  case class WorkerRegistered()
  case class WorkerConfigurationMismatch()

  case class UnregisterWorker()

  case class ToWorker(path: String, msg: Any)
  case class Distribute(workers: Set[ActorRef], path: String, origSender: ActorRef)

  case object Connect
  case object ScheduleReconnect
  case class ReconnectWorker(worker: ActorRef, workerName: String, configHash: String, workerPublicKey: Array[Byte])

  case object Timeout

}

class WorkerManager(workerRepository: TaskExecutionWorkerRepository, configurationHash: String, override val shouldNotChunk: Boolean) extends  Actor with Logging with MaxSizeMessageSequenceSender {

  import messages._

  override def receive: Receive = manageWorkers(WorkersState(Set(), Set(), Set()))

  private case class WorkersState(workers: Set[ActorRef], quarantine: Set[ActorRef], draining: Set[ActorRef]) {
    def withActive(worker: ActorRef) : WorkersState = WorkersState(workers + worker, quarantine, draining)
    def withQuarantined(worker: ActorRef) : WorkersState = WorkersState(workers + worker, quarantine + worker, draining)
    def withDraining(worker: ActorRef) : WorkersState = WorkersState(workers + worker, quarantine + worker, draining + worker)
    def withoutWorker(worker: ActorRef) : WorkersState = WorkersState(workers.filterNot(worker.equals), quarantine.filterNot(worker.equals), draining.filterNot(worker.equals))
    def active: Set[ActorRef] = workers.diff(quarantine)
    def incompatible: Set[ActorRef] = quarantine.diff(draining)
  }

  private def manageWorkers(workersState: WorkersState): Receive = manageSubscriptions(workersState) orElse handleShutdown(workersState) orElse fetchWorkers(workersState) orElse forwardMessages(workersState)

  private def manageSubscriptions(workersState: WorkersState): Receive = {
    case RegisterWorker(worker, workerName, workerConfigurationHash, workerPublicKey) =>
      info(s"Worker registered $worker using configuration hash [$workerConfigurationHash].")
      context.watch(worker)
      registerWorker(workersState, worker, workerName, workerConfigurationHash, workerPublicKey)
    case ReconnectWorker(worker, workerName, workerConfigurationHash, workerPublicKey) =>
      info(s"Worker reconnected $worker using configuration hash [$workerConfigurationHash].")
      context.watch(worker)
      registerWorker(workersState, worker, workerName, workerConfigurationHash, workerPublicKey)
    case WorkerDraining() =>
      val worker = sender()
      info(s"Worker $worker started draining.")
      context.become(manageWorkers(workersState.withDraining(worker)))
    case UnregisterWorker() =>
      val worker = sender()
      info(s"Worker unregistered $worker.")
      context.unwatch(worker)
      context.become(manageWorkers(workersState.withoutWorker(worker)))
      workerRepository.removeWorker(workerAddress(worker))
    case Terminated(worker) =>
      info(s"Worker terminated $worker.")
      context.become(manageWorkers(workersState.withoutWorker(worker)))
      workerRepository.getWorkerByAddress(workerAddress(worker)).foreach(w => workerRepository.storeWorker(w.copy(publicKey = null)))
  }

  private def registerWorker(workersState: WorkersState, worker: ActorRef, workerName: String, workerConfigurationHash: String, workerPublicKey: Array[Byte]): Unit = {
    workerRepository.storeWorker(Worker(address = workerAddress(worker), publicKey = workerPublicKey, configurationHash = workerConfigurationHash, name = workerName))
    if (configurationHash == workerConfigurationHash) {
      sender() ! WorkerRegistered()
      context.become(manageWorkers(workersState.withActive(worker)))
    }
    else {
      info(s"Worker $worker tried to register using non matching configuration hash [$workerConfigurationHash].")
      sender() ! WorkerConfigurationMismatch()
      context.become(manageWorkers(workersState.withQuarantined(worker)))
    }
  }

  private def handleShutdown(workersState: WorkersState): Receive = {
    case ShutdownWorker(addresses, force) =>
      val actors: Set[ActorRef] = (addresses match {
        case Nil =>
          info("Shutting down all workers.")
          workersState.workers
        case list =>
          info(s"Shutting down ${list.size} workers (${list.mkString("[",", ","]")}).")
          workersState.workers.filter { a =>
            list.contains(a.path.address.toString)
          }
      }).filterNot(util.isLocal)
      context.actorOf(Props(new ShutdownAggregator(force))) ! Distribute(actors, "", sender())
  }

  private def addressList(workers: Set[ActorRef]): List[String] = workers.map(workerAddress).toList

  private def fetchWorkers(workersState: WorkersState): Receive = {
    case FetchWorkers() =>
      sender() ! WorkersFetched(addressList(workersState.active), addressList(workersState.incompatible), addressList(workersState.draining))
  }

  private def forwardMessages(workersState: WorkersState): Receive = {
    case SelectOne(path, msg) =>
      select(workersState.active) match {
        case Some(worker) =>
          debug(s"Received message, forwarding to $worker.")
          val toSend = msg match {
            case m: WorkerAware =>
              val workerId = workerRepository.getWorkerByAddress(workerAddress(worker)).getOrElse(throw new IllegalStateException(s"Unknown worker: $worker.")).id
              m.withWorkerId(workerId)
            case o => o
          }
          sendChunked(sender(), worker, ToWorker(path, toSend))
        case None =>
          sender() ! NoWorkersError
      }
    case Publish(path, aggregator) =>
      debug(s"Distributing message to all workers.")
      aggregator ! Distribute(workersState.workers, path, sender())
  }

  private val next = new AtomicInteger()

  private def select(healthyWorkers: Set[ActorRef]): Option[ActorRef] = {
    val size = healthyWorkers.size
    if (size > 0) {
      val index = next.getAndIncrement % size
      Some(healthyWorkers.toList(if (index < 0) size + index else index))
    }
    else {
      None
    }
  }

}

object WorkerReceiver {
  val name = "worker-receiver"

  def props(registrars: Seq[String], workerName: String, workerPublicKey: Array[Byte], configurationHash: String, onShutdown: () => Unit = () => {}, canShutdown: () => Boolean) = {
    Props(new WorkerReceiver(registrars, workerName, workerPublicKey, configurationHash, onShutdown, canShutdown))
  }
}

class WorkerReceiver(registrars: Seq[String], workerName: String, workerPublicKey: Array[Byte], configurationHash : String, onShutdown: () => Unit = () => {}, canShutdown: () => Boolean)
  extends Actor with Logging with ActorContextCreationSupport with MessageSequenceReceiver {

  import messages._

  override def preStart(): Unit = {
    super.preStart()
    registrars.foreach { r =>
      info(s"Registering with '$r' using configuration hash [$configurationHash].")
      createConnectActor(r) ! Connect
    }
  }

  protected def createConnectActor(r: String): ActorRef = {
    createChild(WorkerConnectionActor.connectActor(context.actorSelection(r), workerName, workerPublicKey, configurationHash))
  }

  override def receive: Receive = doReceive(Set())

  private def doReceive(inactive: Set[String]): Receive  = receiveChunks orElse handleMasters(inactive) orElse {
    case ToWorker(path, msg) =>
      logger.debug(s"Forwarding message $msg to '/user/$path'.")
      context.actorSelection(s"/user/$path").forward(msg)
    case ShutdownWorker(_, force) =>
      logger.info("Received shutdown message.")
      sender() ! WorkerShutdownStarted(List(myPublicAddress))
      if (force || canShutdown()) {
        shutdownNow()
      } else {
        waitForEmptyTaskRegistry()
      }
    case TaskRegistryEmpty() =>
      context.system.scheduler.scheduleOnce(1.second) {
        shutdownNow()
      }(context.system.dispatcher)
  }

  private def handleMasters(inactive: Set[String]): Receive = {
    case WorkerRegistered() =>
      val master = sender()
      logger.info(s"Registered successfully with $master.")
      context.watch(master)
      context.become(doReceive(inactive - master.path.toString))
    case WorkerConfigurationMismatch() =>
      val master = sender()
      logger.warn(s"Registered with $master - worker config mismatch detected - this node will not be used to run new tasks by this master.")
      context.watch(master)
      if (hasNoActiveMasters(inactive + master.path.toString)) {
        if (canShutdown()) {
          shutdownNow()
        } else {
          waitForEmptyTaskRegistry()
        }
      }
      context.become(doReceive(inactive + master.path.toString))
    case Terminated(master) =>
      logger.info(s"Master terminated $master.")
      context.unwatch(master)
      context.become(doReceive(inactive - master.path.toString))
      createReconnectActor(master) ! ScheduleReconnect
  }

  private def hasNoActiveMasters(inactive: Set[String]): Boolean = {
    registrars.forall(inactive.contains)
  }

  private def shutdownNow(): Unit = {
    logger.info("Shutting down now.")
    registrars.map(context.actorSelection).foreach(_ ! UnregisterWorker())
    context.system.scheduler.scheduleOnce(Duration.Zero) { onShutdown() } (context.system.dispatcher)
  }

  private def waitForEmptyTaskRegistry(): Unit = {
    logger.info("Waiting for tasks to complete before shutdown.")
    registrars.map(context.actorSelection).foreach(_ ! WorkerDraining())
    context.system.eventStream.subscribe(self, classOf[TaskRegistryEmpty])
  }

  private def myPublicAddress = context.system.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress.toString

  protected def createReconnectActor(master: ActorRef): ActorRef = {
    createChild(WorkerConnectionActor.reconnectActor(context.actorSelection(master.path), workerName, workerPublicKey, configurationHash))
  }
}

object WorkerConnectionActor {

  def connectActor(master: ActorSelection, workerName: String, workerPublicKey: Array[Byte], configurationHash: String) : Props = {
    Props(new WorkerConnectActor(master, workerName, workerPublicKey, configurationHash))
  }

  def reconnectActor(master: ActorSelection, workerName: String, workerPublicKey: Array[Byte], configurationHash: String) : Props = {
    Props(new WorkerReconnectActor(master, workerName, workerPublicKey, configurationHash))
  }
}

abstract class WorkerConnectionActor(master: ActorSelection, workerPublicKey: Array[Byte])
  extends Actor with Logging
{

  import scala.concurrent.ExecutionContext.Implicits.global

  val delay: FiniteDuration
  val interval: FiniteDuration

  protected var scheduledTask: Cancellable = _

  protected def createScheduledTask(block: => Unit): Cancellable = {
    context.system.scheduler.schedule(delay, interval)(block)
  }

  def handleWorkerRegistered(successMsg: String): Receive = {
    case msg: WorkerRegistered =>
      connectionRestored(successMsg)
      context.parent.forward(msg)
    case msg: WorkerConfigurationMismatch =>
      connectionRestored(successMsg)
      context.parent.forward(msg)
  }

  private def connectionRestored(successMsg: String): Unit = {
    if (scheduledTask != null && !scheduledTask.isCancelled) scheduledTask.cancel()
    logger.debug(successMsg)
    self ! PoisonPill
  }

}

class WorkerConnectActor(master: ActorSelection, workerName: String, workerPublicKey: Array[Byte], configurationHash: String)
  extends WorkerConnectionActor(master, workerPublicKey) {

  import concurrent.duration._
  import messages.{Connect, RegisterWorker}

  override val delay: FiniteDuration = 0 seconds
  override val interval: FiniteDuration = CommonSettings(context.system).worker.connectInterval

  private def receiveConnect: Receive = {
    case Connect =>
      val worker = context.parent
      scheduledTask = createScheduledTask {
        logger.info(s"Try connecting to $master using configuration hash [$configurationHash] and display name [$workerName]")
        master ! RegisterWorker(worker, workerName, configurationHash, workerPublicKey)
      }
  }

  override def receive: Receive = receiveConnect orElse handleWorkerRegistered(s"Connected to $master.")
}

class WorkerReconnectActor(master: ActorSelection, workerName: String, workerPublicKey: Array[Byte], configurationHash: String) extends WorkerConnectionActor(master, workerPublicKey) {

  import messages.{ReconnectWorker, ScheduleReconnect}

  override val delay: FiniteDuration = CommonSettings(context.system).worker.reconnectDelay
  override val interval: FiniteDuration = CommonSettings(context.system).worker.reconnectInterval

  private def receiveReconnect: Receive = {
    case ScheduleReconnect =>
      logger.info(s"Schedule reconnect to $master.")
      val worker = context.parent
      scheduledTask = createScheduledTask {
        logger.info(s"Try reconnecting to $master using configuration hash [$configurationHash] and display name [$workerName]")
        master ! ReconnectWorker(worker, workerName, configurationHash, workerPublicKey)
      }
  }

  override def receive: Receive = receiveReconnect orElse handleWorkerRegistered(s"Connection restored to $master.")
}

class ShutdownAggregator(force: Boolean) extends Actor with Logging {
  import concurrent.duration._
  private val timeout: FiniteDuration = CommonSettings(context.system).server.aggregationTimeout
  import messages.{Distribute, Timeout}
  override def receive: Receive = {
    case Distribute(workers, _, origSender) if workers.isEmpty =>
      logger.warn(s"Received shutdown message, but no requested workers found.")
      origSender ! WorkerShutdownStarted(Nil)
    case Distribute(workers, _, origSender) =>
      workers.foreach(_ ! ShutdownWorker(Nil, force))
      val cancellable = context.system.scheduler.scheduleOnce(timeout, self, Timeout)(context.system.dispatcher)
      context.become(waitForResponses(workers.map(util.workerAddress), origSender, Nil, cancellable))
  }
  private def waitForResponses(workers: Set[String], origSender: ActorRef, shutdown: List[String], timeoutCancel: Cancellable): Receive = {
    case Timeout =>
      logger.debug("Timeout before workers responded.")
      origSender ! WorkerShutdownStarted(shutdown)
      self ! PoisonPill
    case WorkerShutdownStarted(_) if workers.size == 1 =>
      val address = workerAddress(sender())
      logger.debug(s"Last worker ($address}) shutdown.")
      origSender ! WorkerShutdownStarted(address :: shutdown)
      timeoutCancel.cancel()
      self ! PoisonPill
    case WorkerShutdownStarted(_) =>
      val address = workerAddress(sender())
      logger.debug(s"Worker ($address) shutdown.")
      context.become(waitForResponses(workers.filterNot(_ equals address), origSender, address :: shutdown, timeoutCancel))
  }
}

abstract class Aggregator(message: AnyRef, override val shouldNotChunk: Boolean) extends Actor with Logging with MaxSizeMessageSequenceSender {
  import concurrent.duration._
  private val timeout: FiniteDuration = CommonSettings(context.system).server.aggregationTimeout
  import messages.{Distribute, Timeout, ToWorker}
  override def receive: Receive = {
    case Distribute(workers, _, origSender) if workers.isEmpty =>
      logger.warn(s"Received message, but no workers available.")
      origSender ! createEmpty
    case Distribute(workers, path, origSender) =>
      logger.debug(s"Forwarding message to all (${workers.size}) workers: $workers.")
      workers.foreach(sendChunked(_, ToWorker(path, message)))
      val cancellable = context.system.scheduler.scheduleOnce(timeout, self, Timeout)(context.system.dispatcher)
      context.become(waitForResponses(workers.map(util.workerAddress), origSender, cancellable))
  }

  protected def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable): Receive

  protected def createEmpty: AnyRef
}

abstract class FindOneAggregator[T](message: AnyRef, shouldNotChunk: Boolean) extends Aggregator(message, shouldNotChunk) {

  import WorkerManager.messages.Found
  import messages.Timeout

  override protected def waitForResponses(workers: Set[String], ref: ActorRef, timeoutCancel: Cancellable): Receive =
    waitForResponses(workers, ref, timeoutCancel, found = false)

  protected def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable, found: Boolean): Receive = {
    case Timeout =>
      logger.debug("Timeout before workers responded.")
      if (!found)
        origSender ! createEmpty
      self ! PoisonPill
    case Found(opt: Option[T]) if workers.size == 1 =>
      logger.debug(s"Last response from ${workerAddress(sender())}.")
      if (!found)
        origSender ! create(opt)
      timeoutCancel.cancel()
      self ! PoisonPill
    case Found(_) if found =>
    // Nothing, already found.
    case Found(None) =>
      logger.debug(s"Response ${workers.size} from ${workerAddress(sender())}.")
      context.become(waitForResponses(workers.filterNot(_ equals workerAddress(sender())), origSender, timeoutCancel, found))
    case Found(Some(t: T)) =>
      logger.debug(s"Response ${workers.size} from ${workerAddress(sender())}.")
      origSender ! create(Some(t))
      context.become(waitForResponses(workers.filterNot(_ equals workerAddress(sender())), origSender, timeoutCancel, found = true))
  }
  protected def create(arg: Option[T]): Found[T]
  override protected def createEmpty: AnyRef = create(None)
}

abstract class FindAllAggregator[T](message: AnyRef, shouldNotChunk: Boolean) extends Aggregator(message, shouldNotChunk) {

  import WorkerManager.messages.FoundAll
  import messages.Timeout

  override protected def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable): Receive =
    waitForResponses(workers, origSender, timeoutCancel, Seq())

  protected def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable, all: Seq[T]): Receive = {
    case Timeout =>
      logger.debug("Timeout before workers responded.")
      origSender ! create(all)
      self ! PoisonPill
    case FoundAll(result: Seq[T]) if workers.size == 1 =>
      logger.debug(s"Last response from ${workerAddress(sender())}.")
      origSender ! create(all ++ result)
      timeoutCancel.cancel()
      self ! PoisonPill
    case FoundAll(result: Seq[T]) =>
      logger.debug(s"Response ${workers.size} from ${workerAddress(sender())}.")
      context.become(waitForResponses(workers.filterNot(_ equals workerAddress(sender())), origSender, timeoutCancel, all ++ result))
  }
  protected def create(result: Seq[T]): FoundAll[T]
  override protected def createEmpty: AnyRef = create(Seq())
}
