package com.xebialabs.xlrelease.plugin.classloading

import com.xebialabs.plugin.manager.PluginId
import com.xebialabs.plugin.manager.config.ConfigWrapper
import com.xebialabs.plugin.manager.startup.PluginSynchronizer
import com.xebialabs.plugin.manager.util.PluginFileUtils
import com.xebialabs.plugin.zip.PluginScanner
import com.xebialabs.xlplatform.utils.PerformanceLogging
import com.xebialabs.xlrelease.plugin
import com.xebialabs.xlrelease.plugin.{JarPlugin, Plugin, ZipPlugin}
import de.schlichtherle.truezip.file.{TArchiveDetector, TConfig}
import de.schlichtherle.truezip.fs.archive.zip.JarDriver
import de.schlichtherle.truezip.socket.sl.IOPoolLocator
import org.slf4j.LoggerFactory
import org.springframework.util.StreamUtils

import java.io.{File, IOException, InputStream}
import java.net.URL
import java.nio.file.{Files, Path, Paths}
import java.util
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.jdk.CollectionConverters._

object XlrPluginClassLoader {
  val PLUGINS_LOCAL_FOLDER = new File("plugins/__local__")
  val PLUGINS_OFFICIAL_FOLDER = new File("plugins/xlr-official")
  val HOTFIX_PLUGINS_FOLDER = new File("hotfix/plugins")

  private[XlrPluginClassLoader] val hotfixLogger = LoggerFactory.getLogger("hotfix")

  implicit class NormalizePath(val path: String) extends AnyVal {
    def toUnixPath: String = path.replace(File.separatorChar, '/')
  }

  private lazy val xlrPluginClassLoader = {
    TConfig.get().setArchiveDetector(new TArchiveDetector(TArchiveDetector.ALL, "jar", new JarDriver(IOPoolLocator.SINGLETON)))
    val userDir: String = System.getProperty("user.dir")
    val xlBaseDir: String = System.getProperty("xl.base.dir", userDir)
    val prefix: Path = Paths.get(userDir).toAbsolutePath.relativize(Paths.get(xlBaseDir).toAbsolutePath)
    val pluginDirectories = List(
      HOTFIX_PLUGINS_FOLDER,
      PLUGINS_LOCAL_FOLDER,
      PLUGINS_OFFICIAL_FOLDER
    ).map((d: File) => prefix.resolve(d.toPath).toFile)
    pluginDirectories.foreach((d: File) => {
      try {
        Files.createDirectories(d.toPath)
      } catch {
        case e: IOException =>
          throw new RuntimeException(e)
      }
    })
    val parentClassLoader: ClassLoader = this.getClass.getClassLoader

    new XlrPluginClassLoader(pluginDirectories, parentClassLoader)
  }

  def apply(): XlrPluginClassLoader = xlrPluginClassLoader
}


