package com.xebialabs.deployit.engine.tasker

import com.xebialabs.deployit.engine.api.execution._
import java.util
import org.joda.time.DateTime
import org.joda.time.DateTimeZone.UTC
import collection.convert.wrapAll._
import akka.actor.ActorSystem
import grizzled.slf4j.Logger
import scala.collection.mutable
import com.xebialabs.deployit.engine.tasker.step.PauseStep
import java.io._
import javassist.util.proxy.ProxyObjectInputStream
import com.xebialabs.deployit.engine.tasker.StateChangeEventListenerActor.StepStateEvent
import com.xebialabs.deployit.engine.tasker.StateChangeEventListenerActor.TaskStateEvent
import com.google.common.io.Closeables
import java.util.UUID
import com.xebialabs.deployit.engine.api.execution.{StepExecutionState => SES}
import com.xebialabs.deployit.engine.api.execution.{TaskExecutionState => TES}
import scala.util.{Failure, Success, Try}
import org.springframework.security.core.Authentication

object Task {
  val logger = Logger.apply(classOf[Task])

  def apply(file: File): Option[Task] = {
    logger.info(s"Recovering task [$file]")
    var is: ObjectInputStream = null
    try {
      is = new ProxyObjectInputStream(new FileInputStream(file))
      val t: Task = is.readObject.asInstanceOf[Task]
      t.recovered()
      Some(t)
    } catch {
      case e: ClassNotFoundException =>
        logger.error(s"Could not find serialized class in recovery file [$file]", e)
        None
      case e: IOException =>
        logger.error(s"Could not read recovery file [$file]", e)
        None
      case e: RuntimeException =>
        logger.error(s"Could not read recovery file [$file]", e)
        None
    } finally {
      Try(Closeables.close(is, true))
    }
  }
}

// Wondering whether task shouldn't be an actor itself?
class Task(id: TaskId, spec: TaskSpecification) extends TaskWithSteps with TaskWithBlock with Serializable {
  import Task.logger
  def this(id: TaskId, block: Block) = this(id, new TaskSpecification("", null, block))

  private[this] var startDate: DateTime = null
  private[this] var completionDate: DateTime = null
  private[this] var scheduledDate: DateTime = null
  private[this] var owner = spec.getOwner
  private[this] var failureCount = 0
  private[this] var state = TES.UNREGISTERED
  var block: Block = spec.getBlock
  private[this] def steps: IndexedSeq[((BlockPath, StepState), Int)] = fetchStepSeq
  val context = new TaskExecutionContext(None, spec)

  if (spec.getInspectionContext != null) {
    spec.getInspectionContext.registerTask(this)
  }

  private[this] def fetchStepSeq: IndexedSeq[((BlockPath, StepState), Int)] = block.getStepsWithPaths().zipWithIndex.toIndexedSeq

  def getSpecification = spec

  def getSteps: util.List[StepState] = steps.map(_._1._2).toList

  def getStep(nr: Int): StepState = steps(stepNrToIndex(nr))._1._2

  def getStep(path: BlockPath): StepState = block.getStep(path.tail)

  def getState: TaskExecutionState = state

  def getId: String = id

  def getDescription: String = spec.getDescription

  def getStartDate: DateTime = startDate

  def getCompletionDate: DateTime = completionDate

  def getNrSteps: Int = steps.size

  def getCurrentStepNr: Int = getCurrentStepNrs.headOption.map(Integer2int).getOrElse(if (state == TES.EXECUTED) steps.size else 0)

  def getCurrentStepPaths: util.List[BlockPath] = block.getStepsWithPaths().collect({
    case (path, step) if Set(SES.EXECUTING, SES.FAILED, SES.PAUSED).contains(step.getState) => path
  })

