package com.xebialabs.deployit.engine.tasker.distribution

import akka.actor.{Actor, ActorRef, ActorRefFactory, Cancellable, PoisonPill, Props}
import com.xebialabs.deployit.engine.api.distribution.TaskExecutionWorkerRepository
import com.xebialabs.deployit.engine.api.execution.{FetchMode, TaskWithBlock}
import com.xebialabs.deployit.engine.spi.execution.IgnoreOnRollbackExecutionStateListener
import com.xebialabs.deployit.engine.tasker.TaskManagingActor.messages.Recovered
import com.xebialabs.deployit.engine.tasker._
import com.xebialabs.deployit.engine.tasker.distribution.TaskDistributor.messages._
import com.xebialabs.deployit.engine.tasker.distribution.WorkerManager.messages._
import com.xebialabs.deployit.engine.tasker.distribution.messages.Timeout
import com.xebialabs.deployit.engine.tasker.distribution.util.workerAddress
import com.xebialabs.deployit.engine.tasker.repository.CrudTaskRepository
import grizzled.slf4j.Logging
import org.springframework.security.core.Authentication

object TaskDistributor {
  def props(workerManager: ActorRef, shouldNotChunk: Boolean = true): Props = Props(new TaskDistributor(workerManager, shouldNotChunk))

  def name = "task-distributor"

  object messages {

    case class CreateTaskActor(specification: TaskSpecification, workerId: Integer = null) extends WorkerAware {
      type Self = CreateTaskActor
      override def withWorkerId(workerId: Integer): CreateTaskActor = CreateTaskActor(specification, workerId)
    }
    case class TaskActorCreated(taskActor: ActorRef)

    case class RetrieveTask(taskId: TaskId)
    case class TaskFound(task: Option[Task]) extends Found(task)

    case class RetrieveAllTasks(fetchMode: FetchMode)
    case class TasksFound(tasks: Seq[TaskWithBlock]) extends FoundAll(tasks)

    case class AssignTask(taskId: TaskId, owner: Authentication)

    case class PrepareRollbackTask(taskId: TaskId)
    case class RollbackPrepared()

    case class ReRegisterGhostTasks()
    case class GhostTasksReRegistered(tasks: List[TaskWithBlock])
  }

}

class TaskDistributor(workerManager: ActorRef, shouldNotChunk: Boolean) extends Actor with Logging {

  override def receive: Receive = handleCreation orElse forwardMessage

  private def handleCreation: Receive = {
    case msg: CreateTaskActor =>
      workerManager.forward(SelectOne(TasksManager.name, msg))
  }

  private def forwardMessage: Receive = {
    case RetrieveAllTasks(fetchMode) =>
      publish(context.actorOf(Props(classOf[RetrieveAllTasksAggregator], fetchMode, shouldNotChunk)))
  }

  private def publish(aggregator: ActorRef): Unit = {
    workerManager.forward(Publish(TasksManager.name, aggregator))
  }
}

class RetrieveAllTasksAggregator(fetchMode: FetchMode, shouldNotChunk: Boolean)
  extends FindAllAggregator[TaskWithBlock](RetrieveAllTasks(fetchMode), shouldNotChunk) with MessageSequenceReceiver {

  override protected def create(arg: Seq[TaskWithBlock]): FoundAll[TaskWithBlock] = TasksFound(arg)

  override def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable,
                                all: Seq[TaskWithBlock]): Receive =
    receiveChunks orElse super.waitForResponses(workers, origSender, timeoutCancel, all)
}

object TasksManager {
  def props(taskActorCreator: TaskActorCreator, taskRegistry: TaskRegistryExtension, taskRepository: CrudTaskRepository,
            shouldNotChunk: Boolean = true): Props =
    Props(new TasksManager(taskActorCreator, taskRegistry, taskRepository, shouldNotChunk))

  def name = "tasks-manager"
}

