package com.xebialabs.deployit.engine.tasker

import java.util

import akka.actor._
import akka.pattern._
import akka.util.Timeout
import com.github.nscala_time.time.Imports
import com.xebialabs.deployit.engine.api.execution.{FetchMode, TaskWithBlock}
import com.xebialabs.deployit.engine.tasker.TaskManagingActor.messages._
import com.xebialabs.deployit.engine.tasker.distribution.TaskDistributor.messages._
import com.xebialabs.deployit.engine.tasker.distribution.WorkerManager.messages.NoWorkersError
import com.xebialabs.deployit.engine.tasker.distribution.{ChunkReceivingForwarder, ChunkSendingForwarder, NoWorkersException, TaskDistributor, TasksManager}
import com.xebialabs.deployit.engine.tasker.messages._
import com.xebialabs.deployit.engine.tasker.repository.ActiveTaskRepository
import com.xebialabs.xlplatform.settings.{CommonSettings, TaskerSettings}
import grizzled.slf4j.Logging
import org.springframework.security.core.Authentication

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.concurrent.{Await, Awaitable, TimeoutException}
import scala.language.postfixOps
import scala.util.{Failure, Try}

class TaskExecutionEngine(taskRepository: ActiveTaskRepository,
                          system: ActorSystem,
                          workerManager: ActorRef,
                          taskFinalizer: TaskFinalizer) extends IEngine with Logging {

  private val commonSettings: CommonSettings = CommonSettings(system)
  private val taskerSettings: TaskerSettings = commonSettings.tasker

  private val distributor: ActorRef = system.actorOf(TaskDistributor.props(workerManager, commonSettings.inProcessTaskEngine), TaskDistributor.name)

  override def cancel(taskid: String, force: Boolean = false): Unit = {
    taskFinalizer.cancel(lookupTaskActor(taskid), taskid, RunMode.withForce(force))
  }

  override def archive(taskid: String): Unit = {
    taskFinalizer.archive(lookupTaskActor(taskid), taskid)
  }

  override def skipSteps(taskid: String, stepNrs: util.List[Integer]): Unit = {
    skipStepPaths(taskid, stepNrs.asScala.map(int2BlockPath))
  }

  private[this] def skipStepPaths(taskid: String, stepNrs: Seq[BlockPath]): Unit = {
    manageStep(taskid, SkipSteps(taskid, stepNrs))
  }

  private def manageStep(taskid: String, message: ModifySteps): Unit = {
    implicit val timeout: Timeout = taskerSettings.askTimeout

    Await.result(lookupTaskActor(taskid) ? message, timeout.duration) match {
      case ex: StepModificationError => throw new TaskerException(ex.msg)
      case PathsNotFound(paths) => throw new TaskerException(s"Cannot find a path to add a pause step at the position $paths")
      case _: ActorNotFound => throw new RuntimeException("Probably connection to satellite is lost")
      case _ =>
    }
  }

  private[this] def int2BlockPath(oldPosition: Integer): BlockPath = BlockPath.Root / 1 / 1 / oldPosition

  override def unskipSteps(taskid: String, stepNrs: util.List[Integer]): Unit = {
    unskipStepPaths(taskid, stepNrs.asScala.map(int2BlockPath))
  }

  private[this] def unskipStepPaths(taskid: String, stepNrs: Seq[BlockPath]): Unit = {
    manageStep(taskid, UnSkipSteps(taskid, stepNrs))
  }

  override def addPauseStep(taskid: String, position: Integer): Unit = {
    addPauseStep(taskid, int2BlockPath(position))
  }

  override def addPauseStep(taskid: String, position: BlockPath): Unit = {
    manageStep(taskid, AddPauseStep(taskid, position))
  }

  override def unskipStepPaths(taskid: String, stepNrs: util.List[BlockPath]): Unit = {
    unskipStepPaths(taskid, stepNrs.asScala)
  }

  override def skipStepPaths(taskid: String, stepNrs: util.List[BlockPath]): Unit = {
    skipStepPaths(taskid, stepNrs.asScala)
  }

  override def stop(taskid: String): Unit = {
    lookupTaskActor(taskid) ! Stop(taskid)
  }

  override def abort(taskid: String): Unit = {
    lookupTaskActor(taskid) ! Abort(taskid)
  }

  override def execute(taskid: String): Unit = {
    implicit val askTimeout: Timeout = taskerSettings.askTimeout
    Try(Await.ready(lookupTaskActor(taskid) ? Enqueue(taskid), askTimeout.duration + (100 millis))) match {
      case Failure(cause) => warn(s"Enqueue task $taskid failed due the ${cause.getMessage}")
      case _ => debug(s"Task $taskid enqued")
    }
  }

  override def retrieve(taskid: String): Task = {
    askActor[TaskFound](chunkReceiver(lookupTaskActorParent(taskid)), RetrieveTask(taskid))(taskNotFoundOnError(taskid))
      .task.getOrElse(throw taskNotFoundException(taskid))
  }

  override def assign(taskid: String, owner: Authentication): Task = {
    val task = askActor[TaskFound](chunkReceiver(lookupTaskActorParent(taskid)), AssignTask(taskid, owner))(taskNotFoundOnError(taskid))
      .task.getOrElse(throw taskNotFoundException(taskid))
    taskRepository.changeOwner(taskid, task.getOwner)
    task
  }

  private def chunkReceiver(parent: ActorSelection) = {
    system.actorOf(Props(new ChunkReceivingForwarder(parent)))
  }

  private[tasker] def lookupTaskActor(taskid: String): ActorSelection = {
    getActorPathFromRepository(taskid: String)
      .map(system.actorSelection)
      .getOrElse(throw taskNotFoundException(taskid))
  }

  private[tasker] def lookupTaskActorParent(taskid: String): ActorSelection = {
    getActorPathFromRepository(taskid)
      .map(_ / "..")
      .map(system.actorSelection)
      .getOrElse(throw taskNotFoundException(taskid))
  }

  private def getActorPathFromRepository(taskId: String): Option[ActorPath] = {
    taskRepository.workerAddress(taskId)
      .map(wa => ActorPath.fromString(s"$wa/user/${TasksManager.name}/$taskId"))
  }

  private def taskNotFoundOnError[T](taskId: TaskId): PartialFunction[Throwable, T] = {
    case _: TimeoutException => throw taskNotFoundException(taskId)
    case t: Throwable => throw t
  }

  private def taskNotFoundException[T](taskId: TaskId) = {
    new TaskNotFoundException("registry", taskId)
  }

  def schedule(taskid: String, scheduleAt: com.github.nscala_time.time.Imports.DateTime): Unit = {
    import com.github.nscala_time.time.Imports._
    def p(d: DateTime) = d.toString("yyyy-MM-dd HH:mm:ss Z")

    if (scheduleAt.isBeforeNow) {
      throw new TaskerException(s"Cannot schedule a task for the past, date entered was [${p(scheduleAt)}, now is [${p(DateTime.now)}]")
    }
    val delayMillis: Long = (DateTime.now to scheduleAt).millis
    val tickMillis: Long = taskerSettings.tickDuration.toMillis
    if (delayMillis > Int.MaxValue.toLong * tickMillis) {
      val time: Imports.DateTime = new DateTime(DateTime.now.millis.addToCopy(tickMillis * Int.MaxValue))
      throw new TaskerException(s"Cannot schedule task [$taskid] at [${p(scheduleAt)}], because it is too far into the future. Can only schedule to [${p(time)}]")
    }
    lookupTaskActor(taskid) ! Schedule(taskid, scheduleAt)
  }

  override def register(spec: TaskSpecification): String = {
    doRegister(createTaskActor(spec), spec)
  }

  override def prepareRollbackAndRegister(taskid: String, rollbackSpec: TaskSpecification): String = {
    val parent = lookupTaskActorParent(taskid)
    askActor[Any](parent, PrepareRollbackTask(taskid))(rethrowOnError)
    val forwarder: ActorRef = system.actorOf(Props(new ChunkSendingForwarder(parent, commonSettings.inProcessTaskEngine)))
    doRegister(askActor[TaskActorCreated](forwarder, CreateTaskActor(rollbackSpec))(rethrowOnError).taskActor, rollbackSpec)
  }

  private def doRegister(taskActor: ActorRef, spec: TaskSpecification): String = {
    val taskId = askActor[Registered](taskActor, Register())(rethrowOnError).taskId
    taskRepository.store(taskId, spec, taskActor.path)
    taskId
  }

  private[this] def createTaskActor(specification: TaskSpecification, stickyAddress: Address = null): ActorRef = {
    askDistributor[TaskActorCreated](CreateTaskActor(specification))(rethrowOnError).taskActor
  }

  override def getAllIncompleteTasks(fetchMode: FetchMode = FetchMode.FULL): util.List[TaskWithBlock] = {
    val tasks = taskRepository.tasks()
    val result = new util.ArrayList[TaskWithBlock](tasks.asJava)
    val idxs: Map[String, Int] = tasks.map(_.getId).zipWithIndex.toMap
    val details = askDistributor[TasksFound](RetrieveAllTasks(fetchMode))(rethrowOnError).tasks
    details.foreach { t =>
      idxs.get(t.getId).foreach(result.set(_, t))
    }
    result
  }

  private[this] def askDistributor[T](msg: Any)(onTimeout: PartialFunction[Throwable, T]): T = {
    implicit val timeout: Timeout = taskerSettings.askTimeout
    try {
      Await.result(distributor ? msg, timeout.duration) match {
        case NoWorkersError =>
          throw NoWorkersException("There are currently no workers available, tasks cannot be executed. " +
            "Please contact your system administrator to correct the system.")
        case m => m.asInstanceOf[T]
      }
    } catch onTimeout
  }

  private[this] def askActor[T](actor: ActorRef, msg: Any)(onTimeout: PartialFunction[Throwable, T]): T =
    doAsk { implicit t: Timeout => actor ? msg }(onTimeout)

  private[this] def askActor[T](actor: ActorSelection, msg: Any)(onTimeout: PartialFunction[Throwable, T]): T =
    doAsk { implicit t: Timeout => actor ? msg }(onTimeout)

  private[this] def doAsk[T](send: Timeout => Awaitable[Any])(onTimeout: PartialFunction[Throwable, T]): T = {
    val timeout: Timeout = taskerSettings.askTimeout
    try {
      Await.result(send(timeout), timeout.duration).asInstanceOf[T]
    } catch onTimeout
  }

  def rethrowOnError[T]: PartialFunction[Throwable, T] = {
    case t: Throwable => throw t
  }
}
