package com.xebialabs.deployit.engine.tasker

import java.io.Serializable
import java.util

import ai.digital.deploy.task.steplog.TaskStepLogStoreHolder
import com.xebialabs.deployit.core.events.TaskStepLogEvent
import com.xebialabs.deployit.engine.tasker.log.StepLog
import com.xebialabs.deployit.local.message.ProductName
import com.xebialabs.deployit.plugin.api.flow.{ExecutionContext, ITask}
import com.xebialabs.deployit.plugin.api.inspection.InspectionContext
import com.xebialabs.deployit.plugin.api.services.Repository
import grizzled.slf4j.Logger

class StepExecutionContext(taskExecutionContext: TaskExecutionContext,
                           step: TaskStep, task: Task, spec: TaskSpecification,
                           stepLog: StepLog)
  extends ExecutionContext with Serializable {

  val stepLogger: Logger = Logger.apply(step.getImplementation.getClass)

  if (spec != null) {
    setAttribute("config", spec.getConfig)
  }

  override def setAttribute(name: String, value: scala.AnyRef): Unit = {
    taskExecutionContext.setAttribute(name, value)
  }

  override def logOutput(output: String): Unit = {
    logOutputRaw(output + "\n")
  }

  override def logMsgOutput(productName: ProductName, key: String, args: Object*): Unit = {
    val output = getMessage(productName, key, args: _*)
    logOutputRaw(output + "\n")
  }

  override def logOutputRaw(output: String): Unit = {
    stepLogger.info(output.stripSuffix("\n"))
    stepLog.logOutputRaw(output)
    val lineNumber = step.getLog.split("\n").length
    TaskStepLogStoreHolder.getTaskStepLogJmsBean.
      sendStepLogApplicationEvent(TaskStepLogEvent(task.getId, step.getMetadata.get("blockPath"), lineNumber, "INFO", step.getFailureCount + 1, output))
    step.touch()
  }

  override def logError(error: String): Unit = {
    logErrorRaw(error + "\n")
  }

  override def logErrorRaw(error: String): Unit = {
    logErrorRaw(error, None)
  }

  override def logError(error: String, t: Throwable): Unit = {
    logErrorRaw(error + "\n", Option(t))
  }

  def markNextAttempt(): Unit = {
    stepLog.markNextAttempt()
  }

  private def logErrorRaw(error: String, t: Option[Throwable]): Unit = {
    if (t.isDefined) {
      stepLogger.error(error.stripSuffix("\n"), t.get)
    } else {
      stepLogger.error(error.stripSuffix("\n"))
    }
    stepLog.logErrorRaw(error, t)
    val lineNumber = step.getLog.split("\n").length
    TaskStepLogStoreHolder.getTaskStepLogJmsBean.
      sendStepLogApplicationEvent(TaskStepLogEvent(task.getId, step.getMetadata.get("blockPath"), lineNumber, "ERROR", step.getFailureCount + 1, error))
    step.touch()
  }

  override def getAttribute(name: String): AnyRef = taskExecutionContext.getAttributes.getOrElse(name, null) match {
    case l: OldExecutionContextListenerCleanupTrigger => l.getWrappedListener
    case l@_ => l
  }

  override def getInspectionContext: InspectionContext = spec.getInspectionContext

  override def getRepository: Repository = taskExecutionContext.getRepository

  override def getTask: ITask = new ITask {

    override def getMetadata: util.Map[String, String] = task.getMetadata

    override def getUsername: String = task.getOwner

    override def getId: String = task.getId

  }

  def getTaskFailureCount : Int = task.getFailureCount
}
