package com.xebialabs.satellite.streaming

import akka.actor.ActorSystem

import java.io.{FileInputStream, InputStream}
import java.security.KeyStore
import akka.stream.TLSProtocol._
import akka.stream._
import akka.stream.scaladsl._
import akka.util.ByteString
import com.xebialabs.xlplatform.settings.SecuritySettings
import com.xebialabs.xlplatform.utils.{ClassLoaderUtils, SecureRandomHolder}
import grizzled.slf4j.Logging

import javax.net.ssl._

object SslStreamingSupport extends Logging {

  object SslConfig extends Logging {

    def apply(useSsl: Boolean, settings: SecuritySettings): SslConfig = useSsl match {
      case true if settings.enabled =>
        debug("SSL enabled")
        Enabled(settings)
      case true =>
        logger.warn("Requested ssl encryption but there is no configuration provided.")
        throw new SecurityException("Requested ssl encryption but there is no configuration provided.")
      case _ =>
        debug("SSL disabled")
        Disabled
    }

    lazy val Disabled = SslConfig(enabled = false, sslContext = null, role = null, closing = null, protocol = null)

    private def Enabled(settings: SecuritySettings) = SslConfig(
      enabled = true,
      sslContext = initSslContext(settings),
      enabledAlgorithms = settings.enabledAlgorithms,
      protocol = settings.protocol,
      closing = null
    )
  }

  case class SslConfig(enabled: Boolean, sslContext: SSLContext, role: TLSRole = null, closing: TLSClosing,
                       protocol: String, enabledAlgorithms: Seq[String] = Nil) {
    def asClient: SslConfig = copy(role = Client)

    def asServer: SslConfig = copy(role = Server)

    def ignoreCancel: SslConfig = copy(closing = IgnoreCancel)

    def ignoreComplete: SslConfig = copy(closing = IgnoreComplete)

    def ignoreBoth: SslConfig = copy(closing = IgnoreBoth)

    def eagerClose: SslConfig = copy(closing = EagerClose)
  }

  type SslFlow = BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, _]

  type ByteStingFlow[Mat] = Flow[ByteString, ByteString, Mat]

  def wrapWithSsl[Mat](sslConfig: SslConfig, tcpConnection: ByteStingFlow[Mat])(implicit system: ActorSystem): ByteStingFlow[Mat] =
    sslWrapper[Mat](sslFlow(sslConfig), tcpConnection, sslConfig.role)

  private def sslFlow(sslConfig: SslConfig)(implicit system: ActorSystem) = if (sslConfig.enabled) {
    debug("Real ssl config")

    val createSSLEngine: () => SSLEngine = { () =>
      val sslContext: SSLContext = sslConfig.sslContext
      val engine = sslContext.createSSLEngine()

      engine.setSSLParameters(sslContext.getDefaultSSLParameters)
      engine.setEnabledProtocols(Array(sslConfig.protocol))
      engine.setEnabledCipherSuites(sslConfig.enabledAlgorithms.toArray)

      engine.setUseClientMode(sslConfig.role == Client)

      val paramsWithHostnameVerification = newSessionNegotiation(sslConfig.enabledAlgorithms)

      applySessionParameters(engine, paramsWithHostnameVerification)
      engine
    }
    TLS(createSSLEngine, sslConfig.closing)

    TLS(sslConfig.sslContext, None, newSessionNegotiation(sslConfig.enabledAlgorithms), sslConfig.role, sslConfig.closing)

  } else {
    debug("Placebo ssl config")
    TLSPlacebo()
  }

  private def applySessionParameters(engine: SSLEngine, sessionParameters: NegotiateNewSession): Unit = {
    sessionParameters.enabledCipherSuites.foreach(cs => engine.setEnabledCipherSuites(cs.toArray))
    sessionParameters.enabledProtocols.foreach(p => engine.setEnabledProtocols(p.toArray))
    sessionParameters.clientAuth match {
      case Some(TLSClientAuth.None) => engine.setNeedClientAuth(false)
      case Some(TLSClientAuth.Want) => engine.setWantClientAuth(true)
      case Some(TLSClientAuth.Need) => engine.setNeedClientAuth(true)
      case _                        => // do nothing
    }

    sessionParameters.sslParameters.foreach(engine.setSSLParameters)
  }

  private def newSessionNegotiation(cypherSuites: Seq[String]): NegotiateNewSession =
    NegotiateNewSession.withCipherSuites(cypherSuites: _*)

  private def sslWrapper[Mat](sslFlow: SslFlow, tcpConnection: ByteStingFlow[Mat], role: TLSRole): ByteStingFlow[Mat] =
    Flow.fromGraph(GraphDSL.create(sslFlow, tcpConnection)((_, c) => c) {
      implicit builder =>
        (sslFlow, conn) =>
          import GraphDSL.Implicits._
          val sendBytes = builder.add(Flow[ByteString].map(bs => SendBytes(bs)))
          sendBytes.outlet ~> sslFlow.in1
          sslFlow.out1 ~> conn ~> sslFlow.in2
          val inboundFlow = sslFlow.out2.collect { case SessionBytes(_, bytes) => bytes }
          FlowShape(sendBytes.in, inboundFlow.outlet)
    })

  def initSslContext(settings: SecuritySettings): SSLContext = {
    val context = SSLContext.getInstance(settings.protocol)
    context.init(createKeyManagersIfPossible(settings), createTrustManagers(settings), SecureRandomHolder.get())
    context
  }

  private def createKeyManagersIfPossible(settings: SecuritySettings): Array[KeyManager] = {
    def createKeyManager(keyStoreResource: String, keyStorePassword: Array[Char], keyPassword: Array[Char]) = {
      val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
      keyStore.load(loadResource(keyStoreResource), keyStorePassword)
      val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
      keyManagerFactory.init(keyStore, keyPassword)
      keyManagerFactory.getKeyManagers
    }
    settings.keyStore
      .map(ks => createKeyManager(ks, settings.keyStorePassword.toCharArray, settings.keyPassword.toCharArray))
      .toArray.flatten
  }

  private def createTrustManagers(settings: SecuritySettings): Array[TrustManager] = {
    val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
    trustStore.load(loadResource(settings.trustStore), settings.trustStorePassword.toCharArray)

    val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
    trustManagerFactory.init(trustStore)

    trustManagerFactory.getTrustManagers
  }

  private def loadResource(resource: String): InputStream = {
    Option(ClassLoaderUtils.classLoader.getResourceAsStream(resource)).getOrElse(new FileInputStream(resource))
  }
}
