package com.xebialabs.platform.script.jython

import java.io.Writer
import javax.script.{ScriptContext, ScriptException}

import com.xebialabs.platform.script.jython.JythonSupport._
import grizzled.slf4j.Logging

object JythonSupport {
  type PreprocessExpression = String => String
  val outWriterDecorator = new ThreadLocalWriterDecorator
  val errorWriterDecorator = new ThreadLocalWriterDecorator
}

trait JythonSupport extends Logging {

  import com.xebialabs.platform.script.jython.EngineInstance._

  private val doNotPreprocess: PreprocessExpression = expr => expr

  def evaluateExpression[T](expression: String, preprocess: PreprocessExpression = doNotPreprocess)(implicit jythonContext: JythonContext): T = {
    require(Option(expression).filter(_.nonEmpty).isDefined, "Expression must be defined")
    val result = executeScript(ScriptSource.byContent(preprocess(expression)))
    result.asInstanceOf[T]
  }

  def executeScript(scriptSource: ScriptSource)(implicit jythonContext: JythonContext): AnyRef = {
    val scriptContext = jythonContext.buildScriptContext

    jythonContext.libraries.foreach(runtimeScript =>
      executeScript(runtimeScript, scriptContext)
    )
    executeScript(scriptSource, scriptContext)
  }

  def executeScript(scriptSource: ScriptSource, scriptContext: ScriptContext): AnyRef = {
    trace(s"Evaluating script\n${scriptSource.scriptContent}")
    withThreadLocalWriter(scriptContext) {
      try {
        jython.eval(scriptSource.scriptContent, scriptContext)
      } catch {
        case ex: ScriptException => throw JythonException(scriptSource, ex)
      }
    }
  }

  private def withThreadLocalWriter(scriptContext: ScriptContext)(fn: => AnyRef) = {
    addLoggerDecoration(scriptContext)
    val result = fn
    removeLoggerDecoration(scriptContext)
    result
  }

  private def addLoggerDecoration(scriptContext: ScriptContext): Unit = {
    add(scriptContext, scriptContext.getWriter, outWriterDecorator, decorator => scriptContext.setWriter(decorator))
    add(scriptContext, scriptContext.getErrorWriter, errorWriterDecorator, decorator => scriptContext.setErrorWriter(decorator))

    def add(scriptContext: ScriptContext, writer: Writer, decorator: ThreadLocalWriterDecorator, setWriter: (ThreadLocalWriterDecorator) => Unit): Unit = {
      writer match {
        case _: ThreadLocalWriterDecorator =>
        case w if w != null =>
          decorator.registerWriter(w)
          setWriter(decorator)
        case _ =>
      }
    }
  }

  private def removeLoggerDecoration(scriptContext: ScriptContext) = {
    remove(scriptContext, scriptContext.getWriter, outWriterDecorator, decorator => scriptContext.setWriter(decorator.getWriter))
    remove(scriptContext, scriptContext.getErrorWriter, errorWriterDecorator, decorator => scriptContext.setErrorWriter(decorator.getWriter))

    def remove(scriptContext: ScriptContext, writer: Writer, decorator: ThreadLocalWriterDecorator, restoreWriter: (ThreadLocalWriterDecorator) => Unit) = {
      writer match {
        case x: ThreadLocalWriterDecorator =>
          restoreWriter(decorator)
          decorator.removeWriter()
        case _ =>
      }
    }
  }
}