  def getCurrentStepNrs: util.List[Integer] = {
    val mySteps = steps
    def stepByIdxIs(idx: Int, s: SES) = mySteps.find(_._2 == idx).map(_._1._2.getState) == Some(s)

    val activeStates = Set(SES.EXECUTING, SES.FAILED, SES.PAUSED)

    // Can't use collect here, as the list changes while we iterate, and the element being considered might change between
    // The pd.isDefinedAt and the pf.apply, resulting in a MatchError
    val list: List[Integer] = steps.view.filter(t => activeStates.contains(t._1._2.getState)).map(t => int2Integer(t._2 + 1)).toList

    val res: util.List[Integer] = list match {
      case Nil if state == TES.EXECUTED => List(int2Integer(mySteps.size))
      case Nil if Set(TES.STOPPED, TES.ABORTED).contains(state) => mySteps.map(_._2).collect {
        case idx if stepByIdxIs(idx, SES.DONE) && stepByIdxIs(idx + 1, SES.PENDING) => int2Integer(idx + 1)
      }
      case Nil => List(int2Integer(0))
      case x@_ => x
    }

    res
  }

  def getMetadata: util.Map[String, String] = spec.getMetadata

  def getFailureCount: Int = failureCount

  def getOwner: String = owner.getName

  def getAuthentication: Authentication = owner

  def setOwner(newOwner: Authentication) {
    this.owner = newOwner
  }

  def getTempWorkDir: File = spec.getTempWorkDir

  private[tasker] def setState(state: TaskExecutionState) {
    this.state = state
  }

  private[tasker] def recordFailure() {
    failureCount += 1
  }

  private[tasker] def recordStart() {
    if (startDate == null) {
      startDate = new DateTime(UTC)
    }
  }

  private[tasker] def recordCompletion() {
    completionDate = new DateTime(UTC)
  }

  def getContext: TaskExecutionContext = context

  private[tasker] def getBlock = block

  def getBlock(path: BlockPath) = block.getBlock(path.tail)

  private[this] def setStateAndNotify(state: StepExecutionState, step: TaskStep)(implicit system: ActorSystem) {
    val oldState = step.getState
    step.setState(state)
    system.eventStream.publish(StepStateEvent(getId, UUID.randomUUID().toString, this, step, oldState, state, None))
  }

  private[tasker] def setTaskStateAndNotify(newState: TaskExecutionState)(implicit system: ActorSystem) {
    val oldState = state
    state = newState
    logger.info(s"Publishing state change $oldState -> $newState")
    system.eventStream.publish(TaskStateEvent(getId, this, oldState, newState))
  }

  val skippableTaskStates = Set(TES.PENDING, TES.STOPPED, TES.FAILED, TES.ABORTED)

  private[tasker] def skip(stepNrs: List[Int])(implicit system: ActorSystem) {
    if (!skippableTaskStates.contains(state)) throw new TaskerException(s"Task [$id] should be PENDING or STOPPED, but was [$state]")
    Try(stepNrs.foreach {
      stepNr =>
        val step: ((BlockPath, StepState), Int) = steps(stepNr - 1)
        logger.info(s"Trying to skip step ${step._1}")
        val taskStep: TaskStep = step._1._2.asInstanceOf[TaskStep]
        if (!taskStep.isMarkedForSkip) {
          if (!taskStep.canSkip) throw new IllegalArgumentException(s"Step [${step._1}] cannot be skipped")
          setStateAndNotify(StepExecutionState.SKIP, taskStep)
        }
    }) match {
      case Success(u) =>
      case Failure(ex) => throw new TaskerException(ex, "Could not skip all steps")
    }
  }

  private[tasker] def unskip(stepNrs: List[Int])(implicit system: ActorSystem) {
    if (!skippableTaskStates.contains(state)) throw new TaskerException(s"Task [$id] should be PENDING or STOPPED, but was [$state]")
    Try(stepNrs.foreach {
      stepNr =>
        val step: ((BlockPath, StepState), Int) = steps(stepNr - 1)
        logger.info(s"Trying to unskip step ${step._1}")
        val taskStep: TaskStep = step._1._2.asInstanceOf[TaskStep]
        if (taskStep.hasExecuted) throw new IllegalArgumentException(s"Step [${step._1}] cannot be skipped")
        setStateAndNotify(StepExecutionState.PENDING, taskStep)
    }) match {
      case Success(u) =>
      case Failure(ex) => throw new TaskerException(ex, "Could not skip all steps")
    }
  }

