package com.xebialabs.database.anonymizer

import ai.digital.configuration.central.deploy.db.{DatabaseProperties, MainDatabase, ReportingProperties}
import com.xebialabs.deployit.ServerConfigFile
import ai.digital.deploy.core.common.XldServerPaths.DEFAULT_CONFIGURATION_FILE
import com.xebialabs.deployit.util.{DeployitKeys, PasswordEncrypter}
import org.dbunit.database.{DatabaseConfig, DatabaseConnection, ForwardOnlyResultSetTableFactory}
import org.dbunit.dataset.datatype.DefaultDataTypeFactory
import org.dbunit.ext.db2.Db2DataTypeFactory
import org.dbunit.ext.mssql.{InsertIdentityOperation, MsSqlDataTypeFactory}
import org.dbunit.ext.mysql.MySqlDataTypeFactory
import org.dbunit.ext.oracle.OracleDataTypeFactory
import org.dbunit.ext.postgresql.PostgresqlDataTypeFactory
import org.dbunit.operation.DatabaseOperation
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Repository

import java.sql.{Connection, Driver}
import java.util.Properties

@Repository
class DatabaseRepository {

  @Autowired(required = true)
  var mainDatabase: MainDatabase = _

  @Autowired(required = true)
  var reportingProperties: ReportingProperties = _

  private def getDriverConnection(isReportingDb: Boolean) = {
    initEncryptionKey()
    val driver = dbDriver(isReportingDb)
    val dbUrl = databaseProperties(isReportingDb).dbUrl
    val driverConnection = driver.connect(dbUrl, dbCredentials(isReportingDb))
    driverConnection.setAutoCommit(true)
    driverConnection
  }

  private def dbDriver(isReportingDb: Boolean): Driver =
    Class.forName(databaseProperties(isReportingDb).dbDriverClassname).newInstance.asInstanceOf[Driver]

  private def dbCredentials(isReportingDb: Boolean): Properties = {
    val properties = new Properties
    properties.put("user", databaseProperties(isReportingDb).dbUsername)
    properties.put("password", databaseProperties(isReportingDb).dbPassword)
    properties
  }

  private def databaseProperties(isReportingDb: Boolean): DatabaseProperties =
    if (isReportingDb && reportingProperties.database.hasConfigured)
      reportingProperties.database else mainDatabase.database

  private def initEncryptionKey(): Unit = {
    val serverConfigFile = new ServerConfigFile(DEFAULT_CONFIGURATION_FILE)
    val serverConfiguration = serverConfigFile.loadConfig(false, false, true)
    val repositoryKeyStorePassword = serverConfiguration.getRepositoryKeyStorePassword
    val passwordEncryptionKey = DeployitKeys.getPasswordEncryptionKey(repositoryKeyStorePassword)
    PasswordEncrypter.init(passwordEncryptionKey)
  }

  def getDatabaseName(connection: Connection): String = {
    val metaData = connection.getMetaData
    metaData.getDatabaseProductName.toLowerCase
  }

  private def updateDatabaseConfig(connection: DatabaseConnection, databaseName: String): Unit = {
    val config = connection.getConfig
    config.setProperty(DatabaseConfig.FEATURE_CASE_SENSITIVE_TABLE_NAMES, true)
    config.setProperty(DatabaseConfig.FEATURE_QUALIFIED_TABLE_NAMES, true)
    config.setProperty(DatabaseConfig.FEATURE_ALLOW_EMPTY_FIELDS, true)
    config.setProperty(DatabaseConfig.PROPERTY_RESULTSET_TABLE_FACTORY, new ForwardOnlyResultSetTableFactory)

    DatabaseName.toValue(databaseName) match {
      case DatabaseName.Oracle =>
        config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new OracleDataTypeFactory)
        config.setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN, "\"?\"")
      case DatabaseName.MSSQL =>
        config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new MsSqlDataTypeFactory)
        config.setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN, "\"?\"")
      case DatabaseName.POSTGRES =>
        config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new PostgresqlDataTypeFactory)
        config.setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN, "\"?\"")
      case DatabaseName.MYSQL =>
        config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new MySqlDataTypeFactory)
        config.setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN, "`?`")
      case DatabaseName.DB2 =>
        config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new Db2DataTypeFactory)
        config.setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN, "\"?\"")
      case DatabaseName.DEFAULT =>
        config.setProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY, new DefaultDataTypeFactory)
        config.setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN, "\"?\"")
    }
  }

  def getDatabaseConnection(isReportingDb: Boolean): DatabaseConnection = {
    val driverConnection = getDriverConnection(isReportingDb)
    var connection: DatabaseConnection = null
    val databaseName = getDatabaseName(driverConnection)

    DatabaseName.toValue(databaseName) match {
      case DatabaseName.Oracle => connection = new DatabaseConnection(driverConnection, driverConnection.getMetaData.getUserName)
      case DatabaseName.MSSQL => connection = new DatabaseConnection(driverConnection, "dbo")
      case _ => connection = new DatabaseConnection(driverConnection)
    }

    updateDatabaseConfig(connection, databaseName)
    connection
  }

  def getDatabaseOperation(databaseName: String, isDbRefresh: Boolean): DatabaseOperation =
    DatabaseName.toValue(databaseName) match {
      case DatabaseName.MSSQL => if (isDbRefresh) InsertIdentityOperation.REFRESH else InsertIdentityOperation.INSERT
      case _ => if (isDbRefresh) DatabaseOperation.REFRESH else DatabaseOperation.INSERT
    }

  def enableBatchProcessing(connection: DatabaseConnection, batchSize: Int): Unit = {
    connection.getConfig.setProperty(DatabaseConfig.PROPERTY_BATCH_SIZE, batchSize)
    connection.getConfig.setProperty(DatabaseConfig.FEATURE_BATCHED_STATEMENTS, true)
  }
}
