package com.xebialabs.xlrelease.script.jython

import org.apache.commons.lang3.StringUtils.isNotBlank
import org.python.antlr.ast._
import org.python.antlr.runtime.ANTLRStringStream
import org.python.antlr.{AnalyzingParser, PythonTree}

import java.util
import javax.script.ScriptException
import scala.jdk.CollectionConverters._

trait JythonScriptValidator {
  def validate(script: String, encoding: String): Unit
}

case class JythonSandboxConfiguration(restrictedModules: java.util.List[String],
                                      restrictedFunctions: java.util.List[String],
                                      restrictedAttributes: java.util.List[String])

class JythonANTLRScriptValidator(jythonSandbox: JythonSandboxConfiguration) extends JythonScriptValidator {

  def validate(script: String, encoding: String): Unit = if (isNotBlank(script)) {
    val validatingParser = new AnalyzingParser(new ANTLRStringStream(script), "", encoding)

    val parsedTree = validatingParser.parseModule
    val securityVisitor = new JythonANTLRSecurityVisitor(jythonSandbox)

    Option(parsedTree.getChildren).foreach { children =>
      children.asScala.foreach(securityVisitor.visit)
    }
  }
}


class JythonANTLRSecurityVisitor(jythonSandbox: JythonSandboxConfiguration) extends VisitorBase[Object] {

  override def visitName(node: Name): Object = {
    if (node.getText != null && jythonSandbox.restrictedFunctions.contains(node.getText)) {
      raiseException(s"Using the builtin ${node.getText} statement is not allowed.")
    }
    super.visitName(node)
  }

  override def visitImport(node: Import): Object = {
    importedNames(node.getNames).foreach { importedName =>
      if (jythonSandbox.restrictedModules.contains(importedName)) {
        raiseException(s"Using the $importedName module is not allowed.")
      }
    }
    super.visitImport(node)
  }

  override def visitAttribute(node: Attribute): Object = {
    if (node.getAttr != null && jythonSandbox.restrictedAttributes.contains(node.getAttr.toString)) {
      raiseException(s"Using the attribute ${node.getAttr.toString} statement is not allowed.")
    }
    super.visitAttribute(node)
  }

  private def importedNames(names: Object): Seq[String] = {
    names match {
      case names: util.List[_] => names.asScala.filter(_.isInstanceOf[alias]).map(_.asInstanceOf[alias].getName.toString).toSeq
      case _ => Seq.empty
    }
  }

  private def raiseException(message: String) = {
    throw new ScriptException(message)
  }

  @throws[Exception]
  override def traverse(node: PythonTree): Unit = node.traverse(this)

  @throws[Exception]
  def visit(nodes: Array[PythonTree]): Unit = nodes.foreach(visit)

  @throws[Exception]
  def visit(node: PythonTree): Object = node.accept(this)

  @throws[Exception]
  override protected def unhandled_node(node: PythonTree): Object = this
}
