package com.xebialabs.deployit.engine.tasker

import akka.actor.Status.Failure
import akka.actor._
import com.github.nscala_time.time.Imports._
import com.xebialabs.deployit.engine.api.execution.TaskExecutionState._
import com.xebialabs.deployit.engine.api.execution.{BlockExecutionState, TaskExecutionState}
import com.xebialabs.deployit.engine.tasker.ArchiveActor.messages.SendToArchive
import com.xebialabs.deployit.engine.tasker.BlockExecutingActor.{BlockDone, BlockStateChanged}
import com.xebialabs.deployit.engine.tasker.StateChangeEventListenerActor.{StepStateEvent, TaskStateEvent}
import com.xebialabs.deployit.engine.tasker.TaskManagingActor.messages.{ArchiveTask, Cancel, GetTask, Recovered, Register, Schedule, ScheduledStart}
import com.xebialabs.deployit.engine.tasker.messages._
import org.joda.time.DateTime

import scala.concurrent.duration._

object TaskManagingActor {
  def props = Props(classOf[TaskManagingActor])

  def getMessages = messages

  object messages {
    case class GetTask(taskId: TaskId)
    case class Register(task: Task)
    case class ScheduledStart(taskId: TaskId)
    case class Schedule(taskId: TaskId, scheduleAt: DateTime)
    case class ArchiveTask(taskId: TaskId, archiveActor: ActorRef, notificationActor: ActorRef)
    case class Cancel(taskId: TaskId, archiveActor: ActorRef, notificationActor: ActorRef)
    case class Recovered(task: Task)
  }
}

class TaskManagingActor extends BaseExecutionActor with Stash {
  import context._

  val terminationStates = Set(FAILED, ABORTED, EXECUTED, STOPPED)

  def updateStateAndNotify(t: Task, newState: TaskExecutionState): Task = {
    debug(s"Sending TaskStateEvent(${t.getState}->$newState) message for [${t.getId}]")
    implicit val system = context.system
    t.setTaskStateAndNotify(newState)
    t
  }

