package com.xebialabs.xlrelease.scheduler.workers

import com.xebialabs.deployit.ServerConfiguration
import com.xebialabs.xlrelease.actors.ReleaseActorService
import com.xebialabs.xlrelease.config.XlrConfig
import com.xebialabs.xlrelease.domain.ScriptTask.JYTHON_ENGINE
import com.xebialabs.xlrelease.domain.recover.TaskRecoverOp.{RESTART_PHASE, RUN_SCRIPT, SKIP_TASK}
import com.xebialabs.xlrelease.domain.status.TaskStatus
import com.xebialabs.xlrelease.domain.{ParallelGroup, SequentialGroup, Task, TaskGroup}
import com.xebialabs.xlrelease.repository.IdType.DOMAIN
import com.xebialabs.xlrelease.repository.{PhaseVersion, TaskRepository}
import com.xebialabs.xlrelease.scheduler.FailureHandlerJob
import com.xebialabs.xlrelease.scheduler.workers.Worker.{ExecuteJob, FailureHandlerExecutionResult, ProcessJobResult}
import com.xebialabs.xlrelease.script.DefaultScriptService.ScriptTaskResults
import com.xebialabs.xlrelease.script._
import com.xebialabs.xlrelease.serialization.json.repository.ResolveOptions
import com.xebialabs.xlrelease.user.User
import com.xebialabs.xlrelease.user.User.SYSTEM
import grizzled.slf4j.Logging
import org.springframework.stereotype.Component

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

@Component
class FailureHandlerWorker(val releaseActorService: ReleaseActorService,
                           val taskRepository: TaskRepository,
                           val scriptLifeCycle: ScriptLifeCycle,
                           val scriptServicesList: java.util.List[ScriptService],
                           val serverConfiguration: ServerConfiguration
                          )
  extends Worker with ScriptServiceSupport with TaskWorkerFailureLogic with ExecuteRecoverCallback with Logging {

  override def execute: ExecuteJob = {
    case FailureHandlerJob(taskRef) =>
      val task = taskRef.get()
      val result = Try {
        scriptServices.get(JYTHON_ENGINE).executeFailureHandler(task)
      }
      FailureHandlerExecutionResult(task.getId, task.getExecutionId, result)
  }

  override def processResult: ProcessJobResult = {
    case FailureHandlerExecutionResult(taskId, executionId, executionResult) =>
      try {
        executionResult match {
          case Failure(error) => handleFailure(taskId, executionId, error)
          case Success(result) => result match {
            case SuccessFailureHandlerResult(taskId, scriptTaskResults) =>
              recover(taskId, scriptTaskResults)
            case FailureFailureHandlerResult(taskId, scriptTaskResults, exception) =>
              fallbackCallback(taskId, scriptTaskResults, exception)
          }
        }
      } catch {
        case error: Exception => handleFailure(taskId, executionId, error)
      }
  }
}

trait ExecuteRecoverCallback extends Logging {

  def scriptLifeCycle: ScriptLifeCycle

  def releaseActorService: ReleaseActorService

  def taskRepository: TaskRepository

  def serverConfiguration: ServerConfiguration

  // FIXME: Copy-paste from non-importable ContextHelper which lives in notification module.
  private def taskUrl(task: Task): String = s"${serverConfiguration.getServerUrl}#/tasks/${DOMAIN.convertToViewId(task.getId)}?showDetails=true"

  // FIXME: copied from DefaultScriptService
  private def sanitizeServerPath(message: String): String = {
    if (message != null) {
      message.replace(System.getProperty("user.dir"), "{ServerWorkingDirectory}")
    }
    message
  }

  private def reason(action: String, task: Task) = s"$action by failure handler on task [${task.getTitle}](${taskUrl(task)})"

  def recover(taskId: String, results: ScriptTaskResults): Unit = {
    val task: Task = taskRepository.findById[Task](taskId, ResolveOptions.WITHOUT_DECORATORS)
    Try(executeRecoverCallback(task, results)).recover {
      case t: Throwable =>
        fallbackCallback(task, results, t)
    }.get
  }

  def fallbackCallback(taskId: String, r: ScriptTaskResults, exception: Throwable): Unit = {
    val task: Task = taskRepository.findById[Task](taskId, ResolveOptions.WITHOUT_DECORATORS)
    fallbackCallback(task, r, exception)
  }

  private def fallbackCallback(task: Task, r: ScriptTaskResults, exception: Throwable): Unit = exception match {
    case _: ScriptTimeoutException =>
      logger.debug(s"Timeout happened when executing failure handler for ${task.getId}")
      scriptLifeCycle.tryAborting(task.getExecutionId)
      val timeout = XlrConfig.getInstance.timeoutSettings.failureHandlerTimeout.toSeconds
      val msg = s"Failure handler of task [${task.getId}] with script execution [${task.getExecutionId}] " +
        s"was terminated due to timeout of [$timeout] seconds. " +
        s"Consider increasing 'xl.timeouts.failureHandlerTimeout' property"
      releaseActorService.failTaskWithRetry(task.getId, msg, SYSTEM, Option.empty, task.getExecutionId)
    case t =>
      val executionLog = Option(t.getMessage).map(sanitizeServerPath).orNull
      logger.warn(s"Exception during task recovery: ${t.getMessage}", t)
      failTask(s"Task recovery failed: $executionLog")(task, r)
  }

