package com.xebialabs.xlrelease.server.jetty;

import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.eclipse.jetty.alpn.server.ALPNServerConnectionFactory;
import org.eclipse.jetty.http2.server.HTTP2CServerConnectionFactory;
import org.eclipse.jetty.http2.server.HTTP2ServerConnectionFactory;
import org.eclipse.jetty.jmx.MBeanContainer;
import org.eclipse.jetty.server.*;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.embedded.jetty.JettyServerCustomizer;
import com.codahale.metrics.Timer;

import com.xebialabs.deployit.ServerConfiguration;
import com.xebialabs.xlrelease.config.XlrConfig;
import com.xebialabs.xlrelease.metrics.XlrMetricRegistry;

import io.dropwizard.metrics.jetty11.InstrumentedConnectionFactory;

import static com.google.common.base.Strings.nullToEmpty;
import static com.xebialabs.deployit.ServerConfiguration.DEFAULT_IDLE_TIMEOUT;

public class ReleaseJettyServerCustomizer {

    private static final Logger logger = LoggerFactory.getLogger(ReleaseJettyServerCustomizer.class);

    private ReleaseJettyServerCustomizer() {
        /**
         * Empty private constructor to prevent object creation
         */
    }

    public static JettyServerCustomizer create(final ServerConfiguration serverConfiguration, final XlrConfig xlrConfig) {
        return jettyServer -> {
            ServerConnector connector;
            HttpConfiguration httpConfiguration = getHttpConfiguration();

            if (serverConfiguration.isSsl()) {
                SslContextFactory.Server sslContextFactory = getSslContextFactory(serverConfiguration);
                if (xlrConfig.server_http2_enabled()) {
                    ConnectionFactory http11 = createHttpConnectionFactory(httpConfiguration, xlrConfig, HttpConnectionFactoryType.HTTP11);
                    ConnectionFactory http2 = createHttpConnectionFactory(httpConfiguration, xlrConfig, HttpConnectionFactoryType.HTTP2);
                    ConnectionFactory alpn = createHttpConnectionFactory(httpConfiguration, xlrConfig, HttpConnectionFactoryType.ALPN);
                    ConnectionFactory ssl = createSslConnectionFactory(sslContextFactory, alpn.getProtocol(), xlrConfig);
                    connector = new ServerConnector(jettyServer, sslContextFactory, ssl, alpn, http2, http11);
                } else {
                    connector = new ServerConnector(jettyServer, sslContextFactory, createHttpConnectionFactory(httpConfiguration, xlrConfig, HttpConnectionFactoryType.HTTP11));
                }
            } else {
                List<ConnectionFactory> connectionFactories = new ArrayList<>();
                connectionFactories.add(createHttpConnectionFactory(httpConfiguration, xlrConfig, HttpConnectionFactoryType.HTTP11));
                if (xlrConfig.server_http2_enabled()) {
                    connectionFactories.add(createHttpConnectionFactory(httpConfiguration, xlrConfig, HttpConnectionFactoryType.HTTP2C));
                }
                connector = new ServerConnector(jettyServer, connectionFactories.toArray(new ConnectionFactory[0]));
            }

            connector.setIdleTimeout(DEFAULT_IDLE_TIMEOUT);
            connector.setHost(serverConfiguration.getHttpBindAddress());
            connector.setPort(serverConfiguration.getHttpPort());
            jettyServer.setConnectors(new Connector[] {connector});

            logger.info("Server listens on {}:{} ({})", connector.getHost(), connector.getPort(), serverConfiguration.isSsl() ? "secure" : "not secure");
//            https://github.com/spring-projects/spring-boot/issues/22689
            jettyServer.setStopTimeout(0);

            MBeanContainer mbContainer = new MBeanContainer(ManagementFactory.getPlatformMBeanServer());
            jettyServer.addEventListener(mbContainer);
            jettyServer.addBean(mbContainer);
        };
    }

    private static SslContextFactory.Server getSslContextFactory(ServerConfiguration serverConfiguration) {
        logger.debug("Setting up Jetty to use SSL");
        SslContextFactory.Server sslContextFactory = new SslContextFactory.Server();
        sslContextFactory.setKeyStorePath(serverConfiguration.getKeyStorePath());
        sslContextFactory.setKeyStorePassword(serverConfiguration.getKeyStorePassword());
        sslContextFactory.setKeyManagerPassword(serverConfiguration.getKeyStoreKeyPassword());
        String protocol = serverConfiguration.getSslProtocol();
        if (!nullToEmpty(protocol).trim().isEmpty()) {
            sslContextFactory.setProtocol(protocol);
        }

        if (serverConfiguration.isMutualSsl()) {
            logger.debug("Setting up Jetty to use mutual SSL");
            sslContextFactory.setNeedClientAuth(true);
            sslContextFactory.setTrustStorePath(serverConfiguration.getTrustStorePath());
            sslContextFactory.setTrustStorePassword(serverConfiguration.getTrustStorePassword());
        }
        return sslContextFactory;
    }

    private static HttpConfiguration getHttpConfiguration() {
        HttpConfiguration httpConfiguration = new HttpConfiguration();
        httpConfiguration.addCustomizer(new SecureRequestCustomizer(
                true,
                TimeUnit.SECONDS.convert(365, TimeUnit.DAYS),
                true
        ));
        httpConfiguration.setSendServerVersion(false);
        httpConfiguration.setSendXPoweredBy(false);
        return httpConfiguration;
    }

    private static ConnectionFactory createHttpConnectionFactory(HttpConfiguration httpConfiguration, XlrConfig xlrConfig, HttpConnectionFactoryType factoryType) {
        ConnectionFactory connectionFactory;
        switch (factoryType) {
            case HTTP11:
                connectionFactory = new HttpConnectionFactory(httpConfiguration);
                break;
            case HTTP2:
                connectionFactory = new HTTP2ServerConnectionFactory(httpConfiguration);
                break;
            case HTTP2C:
                connectionFactory = new HTTP2CServerConnectionFactory(httpConfiguration);
                break;
            case ALPN:
                connectionFactory = new ALPNServerConnectionFactory();
                break;
            default:
                throw new IllegalStateException();
        }
        if (xlrConfig.metrics().enabled()) {
            Timer timer = XlrMetricRegistry.metricRegistry().timer("connections");
            connectionFactory = new InstrumentedConnectionFactory(connectionFactory, timer);
        }
        return connectionFactory;
    }

    private static ConnectionFactory createSslConnectionFactory(SslContextFactory.Server sslContextFactory, String protocol, XlrConfig xlrConfig) {
        ConnectionFactory connectionFactory = new SslConnectionFactory(sslContextFactory, protocol);
        if (xlrConfig.metrics().enabled()) {
            Timer timer = XlrMetricRegistry.metricRegistry().timer("connections");
            connectionFactory = new InstrumentedConnectionFactory(connectionFactory, timer);
        }
        return connectionFactory;
    }

    private enum HttpConnectionFactoryType {
        HTTP11,
        HTTP2,
        HTTP2C,
        ALPN
    }
}
