package com.xebialabs.platform.script.jython

import java.io.Writer

import com.xebialabs.deployit.plugin.api.flow.ExecutionContext
import com.xebialabs.platform.script.jython.JythonSupport._
import com.xebialabs.xlplatform.script.jython.JythonSugarDiscovery
import grizzled.slf4j.Logging
import javax.script.{ScriptContext, ScriptException}

object JythonSupport {
  type PreprocessExpression = String => String
  type ResultProcessor = (ScriptContext, AnyRef) => AnyRef
  val outWriterDecorator = new ThreadLocalWriterDecorator
  val errorWriterDecorator = new ThreadLocalWriterDecorator
  private val doNotPreprocess: PreprocessExpression = expr => expr
  private val identityResultProcessor: ResultProcessor = (_, value) => value
}

trait JythonSupport extends Logging {

  import com.xebialabs.platform.script.jython.EngineInstance._
  import com.xebialabs.platform.script.jython.ScriptSource.byResource

  def executeScriptWithVariables(scriptPath: String, variables: Map[String, Any], context: ExecutionContext = new JythonExecutionContext): Unit = {
    implicit val scriptContext: JythonContext = createJythonContext(context, Bindings.xlDeployApiServices ++ Map("context" -> context) ++ variables)
    executeScript(byResource(scriptPath))
  }

  private def createJythonContext(executionContext: ExecutionContext, variables: Map[String, Any]) = {
    JythonContext.withLibrariesAndFactory(
      (Syntactic.loggerLib +: Syntactic.wrapperCodeWithLib(variables.keys)) ++ JythonSugarDiscovery.getExtensionResources
    ) {
      val scriptContext = variables.toScriptContext
      scriptContext.setWriter(new ConsumerWriter((text) => executionContext.logOutputRaw(text)))
      scriptContext.setErrorWriter(new ConsumerWriter((text) => executionContext.logErrorRaw(text)))
      scriptContext
    }
  }

  def evaluateExpression[T](expression: String, preprocess: PreprocessExpression = doNotPreprocess, resultProcessor: ResultProcessor = identityResultProcessor)(implicit jythonContext: JythonContext): T = {
    require(Option(expression).exists(_.nonEmpty), "Expression must be defined")
    val result = executeScript(ScriptSource.byContent(preprocess(expression)), resultProcessor)
    result.asInstanceOf[T]
  }

  def executeScript(scriptSource: ScriptSource, resultProcessor: ResultProcessor = identityResultProcessor)(implicit jythonContext: JythonContext): AnyRef = {
    val scriptContext = jythonContext.buildScriptContext

    jythonContext.libraries.foreach(runtimeScript =>
      executeScript(runtimeScript, scriptContext, resultProcessor)
    )
    executeScript(scriptSource, scriptContext, resultProcessor)
  }

  def executeScript(scriptSource: ScriptSource, scriptContext: ScriptContext, resultProcessor: ResultProcessor): AnyRef = {
    trace(s"Evaluating script\n${scriptSource.scriptContent}")
    withThreadLocalWriter(scriptContext) {
      try {
        val result = jython.eval(scriptSource.scriptContent, scriptContext)
        resultProcessor.apply(scriptContext, result)
      } catch {
        case ex: ScriptException => throw JythonException(scriptSource, ex)
      }
    }
  }

  private def withThreadLocalWriter(scriptContext: ScriptContext)(fn: => AnyRef) = {
    addLoggerDecoration(scriptContext)
    val result = fn
    removeLoggerDecoration(scriptContext)
    result
  }

  private def addLoggerDecoration(scriptContext: ScriptContext): Unit = {
    add(scriptContext.getWriter, outWriterDecorator, scriptContext.setWriter)
    add(scriptContext.getErrorWriter, errorWriterDecorator, scriptContext.setErrorWriter)

    def add(writer: Writer, decorator: ThreadLocalWriterDecorator, setWriter: Writer => Unit): Unit = {
      writer match {
        case _: ThreadLocalWriterDecorator =>
        case w if w != null =>
          decorator.registerWriter(w)
          setWriter(decorator)
        case _ =>
      }
    }
  }

  private def removeLoggerDecoration(scriptContext: ScriptContext): Unit = {
    remove(scriptContext.getWriter, outWriterDecorator, scriptContext.setWriter)
    remove(scriptContext.getErrorWriter, errorWriterDecorator, scriptContext.setErrorWriter)

    def remove(writer: Writer, decorator: ThreadLocalWriterDecorator, resetWriter: Writer => Unit): Unit = {
      writer match {
        case _: ThreadLocalWriterDecorator =>
          resetWriter(decorator.getWriter)
          decorator.removeWriter()
        case _ =>
      }
    }
  }
}