  private[tasker] def skipPaths(stepNrs: List[BlockPath])(implicit system: ActorSystem) {
    if (!skippableTaskStates.contains(state)) throw new TaskerException(s"Task [$id] should be PENDING or STOPPED, but was [$state]")
    Try(stepNrs.foreach {
      stepNr =>
        val step: TaskStep = block.getStep(stepNr.tail).asInstanceOf[TaskStep]
        logger.info(s"Trying to skip step $step")
        if (!step.canSkip) throw new IllegalArgumentException(s"Step [$step] cannot be skipped")
        setStateAndNotify(StepExecutionState.SKIP, step)
    }) match {
      case Success(u) =>
      case Failure(ex) => throw new TaskerException(ex, "Could not skip all steps")
    }
  }

  private[tasker] def unskipPaths(stepNrs: List[BlockPath])(implicit system: ActorSystem) {
    if (!skippableTaskStates.contains(state)) throw new TaskerException(s"Task [$id] should be PENDING or STOPPED, but was [$state]")
    Try(stepNrs.foreach {
      stepNr =>
        val step: TaskStep = block.getStep(stepNr.tail).asInstanceOf[TaskStep]
        logger.info(s"Trying to unskip step $step")
        if (step.hasExecuted) throw new IllegalArgumentException(s"Step [$stepNr] cannot be skipped")
        setStateAndNotify(StepExecutionState.PENDING, step)
    }) match {
      case Success(u) =>
      case Failure(ex) => throw new TaskerException(ex, "Could not unskip all steps")
    }
  }

  private[tasker] def moveStep(stepNr: Int, newPosition: Int) {
    if (state != TES.PENDING || getCurrentStepNr != 0) throw new TaskerException(s"Task [$id] should not have run when moving steps")
    if (!block.isInstanceOf[StepBlock]) throw new TaskerException("Cannot move steps when the task consists of multiple blocks")
    val oldIdx: Int = stepNrToIndex(stepNr)
    val newIdx: Int = stepNrToIndex(newPosition)
    
    val stepBlock: StepBlock= block.asInstanceOf[StepBlock]
    val stepBuffer: mutable.Buffer[StepState] = stepBlock.steps
    val removed: StepState = stepBuffer.remove(oldIdx)
    stepBuffer.insert(newIdx, removed)

  }

  def addPause(position: Integer) {
    if (!block.isInstanceOf[StepBlock]) throw new TaskerException("Cannot add pause step using old Integer index when the task consists of multiple blocks")
    if (getCurrentStepNr >= position) throw new IllegalArgumentException(s"Can only add pause steps after the current execution position (currently: [$getCurrentStepNr], requested: [$position])")
    val stepBlock: StepBlock = block.asInstanceOf[StepBlock]
    val stepBuffer = stepBlock.steps
    stepBuffer.insert(stepNrToIndex(position), new TaskStep(new PauseStep))
  }

  def addPause(position: BlockPath) {
    if (!canBeQueued) throw new TaskerException(s"Task [$id] should be PENDING, STOPPED, ABORTED or QUEUED but was [$state]")
    block.addPause(position.tail)
  }

  def recovered() {
    state = state match {
      case TES.EXECUTING | TES.STOPPING | TES.FAILING | TES.ABORTING => TES.FAILED
      case _ => state
    }

    block = block.recovered()
  }

  private[tasker] def stepNrToIndex(stepNr: Int): Int = {
    if (stepNr <= 0 || stepNr > steps.size) throw new IllegalArgumentException(s"Not a valid step number [$stepNr]")
    stepNr - 1
  }

  def canBeQueued = util.EnumSet.of(TES.PENDING, TES.STOPPED, TES.ABORTED, TES.FAILED).contains(state)

  override def toString: String = s"Task[$id, $state]"

  def getScheduledDate: DateTime = scheduledDate

  def setScheduledDate(scheduledDate: DateTime) {
    this.scheduledDate = scheduledDate
  }

  def getActiveBlocks: util.List[String] = activeBlocks(block)

  val activeBlockStates = Set(BlockExecutionState.EXECUTING, BlockExecutionState.FAILING, BlockExecutionState.STOPPING, BlockExecutionState.ABORTING)

  private[this] def activeBlocks(block: Block): List[String] = block match {
    case cb: CompositeBlock if activeBlockStates.contains(cb.state) => cb.getId() :: cb.blocks.flatMap(activeBlocks).toList
    case cb: CompositeBlock => Nil
    case b: Block if activeBlockStates.contains(b.state) => b.getId() :: Nil
    case b @ _ => Nil
  }
}