class XlrPluginClassLoader private(pluginDirectories: Iterable[File], parentClassLoader: ClassLoader)
  extends ClassLoader(parentClassLoader) with PluginScanner with PerformanceLogging {

  private val plugins: CopyOnWriteArrayList[Plugin] = new CopyOnWriteArrayList()

  private val lock = new ReentrantReadWriteLock()
  private val writeLock = lock.writeLock()
  private val readLock = lock.readLock()

  loadPlugins(plugins)

  override def findClass(name: String): Class[_] = logWithTime(s"Loading class $name") {
    val str: Option[URL] = findResourceUrl(convertClassName(name))
    val classOption = str.map(loadClassFromUrl(name, _))
    classOption.getOrElse(
      throw new ClassNotFoundException(
        s"""A plugin could not be loaded due to a missing class ($name). Please remove the offending plugin to successfully start the server.
           |Classes related to JCR were removed from the server because of the migration from JCR to SQL.
           |If the plugin depends on these classes and its functionality is required, please contact support to fix your configuration.
           |$name not found""".stripMargin.replaceAll("\n", ""))
    )
  }

  override def findResource(name: String): URL = logWithTime(s"Loading resource $name")(logHotfix(findResourceUrl(name).orNull))

  override def findResources(name: String): util.Enumeration[URL] = logWithTime(s"Loading resources $name")({
    resourcesByName(name).map(u => {
      logger.trace(s"Found $u for $name")
      u
    }).iterator.asJavaEnumeration
  })

  def syncPlugins(pluginSynchronizer: PluginSynchronizer): Unit = {
    // this is invoked from a single thread on startup so there is no need to guard access
    withWriteLock {
      closePlugins()
      pluginSynchronizer.syncPlugins()
      loadPlugins(plugins)
    }
  }

  def verifyPlugins(pluginSynchronizer: PluginSynchronizer): Unit = {
    withWriteLock {
      closePlugins()
      pluginSynchronizer.verifyPlugins()
      loadPlugins(plugins)
    }
  }

  def getPlugins(): Iterable[Plugin] = {
    withReadLock(plugins.asScala)
  }

  def closePlugin(pluginId: PluginId): Unit = {
    withWriteLock {
      val oldPlugins = plugins.asScala.filter(plugin => plugin.name() == pluginId.name)
      oldPlugins.foreach(_.close())
      plugins.removeAll(oldPlugins.asJava)
    }
  }

  def withWriteLock[R](block: => R): R = {
    try {
      writeLock.lock()
      block
    } finally {
      writeLock.unlock()
    }
  }

  def withReadLock[R](block: => R): R = {
    try {
      readLock.lock()
      block
    } finally {
      readLock.unlock()
    }
  }

  def reload(pluginId: PluginId, pluginFile: File): Unit = {
    withWriteLock {
      // find plugin in existing list of plugins
      val pluginExtension = PluginFileUtils.getPluginExtension(pluginFile.getName)
      val newPlugin = pluginExtension match {
        case ConfigWrapper.EXTENSION_ZIP => plugin.ZipPlugin(pluginFile)
        case ConfigWrapper.EXTENSION_JAR => plugin.JarPlugin(pluginFile)
        case _ => throw new IllegalArgumentException(s"Unhandled plugin type for ${pluginId.id()}")
      }
      closePlugin(pluginId)
      plugins.add(newPlugin)
    }
  }

  def reload(): Unit = {
    withWriteLock {
      closePlugins()
      loadPlugins(plugins)
    }
  }

  private def logHotfix(url: URL): URL = {
    if (url != null && url.toString.contains("hotfix")) {
      XlrPluginClassLoader.hotfixLogger.warn(s"Loading class/resource from hotfix: $url")
    }
    url
  }

  private def loadClassFromUrl(className: String, resourceUrl: URL): Class[_] = {
    logger.trace(s"Loading class from url $resourceUrl")
    import com.xebialabs.xlplatform.utils.ResourceManagement._
    using(resourceUrl.openStream()) { classInputStream =>
      val bytes: Array[Byte] = readFully(classInputStream)
      if (bytes.isEmpty) {
        throw new ClassFormatError("Could not load class. Empty stream returned")
      }
      definePackageIfNeeded(className)
      val clazz = defineClass(className, bytes, 0, bytes.length)
      resolveClass(clazz)
      clazz
    }
  }

  private def definePackageIfNeeded(className: String): Unit = {
    val packageName: String = className.split('.').init.mkString(".")
    Option(getDefinedPackage(packageName)).getOrElse(definePackage(packageName, null, null, null, null, null, null, null))
  }

  private def findResourceUrl(name: String): Option[URL] = resourceByName(name)

  private def convertClassName(className: String) = className.replace('.', '/').concat(".class")

  private def resourceByName(resourcePath: String): Option[URL] = {
    resourcesByName(resourcePath).headOption
  }

  private def resourcesByName(resourcePath: String): Seq[URL] = {
    withReadLock(plugins.asScala.flatMap(_.getResources(resourcePath)).toSeq)
  }

  private def readFully(is: InputStream): Array[Byte] = {
    StreamUtils.copyToByteArray(is)
  }

  private def closePlugins(): Unit = {
    plugins.asScala.foreach(_.close())
    plugins.clear()
  }

  private def getJarPlugins(dir: File): Array[JarPlugin] = {
    findAllPluginFiles(dir, ConfigWrapper.EXTENSION_JAR).map(JarPlugin)
  }

  private def getZipPlugins(dir: File): Array[ZipPlugin] = {
    findAllPluginFiles(dir, ConfigWrapper.EXTENSION_ZIP).map(ZipPlugin)
  }

  private def loadPlugins(pluginList: CopyOnWriteArrayList[Plugin]): Unit = {
    val normalPlugins = pluginDirectories.flatMap { dir =>
      getJarPlugins(dir) ++ getZipPlugins(dir)
    }
    val allPlugins = normalPlugins.toList
    pluginList.addAll(allPlugins.asJava)
  }

}