  // Task is still in progress!
  private def executeRecoverCallback(task: Task, taskResults: ScriptTaskResults): Unit = {
    logger.trace(s"execRecoverCallback on ${task.getId} '${task.getTitle}'")
    val taskRecoverOp = task.getTaskRecoverOp
    logger.trace(s"$taskRecoverOp")

    val cb: (Task, ScriptTaskResults) => Unit = taskRecoverOp match {
      // I **think** this branching here, means, that, if task itself already did something
      // in its recovery script (like skip or retry), then its status will be changed
      // otherwise, this task is still in a FAILURE_HANDLER_IN_PROGRESS, thus should be failed
      case RUN_SCRIPT if task.getStatus == TaskStatus.FAILURE_HANDLER_IN_PROGRESS =>
        fail
      case RUN_SCRIPT =>
        nop

      case SKIP_TASK =>
        skip

      case RESTART_PHASE =>
        restartPhase

      case _ =>
        fail
    }

    cb(task, taskResults)
  }

  def nop(task: Task, taskResults: ScriptTaskResults): Unit = {
    releaseActorService.saveScriptResults(task.getId, taskResults)
  }

  def fail(task: Task, taskResults: ScriptTaskResults): Unit =
    failTask(reason("Failed", task))(task, taskResults)

  def skip(task: Task, taskResults: ScriptTaskResults): Unit =
    skipTask(reason("Skipped", task))(task, taskResults)

  def restartPhase(task: Task, taskResults: ScriptTaskResults): Unit = {
    restartPhase(reason("Restarted phase", task))(task, taskResults)
  }

  def failTask(message: String)(task: Task, taskResults: ScriptTaskResults): Unit = {
    logger.debug(s"failTask(${task.getTitle}/${task.getTaskType}) [${task.getStatus}]")
    releaseActorService.failTaskWithRetry(task.getId, message, User.SYSTEM, Option(taskResults), task.getExecutionId)
  }

  def cannotSkipLocked(subTasks: List[Task]) =
    s"Cannot skip locked tasks: ${subTasks.map(t => s"'${t.getTitle}'").mkString(", ")}"

  def cannotSkipUnassigned(subTasks: List[Task]) =
    s"Cannot skip unassigned tasks: ${subTasks.map(t => s"'${t.getTitle}'").mkString(", ")}"

  def cannotSkipReason(allSubTasks: List[Task]): Option[String] = {
    val locked = allSubTasks.filter(_.isLocked)
    val unassigned = allSubTasks.filter(t => (t.isPlanned || t.isFailed) && (!t.hasOwner && !t.hasTeam))

    (locked, unassigned) match {
      case (Nil, Nil) => None
      case (l, Nil) => Some(cannotSkipLocked(l))
      case (Nil, u) => Some(cannotSkipUnassigned(u))
      case (l, u) => Some(s"${cannotSkipLocked(l)} and ${cannotSkipUnassigned(u)}")
    }
  }

  def skipTask(message: String)(task: Task, taskResults: ScriptTaskResults): Unit = {
    logger.debug(s"skipTask(${task.getTitle}/${task.getTaskType}) [${task.getStatus}]")
    val allSubTasks = task.getAllTasks.asScala.toList

    cannotSkipReason(allSubTasks) match {
      case Some(msg) =>
        failTask(s"${reason("Failed", task)}: $msg")(task, taskResults)

      case None =>
        task match {
          case g: TaskGroup =>
            skipSubTasks(message)(g, taskResults)
          case _ =>
            releaseActorService.markTaskAsDoneWithScriptResults(TaskStatus.SKIPPED, task.getId, message, None, Option(taskResults), User.SYSTEM, task.getExecutionId)
        }
    }
  }

  def restartPhase(message: String)(task: Task, taskResults: ScriptTaskResults): Unit = {
    logger.debug(s"restartPhase(${task.getTitle}/${task.getTaskType}) [${task.getStatus}]")
    val release = task.getRelease
    val phase = task.getPhase
    releaseActorService.restartPhase(release.getId, phase.getId, phase.getTask(0).getId, PhaseVersion.LATEST, true)
  }

  def skipInAdvance(message: String)(task: Task, taskResults: ScriptTaskResults): Unit = {
    logger.debug(s"skipInAdvance(${task.getTitle}/${task.getTaskType}) [${task.getStatus}]")
    if (task.isTaskGroup) {
      skipTask(message)(task, taskResults)
    } else {
      releaseActorService.markTaskAsDoneWithScriptResults(TaskStatus.SKIPPED_IN_ADVANCE, task.getId, message, None, Option(taskResults), User.SYSTEM, task.getExecutionId)
    }
  }

  //this is not used for the moment, mvp version taskGroup can't have task failure handler
  def skipSubTasks(message: String): (TaskGroup, ScriptTaskResults) => Unit = (group, scriptResult) => {
    logger.debug(s"skipSubTasks(${group.getTitle}/${group.getTaskType}) [${group.getStatus}]")
    val subTasks = group.getTasks.asScala.toList
    group match {
      case _: SequentialGroup =>
        subTasks
          .filter(t => t.isPlanned || t.isFailed)
          .partition(_.isFailed) match {
          case (failed, planned) =>
            planned.reverse.foreach(t => skipInAdvance(message)(t, scriptResult))
            failed.foreach(t => skipTask(message)(t, scriptResult))
        }
      case _: ParallelGroup =>
        subTasks
          .filter(_.isFailed)
          .reverse match {
          case Nil =>
          // nothing to do. Also an empty group cannot fail.
          case last :: rest =>
            rest.foreach(t => skipInAdvance(message)(t, scriptResult))
            skipTask(message)(last, scriptResult)
        }
    }
  }

}
