package com.xebialabs.xlrelease.server.jetty

import com.xebialabs.deployit.ServerConfiguration
import com.xebialabs.deployit.plumbing.CSPFilter.POLICY_DIRECTIVES_PARAM
import com.xebialabs.deployit.plumbing._
import com.xebialabs.deployit.security.PermissionEnforcer
import com.xebialabs.xlplatform.endpoints.servlet.{PekkoStreamServlet, PekkoStreamServletInitializer}
import com.xebialabs.xlrelease.Environment
import com.xebialabs.xlrelease.config.XlrConfig
import com.xebialabs.xlrelease.metrics.XlrMetricRegistry
import com.xebialabs.xlrelease.security.HttpSessionListenerWithTimeout
import io.dropwizard.metrics.jetty11.InstrumentedHandler
import org.eclipse.jetty.http.HttpCookie.SameSite
import org.eclipse.jetty.security.ConstraintSecurityHandler
import org.eclipse.jetty.server.handler.RequestLogHandler
import org.eclipse.jetty.server.handler.gzip.GzipHandler
import org.eclipse.jetty.server.session.SessionHandler
import org.eclipse.jetty.servlet.{FilterHolder, ServletHolder}
import org.eclipse.jetty.servlets.CrossOriginFilter
import org.eclipse.jetty.webapp.WebAppContext
import org.springframework.boot.web.embedded.jetty.JettyServletWebServerFactory
import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer
import org.springframework.web.context.ContextLoader
import org.springframework.web.context.support.XmlWebApplicationContext
import org.springframework.web.filter.{CharacterEncodingFilter, DelegatingFilterProxy}
import jakarta.servlet.{DispatcherType, SessionTrackingMode}
import org.eclipse.jetty.util.compression.{CompressionPool, DeflaterPool}

import java.util
import scala.jdk.CollectionConverters._

