package com.xebialabs.deployit.repository.sql

import java.sql.Connection
import com.xebialabs.deployit.repository.RepositoryDatabaseMetaDataService

import javax.sql.DataSource
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.jdbc.core.JdbcTemplate
import org.springframework.stereotype.Repository

@Repository
class DatabaseMetaDataServiceRepository(@Autowired @Qualifier("mainJdbcTemplate") val jdbcTemplate: JdbcTemplate,
                                        @Autowired @Qualifier("reportingDataSource")  val reportingDataSource: DataSource)
  extends RepositoryDatabaseMetaDataService {
  override def findMainDbImplementation: String = DatabaseInfo(jdbcTemplate.getDataSource).toString
  override def findReportDbImplementation: String = DatabaseInfo(reportingDataSource).toString
}

class DatabaseMetaData(val dbName: String, val dbVersion: String) {
}

sealed trait DatabaseInfo {
  def metadata: DatabaseMetaData

  var dbName: String = metadata.dbName.toLowerCase

  var dbVersion: String = metadata.dbVersion

  override def toString: String = s"$dbName:$dbVersion"
}

object DatabaseInfo {
  def apply(ds: DataSource): DatabaseInfo = {
    var connection: Connection = null

    try {
      connection = ds.getConnection
      val metadata: java.sql.DatabaseMetaData = connection.getMetaData
      DatabaseInfo(new DatabaseMetaData(metadata.getDatabaseProductName, metadata.getDatabaseProductVersion))
    } catch {
      case t: Throwable => DatabaseInfo(new DatabaseMetaData("Unknown", "Unknown"))
    } finally if (connection != null) connection.close()
  }

  def apply(metadata: DatabaseMetaData): DatabaseInfo = {
    metadata.dbName.toLowerCase match {
      case dbNameLower if dbNameLower.contains("db2") => Db2(metadata)
      case dbNameLower if dbNameLower.contains("derby") => Derby(metadata)
      case dbNameLower if dbNameLower.contains("h2") => H2(metadata)
      case dbNameLower if dbNameLower.contains("sql server") => MsSqlServer(metadata)
      case dbNameLower if dbNameLower.contains("mysql") => MySql(metadata)
      case dbNameLower if dbNameLower.contains("oracle") => Oracle(metadata)
      case dbNameLower if dbNameLower.contains("postgresql") => PostgreSql(metadata)
      case _ => Unknown(metadata)
    }
  }

  case class Unknown(metadata: DatabaseMetaData) extends DatabaseInfo

  case class Db2(metadata: DatabaseMetaData) extends DatabaseInfo

  case class Derby(metadata: DatabaseMetaData) extends DatabaseInfo

  case class H2(metadata: DatabaseMetaData) extends DatabaseInfo

  case class MsSqlServer(metadata: DatabaseMetaData) extends DatabaseInfo

  case class MySql(metadata: DatabaseMetaData) extends DatabaseInfo

  case class Oracle(metadata: DatabaseMetaData) extends DatabaseInfo

  case class PostgreSql(metadata: DatabaseMetaData) extends DatabaseInfo

}
