package com.xebialabs.xlrelease.runner.crypto.encrypt

import com.xebialabs.xlrelease.runner.domain.{EncryptedContextData, PlainContextData}

import java.security.spec.X509EncodedKeySpec
import java.security.{Key, KeyFactory, PublicKey}
import java.util.Base64
import javax.crypto.Cipher

object RSACoder {
  private final lazy val encoder = Base64.getEncoder
  private final lazy val decoder = Base64.getDecoder
  private final val ENCODED_REGEX = "\\{rsa:([a-z0-9]+)}(.+)".r

  private def getPrefix(keyVersion: String): String = s"{rsa:$keyVersion}"

  def encode(toEncode: Array[Byte], keyVersion: String): String = {
    val prefix = getPrefix(keyVersion)
    val encodedString = encoder.encodeToString(toEncode)
    s"$prefix$encodedString"
  }

  def decode(toDecode: String): (Array[Byte], String) = {
    if (toDecode.startsWith(s"{rsa:")) {
      toDecode match {
        case ENCODED_REGEX(keyVersion, bytesToDecode) =>
          (decoder.decode(bytesToDecode), keyVersion)
        case _ => throw new RuntimeException("Provided string does not follow encoding criteria. Unable to decode.")
      }
    } else {
      throw new RuntimeException("Unable to decode provided string")
    }
  }
}

object JobRSAUtils {
  private final val ALGORITHM: String = "RSA"
  final val DEFAULT_KEY_VERSION: String = "v0"


  /**
   * Generates an AES-256 session key to encrypt the content. Session-key is encrypted with provided public-key.
   *
   * @param content   Content to encrypt
   * @param publicKey Base-64 encoded public key
   * @param keyVersion The version identifier for the public key
   * @return The encrypted context data
   */
  def encryptContextData(content: String, publicKey: String, keyVersion: String): EncryptedContextData = {
    encryptContextData(content, decodePublicKey(publicKey), keyVersion)
  }

  /**
   * Decrypts the session-key with provided public-key. Data is decrypted with the session-key.
   *
   * @param content   Encrypted context data
   * @param publicKey Base-64 encoded public key
   * @return The decrypted plain data
   */
  def decryptContextData(content: EncryptedContextData, publicKey: String): PlainContextData = {
    decryptContextData(content, decodePublicKey(publicKey))
  }

  /**
   * Generates an AES-256 session key to encrypt the content. Session-key is encrypted with provided public-key.
   *
   * @param content Content to encrypt
   * @param key     The key to encrypt the session-key
   * @param keyVersion The version identifier for the public key
   * @return The encrypted context data
   */
  def encryptContextData(content: String, key: Key, keyVersion: String): EncryptedContextData = {
    val jobDataEncryptor = JobDataEncryptor(true)
    val encryptedContent = jobDataEncryptor.encrypt(content)
    val sessionKeyBytes = jobDataEncryptor.getEncodedSecretKey
    val encryptedSessionKey = encryptSessionKey(sessionKeyBytes, key, keyVersion)

    EncryptedContextData(
      data = encryptedContent,
      sessionKey = encryptedSessionKey
    )
  }

  /**
   * Decrypts the session-key with provided public-key. Data is decrypted with the session-key.
   *
   * @param encryptedContent Encrypted context data
   * @param key              The key to encrypt the session-key
   * @return The decrypted plain data
   */
  def decryptContextData(encryptedContent: EncryptedContextData, key: Key): PlainContextData = {
    val sessionKeyBytes = decryptSessionKey(encryptedContent.sessionKey, key)
    val jobDataEncryptor = JobDataEncryptor(true)
    jobDataEncryptor.setSecretKey(sessionKeyBytes)

    val decryptedContent = jobDataEncryptor.decrypt(encryptedContent.data)
    PlainContextData(data = decryptedContent)
  }

  private def decodePublicKey(b64PublicKey: String): PublicKey = {
    val encodedPublicKey = Base64.getDecoder.decode(b64PublicKey)
    val keyFactory = KeyFactory.getInstance(ALGORITHM)
    val publicKeySpec = new X509EncodedKeySpec(encodedPublicKey)
    val publicKey = keyFactory.generatePublic(publicKeySpec)
    publicKey
  }

  private def encryptSessionKey(sessionKeyBytes: Array[Byte], key: Key, keyVersion: String): String = {
    val encryptCipher = Cipher.getInstance(ALGORITHM)
    encryptCipher.init(Cipher.ENCRYPT_MODE, key)

    val encryptedSessionKeyBytes = encryptCipher.doFinal(sessionKeyBytes)
    RSACoder.encode(encryptedSessionKeyBytes, keyVersion)
  }

  private def decryptSessionKey(sessionKey: String, key: Key): Array[Byte] = {
    val (encryptedSessionKeyBytes, _) = RSACoder.decode(sessionKey)
    val decryptCipher = Cipher.getInstance(ALGORITHM)
    decryptCipher.init(Cipher.DECRYPT_MODE, key)
    decryptCipher.doFinal(encryptedSessionKeyBytes)
  }
}