class TasksManager(taskActorCreator: TaskActorCreator, taskRegistry: TaskRegistryExtension,
                   taskRepository: CrudTaskRepository, override val shouldNotChunk: Boolean)
  extends Actor with Logging with MaxSizeMessageSequenceSender with MessageSequenceReceiver {

  def receive: Receive = receiveChunks orElse {
    case CreateTaskActor(specification, workerId) =>
      info("Received task creation request.")
      val taskActor = taskActorCreator.create(specification, workerId, context)
      sender() ! TaskActorCreated(taskActor)

    case msg@Recovered(task) =>
      taskActorCreator.create(task, context) forward msg

    case RetrieveTask(taskId) =>
      sendChunked(sender(), TaskFound(taskRegistry.getTask(taskId)))

    case RetrieveAllTasks(fetchMode) =>
      sendChunked(sender(), TasksFound(taskRegistry.getTasks(fetchMode)))

    case AssignTask(taskId, owner) =>
      val task = taskRegistry.getTask(taskId)
      task.foreach(_.setOwner(owner))
      sendChunked(sender(), TaskFound(task))

    case PrepareRollbackTask(taskId) =>
      taskRegistry.getTask(taskId).foreach { task =>
        val context: TaskExecutionContext = task.getContext
        context.allListeners.foreach {
          case listener: IgnoreOnRollbackExecutionStateListener => listener.disable()
          case _ =>
        }
        sender() ! RollbackPrepared()
      }
    case ReRegisterGhostTasks =>
      val allRepositoryTaskIds = taskRepository.tasks().map(_.getId).toSet
      val ghosts: List[TaskWithBlock] = taskRegistry.getTasks(FetchMode.SUMMARY).filterNot(
        task => allRepositoryTaskIds.contains(task.getId)
      ).toList

      sender() ! GhostTasksReRegistered(ghosts)
  }
}

trait TaskActorCreator {
  def create(taskSpecification: TaskSpecification, workerId: Integer, factory: ActorRefFactory): ActorRef
  def create(task: Task, factory: ActorRefFactory): ActorRef
}

class ReRegisterGhostTasksAggregator(activeTaskRepository: CrudTaskRepository,
                                     workerRepository: TaskExecutionWorkerRepository, shouldNotChunk: Boolean)
  extends Aggregator(ReRegisterGhostTasks, shouldNotChunk) {

  override protected def createEmpty: AnyRef = GhostTasksReRegistered(Nil)

  override def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable): Receive =
    waitForResponses(workers, origSender, timeoutCancel, Map[String, List[TaskWithBlock]]())

  protected def waitForResponses(workers: Set[String], origSender: ActorRef, timeoutCancel: Cancellable,
                                 taskMap: Map[String, List[TaskWithBlock]]): Receive = {
    case Timeout =>
      info("Timeout fetching ghost tasks before workers responded.")
      restoreTasks(taskMap)
      origSender ! create(taskMap)
      self ! PoisonPill
    case GhostTasksReRegistered(result: List[TaskWithBlock]) if workers.size == 1 =>
      timeoutCancel.cancel()
      val address: String = workerAddress(sender())
      info(s"Last response from $address.")
      val allTasks = taskMap + (address -> result)
      restoreTasks(allTasks)
      origSender ! create(allTasks)
      self ! PoisonPill
    case GhostTasksReRegistered(result: List[TaskWithBlock]) =>
      val address = workerAddress(sender())
      info(s"Response ${workers.size} from $address.")
      context.become(
        waitForResponses(workers.filterNot(_ equals address), origSender, timeoutCancel, taskMap + (address -> result))
      )
  }

  private def restoreTasks(taskMap: Map[String, List[TaskWithBlock]]): Unit =
    taskMap.foreach { case (workerAddress, tasks) =>
      tasks.foreach { t => {
          info(s"Restoring ghost task ${t.getId}.")
          val workerId = workerRepository.getWorkerByAddress(workerAddress)
            .getOrElse(throw new IllegalStateException(s"Unknown worker address: $workerAddress.")).id
          activeTaskRepository.store(t.getId, t.getDescription, t.getOwner, workerId, t.getMetadata)
        }
      }
    }

  protected def create(arg: Map[String, List[TaskWithBlock]]): GhostTasksReRegistered = GhostTasksReRegistered(arg.values.flatten.toList)

}