class ReleaseJettyServletWebServerFactory(val serverConfiguration: ServerConfiguration,
                                          val xlrConfig: XlrConfig,
                                          val permissionEnforcer: PermissionEnforcer)
  extends JettyServletWebServerFactory {

  override def isRegisterDefaultServlet = false

  override protected def postProcessWebAppContext(contextRoot: WebAppContext): Unit = {
    setupJetty(contextRoot)
    setupLogbackFilter(contextRoot)
    setupSpring(contextRoot)
    setupAccessLog(contextRoot)
    setupCorsFilter(contextRoot)
    setupXssSecurityHeadersFilter(contextRoot)
    setupUrlFilter(contextRoot)
    setupCSPFilter(contextRoot)
    setupSlowdownFilter(contextRoot)
    setupExtensionApi(contextRoot)
    setupVersionEnforcer(contextRoot)
    setupRequestLocal(contextRoot)
    setupMaintenanceModeAccessControlFilter(contextRoot)
    setupMultipartFilter(contextRoot)
  }

  private def setupJetty(contextRoot: WebAppContext): Unit = {
    contextRoot.setSessionHandler(newSessionHandler)
    contextRoot.setSecurityHandler(new ConstraintSecurityHandler)

    if (xlrConfig.metrics.enabled) {
      enableInstrumentedHandler(contextRoot)
    }
    if (xlrConfig.server_http_gzip_enabled) {
      enableGzipHandler(contextRoot)
    }
    if (xlrConfig.server_http_cache_enabled && Environment.isProduction) {
      enableResourcesCache(contextRoot)
    }
  }

  private def setupLogbackFilter(contextRoot: WebAppContext): Unit = {
    contextRoot.addFilter(classOf[LogbackFilter], "/*", util.EnumSet.of(DispatcherType.REQUEST)).setAsyncSupported(true)
  }

  private def setupSpring(contextRoot: WebAppContext): Unit = {
    //Set system properties Spring Security uses (via the JRE) for LDAPS
    setPropertyIfNull(SSLConstants.KEYSTORE_PROPERTY, serverConfiguration.getKeyStorePath)
    setPropertyIfNull(SSLConstants.KEYSTORE_PASSWORD_PROPERTY, serverConfiguration.getKeyStorePassword)
    setPropertyIfNull(SSLConstants.TRUSTSTORE_PROPERTY, serverConfiguration.getTrustStorePath)
    setPropertyIfNull(SSLConstants.TRUSTSTORE_PASSWORD_PROPERTY, serverConfiguration.getTrustStorePassword)

    contextRoot.setInitParameter(ContextLoader.CONTEXT_CLASS_PARAM, classOf[XmlWebApplicationContext].getName)
    contextRoot.addFilter(classOf[UrlRedirectRewriteFilter], "/*", util.EnumSet.of(DispatcherType.REQUEST))
    val filter = new FilterHolder(classOf[DelegatingFilterProxy])
    filter.setName(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME)
    filter.setAsyncSupported(true)
    contextRoot.addFilter(filter, "/*", util.EnumSet.of(DispatcherType.REQUEST))

    val encodingFilter = new FilterHolder(classOf[CharacterEncodingFilter])
    encodingFilter.setName("encodingFilter")
    encodingFilter.setInitParameter("encoding", "UTF-8")
    encodingFilter.setInitParameter("forceEncoding", "true")
    encodingFilter.setAsyncSupported(true)
    contextRoot.addFilter(encodingFilter, "/*", util.EnumSet.of(DispatcherType.REQUEST))

    contextRoot.setResourceBase(".")

    if (!xlrConfig.server_session_storage_enabled) {
      contextRoot.addEventListener(new HttpSessionListenerWithTimeout)
    }
  }

  private def setupAccessLog(contextRoot: WebAppContext): Unit = {
    val requestLogHandler = new RequestLogHandler
    val requestLogger = new XLRequestLogImpl("conf/logback-access.xml")
    requestLogHandler.setRequestLog(requestLogger)
    contextRoot.insertHandler(requestLogHandler)
  }

  private def setupCorsFilter(contextRoot: WebAppContext): Unit = {
    if (xlrConfig.api_corsEnabled) {
      val holder = new FilterHolder(new CrossOriginFilter)
      holder.setInitParameter(CrossOriginFilter.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, "*")
      holder.setInitParameter(CrossOriginFilter.ACCESS_CONTROL_ALLOW_METHODS_HEADER, "POST, GET, OPTIONS, DELETE, PUT")
      holder.setInitParameter(CrossOriginFilter.ACCESS_CONTROL_ALLOW_HEADERS_HEADER, "3600")
      holder.setInitParameter(CrossOriginFilter.ACCESS_CONTROL_MAX_AGE_HEADER, "x-requested-with, accept, origin, content-type")
      holder.setAsyncSupported(true)
      contextRoot.addFilter(holder, "/api/v1/*", util.EnumSet.of(DispatcherType.REQUEST))
    }
  }

  private def setupXssSecurityHeadersFilter(contextRoot: WebAppContext): Unit = {
    val holder = new FilterHolder(new XssSecurityHeadersFilter)
    holder.setAsyncSupported(true)
    contextRoot.addFilter(holder, "/*", util.EnumSet.of(DispatcherType.REQUEST))
  }

  private def setupUrlFilter(contextRoot: WebAppContext): Unit = {
    val holder = new FilterHolder(new IgnoreUrlFilter)
    contextRoot.addFilter(holder, "/*", util.EnumSet.of(DispatcherType.REQUEST))
  }

  private def setupSlowdownFilter(contextRoot: WebAppContext): Unit = {
    if (Environment.isDevelopment) {
      val slowDownInMs = xlrConfig.development_restSlowDownDelay.toMillis
      if (slowDownInMs > 0) {
        logger.warn(s"Configuring slow down filter with $slowDownInMs milliseconds.")
        val slowdownFilter = new FilterHolder(classOf[SlowdownFilter])
        slowdownFilter.setInitParameter(SlowdownFilter.MS, String.valueOf(slowDownInMs))
        slowdownFilter.setAsyncSupported(true)
        contextRoot.addFilter(slowdownFilter, "/*", util.EnumSet.of(DispatcherType.REQUEST))
      }
    }
  }

  private def setupCSPFilter(contextRoot: WebAppContext): Unit = {
    if (xlrConfig.server_http_csp_enabled) {
      val holder = new FilterHolder(new CSPFilter)
      holder.setAsyncSupported(true)
      holder.setInitParameter(POLICY_DIRECTIVES_PARAM, xlrConfig.server_http_csp_policyDirectives)
      contextRoot.addFilter(holder, "/*", util.EnumSet.of(DispatcherType.REQUEST))
    }
  }

  private def setupExtensionApi(contextRoot: WebAppContext): Unit = {
    val apiExtensionRootPath = xlrConfig.serverExtension_rootPath
    val apiExtensionContext = String.format("%s/*", apiExtensionRootPath)
    val extensionApiServletHolder = new ServletHolder(classOf[PekkoStreamServlet])
    extensionApiServletHolder.setDisplayName("ExtensionApiConnectorServlet")
    extensionApiServletHolder.setAsyncSupported(true)
    contextRoot.addEventListener(new PekkoStreamServletInitializer())
    contextRoot.addServlet(extensionApiServletHolder, apiExtensionContext)
  }

  private def setupVersionEnforcer(contextRoot: WebAppContext): Unit = {
    if (!Environment.isDevelopment) {
      contextRoot.addFilter(classOf[VersionEnforcerFilter], "/*", util.EnumSet.of(DispatcherType.REQUEST)).setAsyncSupported(true)
    }
  }

  private def setupRequestLocal(contextRoot: WebAppContext): Unit = {
    contextRoot.addFilter(classOf[RequestLocal], "*", util.EnumSet.of(DispatcherType.REQUEST)).setAsyncSupported(true)
  }

  private def enableResourcesCache(contextRoot: WebAppContext): Unit = {
    val holder = new FilterHolder(new ModifiedResourceFilter)
    holder.setInitParameter("startTimestamp", String.valueOf(System.currentTimeMillis))
    holder.setAsyncSupported(true)
    contextRoot.addFilter(holder, "/static/*", util.EnumSet.of(DispatcherType.REQUEST))
    contextRoot.addFilter(holder, "/ui-extensions/xlrelease-module.js", util.EnumSet.of(DispatcherType.REQUEST))
    contextRoot.addFilter(holder, "/ui-extensions/xlrelease-plugins.js", util.EnumSet.of(DispatcherType.REQUEST))
  }

  private def enableGzipHandler(contextRoot: WebAppContext): Unit = {
    val gzipHandler = new GzipHandler
    gzipHandler.setSyncFlush(true)
    gzipHandler.setIncludedMethods("GET", "POST", "PUT")
    gzipHandler.setMinGzipSize(xlrConfig.server_http_gzip_minSize.intValue)
    gzipHandler.setExcludedPaths(xlrConfig.server_http_gzip_excludedPaths.asScala.toSeq: _*)
    gzipHandler.setDeflaterPool(new DeflaterPool(CompressionPool.DEFAULT_CAPACITY, xlrConfig.server_http_gzip_compression, true))
    contextRoot.insertHandler(gzipHandler)
  }

  private def newSessionHandler = {
    val sessionHandler = new SessionHandler
    sessionHandler.setSessionTrackingModes(util.EnumSet.of(SessionTrackingMode.COOKIE))
    sessionHandler.setHttpOnly(true)
    if (!serverConfiguration.isSsl) {
      logger.debug(s"Setting up Secure Cookie Enabled to - ${serverConfiguration.isSecureCookieEnabled}")
      sessionHandler.getSessionCookieConfig.setSecure(serverConfiguration.isSecureCookieEnabled)
    }
    sessionHandler.setSameSite(SameSite.valueOf(xlrConfig.server_http_cookie_sameSite.getAttributeValue.toUpperCase))
    sessionHandler
  }

  private def enableInstrumentedHandler(contextRoot: WebAppContext): Unit = {
    val handler = new InstrumentedHandler(XlrMetricRegistry.metricRegistry)
    contextRoot.insertHandler(handler)
  }

  private def setPropertyIfNull(key: String, value: String): Unit = {
    // skip null new values
    if (value != null) {
      if (System.getProperty(key) == null) {
        if (logger.isDebugEnabled) {
          logger.debug(s"Setting system property - $key")
        }
        System.setProperty(key, value)
      }
      else if (logger.isDebugEnabled) {
        logger.debug(s"Not overriding already specified system property - $key")
      }
    }
  }

  private def setupMaintenanceModeAccessControlFilter(contextRoot: WebAppContext): Unit = {
    if(xlrConfig.maintenanceModeEnabled){
      val restrictApiAccess = xlrConfig.maintenanceModeRestrictApiAccess
      val filterHolder = new FilterHolder(new MaintenanceModeAccessControlFilter(permissionEnforcer, restrictApiAccess))
      contextRoot.addFilter(filterHolder, "/*", util.EnumSet.of(DispatcherType.REQUEST))
    }
  }

  private def setupMultipartFilter(contextRoot: WebAppContext): Unit = {
    val multipartFilter = new FilterHolder(classOf[MultipartFilter])
    multipartFilter.setName("multipartFilter")
    contextRoot.addFilter(multipartFilter, "/*", util.EnumSet.of(DispatcherType.REQUEST))
  }
}