  def receive: Actor.Receive = ReceiveWithMdc()({
    case Register(task) =>
      info(s"Received [Register] message for task [${task.getId}]")
      registerStateListeners(task)
      become(pending(task.getId, updateStateAndNotify(task, PENDING)))
      TaskRegistryExtension(system).register(task)
      sender() ! Registered(task.getId)
    case Recovered(task) =>
      info(s"Received [Recovered] message for task [${task.getId}]")
      registerStateListeners(task)
      task.getState match {
        case PENDING | QUEUED => become(pending(task.getId, task))
        case SCHEDULED if task.getScheduledDate.isAfterNow => doSchedule(task.getId, task, task.getScheduledDate, createOrLookupChildForTaskBlock(task))
        case SCHEDULED => doEnqueue(task.getId, task, createOrLookupChildForTaskBlock(task))
        case EXECUTED => become(canBeArchived(task.getId, task))
        case STOPPED | FAILED | ABORTED => become(readyForRestart(task.getId, task, createOrLookupChildForTaskBlock(task)))
        case _ => sender() ! Failure(new IllegalStateException(s"Cannot recover a task which is in state [${task.getState}]."))
      }
      TaskRegistryExtension(system).register(task)
      sender() ! Registered(task.getId)
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] in for this actor"))
  })

  def registerStateListeners(task: Task) {
    def registerStateListener(child: ActorRef) {
      context.system.eventStream.subscribe(child, classOf[TaskStateEvent])
      context.system.eventStream.subscribe(child, classOf[StepStateEvent])
    }

    val child: ActorRef = createChild(StateChangeEventListenerActor.props(task.getId), "state-listener")
    registerStateListener(child)
    if (task.getSpecification.isRecoverable) {
      val child: ActorRef = createChild(TaskRecoveryActor.props(task.getId, TaskerSettings(system).recoveryDir), "recovery-listener")
      registerStateListener(child)
    }
  }

  private[this] def calculateDelay(time: DateTime) : FiniteDuration = FiniteDuration((DateTime.now to time).millis, MILLISECONDS)

  def pending(taskId: TaskId, task: Task): Actor.Receive = ReceiveWithMdc(task) {
   case Enqueue(`taskId`) =>
      doEnqueue(taskId, task, createOrLookupChildForTaskBlock(task))
    case Schedule(`taskId`, scheduleAt) =>
      doSchedule(taskId, task, scheduleAt, createOrLookupChildForTaskBlock(task))
    case Cancel(`taskId`, archiveActor, notificationActor) =>
      doCancelWhenPending(taskId, task, notificationActor)
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  private[this] def doEnqueue(taskId: TaskId, task: Task, blockActor: ActorRef) {
    info(s"Received [Enqueue] message for task [${task.getId}]")
    updateStateAndNotify(task, QUEUED)
    become(queued(taskId, task, blockActor))
    self ! Start(taskId)
  }

  def queued(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case Start(`taskId`) =>
      doStart(taskId, task, blockActor)
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  private[this] def doStart(taskId: TaskId, task: Task, blockActor: ActorRef) {
    info(s"Received [Start] message for task [${task.getId}]")
    task.recordStart
    updateStateAndNotify(task, EXECUTING)
    become(executing(taskId, task, blockActor))
    blockActor ! Start(taskId)
  }

  private[this] def doSchedule(taskId: TaskId, task: Task, scheduleAt: DateTime, blockActor: ActorRef) {
    info(s"Received [Schedule] message for task [${task.getId}]")
    task.setScheduledDate(scheduleAt)
    val scheduledTask = updateStateAndNotify(task, SCHEDULED)
    val delay: FiniteDuration = calculateDelay(scheduleAt)
    info(s"Going to schedule task [$taskId] at [${scheduleAt.toString("yyyy-MM-dd hh:mm:ss Z")}] which is [$delay] from now")
    if (delay.toMinutes < 0) {
      doEnqueue(taskId, task, blockActor)
    } else {
      val scheduleHandle: Cancellable = context.system.scheduler.scheduleOnce(delay, self, Enqueue(taskId))
      become(scheduled(taskId, scheduledTask, scheduleHandle, blockActor))
    }
  }

  def doCancelWhenPending(taskId: TaskId, task: Task, notificationActor: ActorRef) {
    info(s"Received [Cancel] message for task [${task.getId}]")
    context.system.eventStream.subscribe(self, classOf[TaskStateEventHandled])
    notifyTaskDone(updateStateAndNotify(task, CANCELLED))
    TaskRegistryExtension(system).deleteTask(task.getId)
    var handledMessagesCounter = if (task.getSpecification.isRecoverable) 2 else 1
    become({
      case TaskStateEventHandled(`taskId`, _, CANCELLED) if handledMessagesCounter == 1 =>
        notificationActor ! Cancelled(taskId)
        context.system.eventStream.unsubscribe(self)
      case TaskStateEventHandled(`taskId`, _, CANCELLED) => handledMessagesCounter -= 1
      case _ =>
    })
  }

  def scheduled(taskId: TaskId, task: Task, scheduleHandle: Cancellable, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case Enqueue(`taskId`) =>
      scheduleHandle.cancel()
      doEnqueue(taskId, task, blockActor)
    case Schedule(`taskId`, scheduleAt) =>
      scheduleHandle.cancel()
      doSchedule(taskId, task, scheduleAt, blockActor)
    case ScheduledStart(`taskId`) =>
      doEnqueue(taskId, task, blockActor)
    case Cancel(`taskId`, archiveActor, notificationActor) if task.getStartDate == null =>
      scheduleHandle.cancel()
      doCancelWhenPending(taskId, task, notificationActor)
    case Cancel(`taskId`, archiveActor, notificationActor) if task.getStartDate != null =>
      scheduleHandle.cancel()
      doCancel(taskId, task, archiveActor, notificationActor)
    case _ => throw new IllegalStateException()
  }

  def executing(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockStateChanged(`taskId`, block, oldState, state) =>
      state match {
        case BlockExecutionState.PENDING | BlockExecutionState.EXECUTING =>
        case BlockExecutionState.DONE => become(executed(taskId, updateStateAndNotify(task, EXECUTED)))
        case BlockExecutionState.FAILING => become(failing(taskId, updateStateAndNotify(task, FAILING), blockActor))
        case BlockExecutionState.FAILED => become(failed(taskId, updateStateAndNotify(task, FAILED), blockActor))
        case BlockExecutionState.STOPPING => become(stopping(taskId, updateStateAndNotify(task, STOPPING), blockActor))
        case BlockExecutionState.STOPPED => become(stopped(taskId, updateStateAndNotify(task, STOPPED), blockActor))
        case BlockExecutionState.ABORTING => become(aborting(taskId, updateStateAndNotify(task, ABORTING), blockActor))
        case BlockExecutionState.ABORTED => become(aborted(taskId, updateStateAndNotify(task, ABORTED), blockActor))
      }
    case s @ Stop(`taskId`) =>
      debug(s"Received [$s], now stopping execution")
      blockActor ! s
    case m @ Abort(`taskId`) =>
      debug(s"Received [$m], now aborting execution")
      blockActor ! m
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def executed(taskId: TaskId, task: Task): Actor.Receive = ReceiveWithMdc(task) {
    case BlockDone(`taskId`, block) =>
      task.recordCompletion
      notifyTaskDone(task)
      become(canBeArchived(taskId, task))
      unstashAll()
    case GetTask(`taskId`) => sender() ! task
    case ArchiveTask(`taskId`, _, _) => stash()
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def canBeArchived(taskId: TaskId, task: Task): Actor.Receive = ReceiveWithMdc(task) {
    case ArchiveTask(`taskId`, archiveActor: ActorRef, notificationActor: ActorRef) =>
      info(s"Received [Archive] message for task [${task.getId}]")
      context.system.eventStream.subscribe(self, classOf[TaskStateEventHandled])
      val doneTask: Task = updateStateAndNotify(task, DONE)
      TaskRegistryExtension(system).deleteTask(doneTask.getId)
      become(waitForStateHandledThenArchive(taskId, task, DONE, archiveActor, notificationActor, if (task.getSpecification.isRecoverable) 2 else 1))
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def archiving(taskId: TaskId, task: Task, notificationActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case Archived(`taskId`) =>
      notifyTaskDone(task)
      task.getState match {
        case DONE => notificationActor ! Archived(taskId)
        case CANCELLED => notificationActor ! Cancelled(taskId)
        case _ =>
      }
      harakiri(s"Done with task [$taskId]")
    case fta @ FailedToArchive(`taskId`, exception) =>
      info(s"Task [$taskId] failed to archive, going back to previous state.")
      task.getState match {
        case DONE => 
          updateStateAndNotify(task, EXECUTED)
          become(executed(taskId, task))
        case CANCELLED => 
          updateStateAndNotify(task, FAILED)
          become(failed(taskId, task, createOrLookupChildForTaskBlock(task)))
        case _ =>
      }      
      notifyTaskDone(task)
      notificationActor ! fta
    case _ =>
  }


  def failing(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockStateChanged(`taskId`, block, oldState, state) =>
      state match {
        case BlockExecutionState.PENDING | BlockExecutionState.EXECUTING | BlockExecutionState.DONE | BlockExecutionState.FAILING | BlockExecutionState.STOPPING | BlockExecutionState.STOPPED =>
        case BlockExecutionState.FAILED => become(failed(taskId, updateStateAndNotify(task, FAILED), blockActor))
        case BlockExecutionState.ABORTING => become(aborting(taskId, updateStateAndNotify(task, ABORTING), blockActor))
        case BlockExecutionState.ABORTED => become(aborted(taskId, updateStateAndNotify(task, ABORTED), blockActor))
      }
    case s @ Stop(`taskId`) =>
      debug(s"Received [$s], now stopping execution")
      blockActor ! s
    case m @ Abort(`taskId`) =>
      debug(s"Received [$m], now aborting execution")
      blockActor ! m
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }


  def failed(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockDone(`taskId`, block) =>
      task.recordFailure
      task.recordCompletion
      notifyTaskDone(task)
      become(readyForRestart(taskId, task, blockActor))
      unstashAll()
    case Enqueue(`taskId`) | Schedule(`taskId`, _) | Cancel(`taskId`, _, _) => stash()
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def readyForRestart(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case Enqueue(`taskId`) =>
      doEnqueue(taskId, task, blockActor)
    case Schedule(`taskId`, scheduleAt) =>
      doSchedule(taskId, task, scheduleAt, blockActor)
    case Cancel(`taskId`, archiveActor, notificationActor) =>
      doCancel(taskId, task, archiveActor, notificationActor)
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def doCancel(taskId: TaskId, task: Task, archiveActor: ActorRef, notificationActor: ActorRef) {
    info(s"Received [Cancel] message for task [${task.getId}]")
    context.system.eventStream.subscribe(self, classOf[TaskStateEventHandled])
    notifyTaskDone(updateStateAndNotify(task, CANCELLED))
    TaskRegistryExtension(system).deleteTask(task.getId)
    become(waitForStateHandledThenArchive(taskId, task, CANCELLED, archiveActor, notificationActor, if (task.getSpecification.isRecoverable) 2 else 1))
  }

  def waitForStateHandledThenArchive(taskId: TaskId, task: Task, state: TaskExecutionState, archiveActor: ActorRef, notificationActor: ActorRef, handledMessages: Int): Actor.Receive = ReceiveWithMdc(task) {
    case TaskStateEventHandled(`taskId`, _, `state`) if handledMessages == 1 =>
      become(archiving(taskId, task, notificationActor))
      archiveActor ! SendToArchive(task, self)
    case TaskStateEventHandled(`taskId`, _, `state`) => become(waitForStateHandledThenArchive(taskId, task, state, archiveActor, notificationActor, handledMessages - 1))
    case _ =>
  }

  def stopping(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockStateChanged(`taskId`, block, oldState, state) =>
      state match {
        case BlockExecutionState.PENDING | BlockExecutionState.EXECUTING | BlockExecutionState.STOPPING =>
        case BlockExecutionState.DONE => become(executed(taskId, updateStateAndNotify(task, EXECUTED)))
        case BlockExecutionState.FAILING => become(failing(taskId, updateStateAndNotify(task, FAILING), blockActor))
        case BlockExecutionState.FAILED => become(failed(taskId, updateStateAndNotify(task, FAILED), blockActor))
        case BlockExecutionState.STOPPED => become(stopped(taskId, updateStateAndNotify(task, STOPPED), blockActor))
        case BlockExecutionState.ABORTING => become(aborting(taskId, updateStateAndNotify(task, ABORTING), blockActor))
        case BlockExecutionState.ABORTED => become(aborted(taskId, updateStateAndNotify(task, ABORTED), blockActor))
      }
    case s @ Stop(`taskId`) =>
      debug(s"Received [$s], now stopping execution")
      blockActor ! s
    case m @ Abort(`taskId`) =>
      debug(s"Received [$m], now aborting execution")
      blockActor ! m
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def stopped(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockDone(`taskId`, block) =>
      task.recordCompletion
      notifyTaskDone(task)
      become(readyForRestart(taskId, task, blockActor))
      unstashAll()
    case Enqueue(`taskId`) | Schedule(`taskId`, _) | Cancel(`taskId`, _, _) => stash()
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def aborting(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockStateChanged(`taskId`, block, oldState, state) =>
      state match {
        case BlockExecutionState.PENDING | BlockExecutionState.EXECUTING | BlockExecutionState.DONE | BlockExecutionState.FAILING | BlockExecutionState.STOPPING | BlockExecutionState.STOPPED | BlockExecutionState.FAILED | BlockExecutionState.ABORTING =>
        case BlockExecutionState.ABORTED => become(aborted(taskId, updateStateAndNotify(task, ABORTED), blockActor))
      }
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def aborted(taskId: TaskId, task: Task, blockActor: ActorRef): Actor.Receive = ReceiveWithMdc(task) {
    case BlockDone(`taskId`, block) =>
      task.recordFailure
      task.recordCompletion
      notifyTaskDone(task)
      become(readyForRestart(taskId, task, blockActor))
      unstashAll()
    case Enqueue(`taskId`) | Schedule(`taskId`, _) | Cancel(`taskId`, _, _) => stash()
    case GetTask(`taskId`) => sender() ! task
    case m @ _ => sender() ! Failure(new IllegalStateException(s"Wrong command [$m] for task [${task.getId}] in state [${task.getState}]."))
  }

  def createOrLookupChildForTaskBlock(task: Task): ActorRef = child(task.getBlock.getId()) match {
    case Some(ref) => ref
    case None => task.getBlock match {
      case sb: StepBlock => createChild(StepBlockExecutingActor.props(task, sb, task.getContext), sb.getId())
      case b: CompositeBlock => createChild(BlockExecutingActor.props(task, b, task.getContext), b.getId())
    }
  }

  def notifyTaskDone(task: Task) {
    info(s"Task [${task.getId}] is completed with state [${task.getState}]")
    context.system.eventStream.publish(TaskDone(task))
  }
}
