package com.xebialabs.xldeploy.jms.factory

import ai.digital.configuration.central.deploy.task.SslProperties
import ai.digital.configuration.central.deploy.{ExternalConfig, InProcessConfig}
import com.google.common.base.Strings
import com.xebialabs.xldeploy.jms.adapter.{UnsupportedDriverException, UnsupportedJmsConfigurationException}
import org.springframework.jms.connection.UserCredentialsConnectionFactoryAdapter
import org.springframework.stereotype.Component

import java.io.FileInputStream
import java.lang.reflect.Method
import java.net.URI
import java.security.KeyStore
import jakarta.jms.ConnectionFactory
import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory}

@Component
class DefaultJmsConnectionFactory extends BaseJmsConnectionFactory() {

  override protected def factory(): ConnectionFactory = {
    val factory = taskQueueBrokerConfig match {
      case cfg: ExternalConfig if !inProcessTaskEngine =>
        val baseCF = Class
          .forName(cfg.jmsDriverClassname)
          .asSubclass(classOf[ConnectionFactory])
          .getDeclaredConstructor()
          .newInstance()
        val uri = new URI(cfg.url)
        val cfWithSsl = if ("amqps".equals(uri.getScheme.toLowerCase)) setSslContext(baseCF, cfg.ssl) else baseCF
        val cfWithAddress = setBrokerUrl(cfWithSsl, cfg.url)
        val authenticatedCF = setUserCredentials(cfWithAddress, cfg.username, cfg.password)
        authenticatedCF
      case _: InProcessConfig if inProcessTaskEngine =>
        EmbeddedBrokerConnectionFactory.initialize()
      case _ =>
        throw UnsupportedJmsConfigurationException(
          "Unsupported configuration: deploy.task.in-process-worker should be set to 'false' to use external broker."
        )
    }
    factory
  }

  private def setBrokerUrl(cf: ConnectionFactory,
                           url: String): ConnectionFactory = {
    val method = findUrlSetter(cf, Seq("setUri", "setBrokerURL"))
    method match {
      case Some(mthd) => mthd.invoke(cf, url)
      case _ =>
        throw UnsupportedDriverException(
          "Cannot set JMS broker URL. Driver does not provide known mechanics for broker URL configuration."
        )
    }
    cf
  }

  private def findUrlSetter(cf: ConnectionFactory,
                            methodNames: Seq[String]): Option[Method] = {
    val options = methodNames.foldLeft[Seq[Method]](Seq.empty) {
      (acc: Seq[Method], name: String) =>
        try {
          val method = cf.getClass.getMethod(name, classOf[String])
          acc ++ Seq(method)
        } catch {
          case e: SecurityException =>
            logger.error(e)
            throw UnsupportedJmsConfigurationException(e.getMessage)
          case _: NoSuchMethodException =>
            logger.debug(s"JMS driver does not support setting broker URL using [$name].")
            acc
        }
    }
    options.headOption
  }

  private def setUserCredentials(cf: ConnectionFactory,
                                 username: String,
                                 password: String): ConnectionFactory = {
    val tcf = new UserCredentialsConnectionFactoryAdapter()
    tcf.setUsername(username)
    tcf.setPassword(password)
    tcf.setTargetConnectionFactory(cf)
    tcf
  }

  private def setSslContext(cf: ConnectionFactory, ssl: SslProperties): ConnectionFactory = {
    if(checkSslProperties(ssl)) setCustomSslContext(cf, ssl) else setDefaultSslContext(cf)
  }

  private def setDefaultSslContext(cf: ConnectionFactory): ConnectionFactory = {
    logger.warn("SSL keystore/trustore is not configured in deploy-task.yaml. Using default SSL context")
    try {
      val method = cf.getClass.getMethod("useDefaultSslContext", classOf[Boolean])
      method.invoke(cf, true)
    } catch {
      case e: SecurityException =>
        logger.error(e)
        throw UnsupportedJmsConfigurationException(e.getMessage)
      case _ =>
        throw UnsupportedJmsConfigurationException(s"JMS driver does not support setting useDefaultSslContext.")
    }
    cf
  }

  private def setCustomSslContext(cf: ConnectionFactory, ssl: SslProperties): ConnectionFactory = {
    val keyManagers = createKeyManagers(ssl)
    val trustManagers = createTrustManagers(ssl)
    val c = SSLContext.getInstance(ssl.protocol)
    c.init(keyManagers, trustManagers, null)
    try {
      val method = cf.getClass.getMethod("useSslProtocol", classOf[SSLContext])
      method.invoke(cf, c)
    } catch {
      case e: SecurityException =>
        logger.error(e)
        throw UnsupportedJmsConfigurationException(e.getMessage)
      case _ =>
        throw UnsupportedJmsConfigurationException(s"JMS driver does not support setting SSL protocol using useSslProtocol")
    }
    cf
  }

  private def createKeyManagers(ssl: SslProperties): Array[KeyManager] = {
    val ks = KeyStore.getInstance(KeyStore.getDefaultType)
    ks.load(new FileInputStream(ssl.keyStore), ssl.keyStorePassword.toCharArray)

    val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
    kmf.init(ks, ssl.keyPassword.toCharArray)
    kmf.getKeyManagers
  }

  private def createTrustManagers(ssl: SslProperties): Array[TrustManager] = {
    val tks = KeyStore.getInstance(KeyStore.getDefaultType)
    tks.load(new FileInputStream (ssl.trustStore), ssl.trustStorePassword.toCharArray)

    val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
    tmf.init(tks)
    tmf.getTrustManagers
  }

  private def checkSslProperties(ssl: SslProperties): Boolean = {
    if (Strings.isNullOrEmpty(ssl.keyStore) ||
      Strings.isNullOrEmpty(ssl.keyPassword) ||
      Strings.isNullOrEmpty(ssl.keyStorePassword) ||
      Strings.isNullOrEmpty(ssl.trustStore) ||
      Strings.isNullOrEmpty(ssl.trustStorePassword)
    ) false else true
  }
}
