package com.xebialabs.deployit.engine.tasker

import java.util.Objects

import akka.actor._
import akka.pattern._
import akka.util.Timeout
import com.github.nscala_time.time.Imports
import com.xebialabs.deployit.engine.api.distribution.TaskExecutionWorkerRepository
import com.xebialabs.deployit.engine.api.execution.{FetchMode, TaskWithBlock}
import com.xebialabs.deployit.engine.tasker.TaskManagingActor.messages._
import com.xebialabs.deployit.engine.tasker.distribution.TaskDistributor.messages._
import com.xebialabs.deployit.engine.tasker.distribution.WorkerManager.messages.Publish
import com.xebialabs.deployit.engine.tasker.distribution._
import com.xebialabs.deployit.engine.tasker.messages._
import com.xebialabs.deployit.engine.tasker.repository.{ActiveTaskRepository, PendingTask, PendingTaskRepository}
import com.xebialabs.xlplatform.settings.CommonSettings
import grizzled.slf4j.Logging
import org.joda.{time => jt}
import org.springframework.security.core.Authentication

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.concurrent.{Await, TimeoutException}
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}

class TaskExecutionEngine(val taskRepository: ActiveTaskRepository,
                          val pendingTaskRepository: PendingTaskRepository,
                          workerRepository: TaskExecutionWorkerRepository,
                          val taskQueueService: TaskQueueService,
                          system: ActorSystem,
                          workerManager: ActorRef,
                          taskFinalizer: TaskFinalizer) extends IEngine with TaskActorSupport with Logging {

  implicit val actorSystem: ActorSystem = system
  protected val commonSettings: CommonSettings = CommonSettings(system)
  protected val distributor: ActorRef = system.actorOf(TaskDistributor.props(workerManager, commonSettings.inProcessTaskEngine), TaskDistributor.name)
  protected val taskScheduler: ActorRef = system.actorOf(TaskScheduler.props(taskRepository, pendingTaskRepository, taskQueueService, distributor), TaskScheduler.name)

  override def cancel(taskId: String, force: Boolean = false): Unit = {
    doWithTaskSummary(taskId) {
      taskFinalizer.cancelPendingTask(taskScheduler, taskId, RunMode.withForce(force))
    } {
      taskFinalizer.cancel(lookupTaskActor(taskId), taskId, RunMode.withForce(force))
    }
  }

  override def archive(taskId: String): Unit = {
    doWithTaskSummary(taskId) {
      throw new TaskerException(s"Pending Task [$taskId] cannot be archived.")
    } {
      taskFinalizer.archive(lookupTaskActor(taskId), taskId)
    }
  }

  private def doWithTask(taskId: TaskId)(doIfPending: PendingTask => Unit)(doIfActive: => Unit): Unit = {
    pendingTaskRepository.task(taskId, loadFullSpec = true) match {
      case Some(task) => doIfPending(task)
      case None => doIfActive
    }
  }

  private def doWithTaskSummary(taskId: TaskId)(doIfPending: => Unit)(doIfActive: => Unit): Unit = {
    pendingTaskRepository.task(taskId) match {
      case Some(_) => doIfPending
      case None => doIfActive
    }
  }

  private def manageStep(taskId: String, message: ModifySteps): Unit = {
    implicit val timeout: Timeout = taskerSettings.askTimeout

    Await.result(lookupTaskActor(taskId) ? message, timeout.duration) match {
      case ex: StepModificationError => throw new TaskerException(ex.msg)
      case PathsNotFound(paths) => throw new TaskerException(s"Cannot find a path to add a pause step at the position $paths")
      case _: ActorNotFound => throw new RuntimeException("Probably connection to satellite is lost")
      case _ =>
    }
  }

  private[this] def int2BlockPath(oldPosition: Integer): BlockPath = BlockPath.Root / 1 / 1 / oldPosition

  private[this] def skipOperation: (StepBlock, Seq[BlockPath]) => Unit = (block, steps) => block.skip(steps)

  private[this] def skipStepsMsg: (String, Seq[BlockPath]) => ModifySteps = (id, steps) => SkipSteps(id, steps)

  private[this] def unskipOperation: (StepBlock, Seq[BlockPath]) => Unit = (block, steps) => block.unskip(steps)

  private[this] def unskipStepsMsg: (String, Seq[BlockPath]) => ModifySteps = (id, steps) => UnSkipSteps(id, steps)

  override def skipSteps(taskId: String, stepNrs: java.util.List[Integer]): Unit = {
    modifyStepPaths(taskId, stepNrs.asScala.map(int2BlockPath), skipOperation, skipStepsMsg)
  }

  override def skipStepPaths(taskId: String, stepNrs: java.util.List[BlockPath]): Unit = {
    modifyStepPaths(taskId, stepNrs.asScala, skipOperation, skipStepsMsg)
  }

  override def unskipSteps(taskId: String, stepNrs: java.util.List[Integer]): Unit = {
    modifyStepPaths(taskId, stepNrs.asScala.map(int2BlockPath), unskipOperation, unskipStepsMsg)
  }

  override def unskipStepPaths(taskId: String, stepNrs: java.util.List[BlockPath]): Unit = {
    modifyStepPaths(taskId, stepNrs.asScala, unskipOperation, unskipStepsMsg)
  }

  private[this] def modifyStepPaths(taskId: String, stepNrs: Seq[BlockPath], modifyOperation: (StepBlock, Seq[BlockPath]) => Unit, modifyStepsMsg: (String, Seq[BlockPath]) => ModifySteps): Unit = {
    doWithTask(taskId) { task =>
      task.spec.foreach { spec => modifyPendingTaskSteps(spec, stepNrs)(modifyOperation(_, stepNrs)) }
    }(manageStep(taskId, modifyStepsMsg(taskId, stepNrs)))
  }

  override def addPauseStep(taskId: String, position: Integer): Unit = addPauseStep(taskId, int2BlockPath(position))

  override def addPauseStep(taskId: String, position: BlockPath): Unit = {
    doWithTask(taskId) { task =>
      task.spec.foreach { spec => modifyPendingTaskSteps(spec, Seq(position))(_.addPause(position)) }
    }(manageStep(taskId, AddPauseStep(taskId, position)))
  }

  private def modifyPendingTaskSteps[T](taskSpecification: TaskSpecification, blockPaths: Seq[BlockPath])(tryOperation: StepBlock => T): Unit = {
    def findBlockByPath(path: BlockPath): Option[Block] = taskSpecification.getBlock.getBlock(path.tail)

    def doTry(block: StepBlock): Unit = {
      Try(tryOperation(block)) match {
        case Success(_) =>
        case Failure(ex) =>
          logger.error(StepModificationError(ex.getMessage).msg)
          throw ex
      }
    }

    def doForSingleBlockPath(blockPath: BlockPath): Unit = {
      findBlockByPath(blockPath) match {
        case Some(sb: StepBlock) => doTry(sb)
        case _ => logger.error(PathsNotFound(Seq(blockPath)).msg)
      }
    }

    blockPaths.map(_.init).toSet[BlockPath].headOption.collect {
      case path => doForSingleBlockPath(path)
    }.getOrElse(logger.error(PathsNotFound(blockPaths).msg))
    pendingTaskRepository.update(taskSpecification.getId, taskSpecification)
  }

  override def stop(taskId: String): Unit = sendCommand(taskId, Stop(taskId))

  override def abort(taskId: String): Unit = sendCommand(taskId, Abort(taskId))

  private def sendCommand(taskId: TaskId, message: Any): Unit = doWithTask(taskId) { task =>
    logger.warn(s"Wrong command [$message] for task [${taskId}] in state [${task.getState}].")
  }(lookupTaskActor(taskId) ! message)

  override def execute(taskId: String): Unit = {
    doWithTaskSummary(taskId)(prepareAndEnqueueTask(taskId)) {
      implicit val askTimeout: Timeout = taskerSettings.askTimeout
      Try(Await.ready(lookupTaskActor(taskId) ? Enqueue(taskId), askTimeout.duration + (100 millis))) match {
        case Failure(cause) =>
          logger.warn(s"Enqueue task $taskId failed due to the ${cause.getMessage}")
        case _ =>
          logger.debug(s"Task $taskId enqueued")
      }
    }
  }

  override def retrieve(taskId: String): Task = {
    taskFromPendingTask(taskId) getOrElse {
      askActor[TaskFound](chunkReceiver(lookupTaskActorParent(taskId)), RetrieveTask(taskId))(taskNotFoundOnError(taskId))
        .task.getOrElse(throw taskNotFoundException(taskId))
    }
  }

  private def taskFromPendingTask(taskId: String): Option[Task] = {
    pendingTaskRepository.task(taskId, loadFullSpec = true).flatMap { pendingTask =>
      pendingTask.spec.map { spec =>
        val task = TaskCreator.taskFromSpec(spec, None)
        task.setState(pendingTask.getState)
        task.setScheduledDate(pendingTask.getScheduledDate)
        task
      }
    }
  }

  override def assign(taskId: String, owner: Authentication): Task = {
    taskFromPendingTask(taskId).map { task =>
      pendingTaskRepository.changeOwner(taskId, owner.getName)
      task.setOwner(owner)
      task
    }.getOrElse {
      val task = askActor[TaskFound](chunkReceiver(lookupTaskActorParent(taskId)), AssignTask(taskId, owner))(taskNotFoundOnError(taskId))
        .task.getOrElse(throw taskNotFoundException(taskId))
      taskRepository.changeOwner(taskId, task.getOwner)
      task
    }
  }

  private def chunkReceiver(parent: ActorSelection) = {
    system.actorOf(Props(new ChunkReceivingForwarder(parent)))
  }

  private def taskNotFoundOnError[T](taskId: TaskId): PartialFunction[Throwable, T] = {
    case _: TimeoutException => throw taskNotFoundException(taskId)
    case t: Throwable => throw t
  }

  def schedule(taskId: String, scheduleAt: com.github.nscala_time.time.Imports.DateTime): Unit = {
    import com.github.nscala_time.time.Imports._
    def p(d: DateTime) = d.toString("yyyy-MM-dd HH:mm:ss Z")

    if (scheduleAt.isBeforeNow) {
      throw new TaskerException(s"Cannot schedule a task for the past, date entered was [${p(scheduleAt)}, now is [${p(DateTime.now)}]")
    }
    val delayMillis: Long = (DateTime.now to scheduleAt).millis
    val tickMillis: Long = taskerSettings.tickDuration.toMillis
    if (delayMillis > Int.MaxValue.toLong * tickMillis) {
      val time: Imports.DateTime = new DateTime(DateTime.now.millis.addToCopy(tickMillis * Int.MaxValue))
      throw new TaskerException(s"Cannot schedule task [$taskId] at [${p(scheduleAt)}], because it is too far into the future. Can only schedule to [${p(time)}]")
    }
    doSchedule(taskId, scheduleAt)
  }

  private def doSchedule(taskId: TaskId, scheduleAt: jt.DateTime): Unit = {
    val message = Schedule(taskId, scheduleAt)
    doWithTaskSummary(taskId) {
      pendingTaskRepository.schedule(taskId, scheduleAt)
      taskScheduler ! message
    }(lookupTaskActor(taskId) ! message)
  }

  override def register(spec: TaskSpecification): String = {
    pendingTaskRepository.store(spec)
  }

  override def prepareRollback(taskId: String, rollbackSpec: TaskSpecification): Unit = {
    val parent = lookupTaskActorParent(taskId)
    askActor[Any](parent, PrepareRollbackTask(taskId))(rethrowOnError)
  }

  override def getAllIncompleteTasks(fetchMode: FetchMode = FetchMode.FULL): java.util.List[TaskWithBlock] = {
    val tasks = taskRepository.tasks() ++ pendingTaskRepository.tasks(Objects.equals(fetchMode, FetchMode.FULL)).asScala
    val result = new java.util.ArrayList[TaskWithBlock](tasks.asJava)
    val idxs: Map[String, Int] = tasks.map(_.getId).zipWithIndex.toMap
    val details = askDistributor[TasksFound](RetrieveAllTasks(fetchMode))(rethrowOnError).tasks
    details.foreach { t =>
      idxs.get(t.getId) match {
        case Some(idx) => result.set(idx, t)
        case None => warn(s"Found ghost task [${t.getId}]. This task is not registered in the system, but found on a worker.")
      }
    }
    result
  }

  override def reregisterGhostTasks(): java.util.List[TaskId] = {
    implicit val timeout: Timeout = taskerSettings.askTimeout
    val props = Props(classOf[ReRegisterGhostTasksAggregator], taskRepository, workerRepository, commonSettings.inProcessTaskEngine)
    val response = Await.result(
      workerManager ?
        Publish(TasksManager.name, system.actorOf(props)), timeout.duration).asInstanceOf[GhostTasksReRegistered]
    new java.util.ArrayList[TaskId](response.tasks.map(_.getId).asJava)
  }

}
