package com.xebialabs.deployit.core.sql

import javax.sql.DataSource

import scala.util.matching.Regex

object SqlDialect {

  def initializeDialect(dataSource: DataSource): SqlDialect = {
    val connection = dataSource.getConnection
    try {
      connection.getMetaData.getDatabaseProductName.toLowerCase() match {
        case DerbyDialect.regex() => DerbyDialect
        case PostgresqlDialect.regex() => PostgresqlDialect
        case MysqlDialect.regex() => MysqlDialect
        case OracleDialect.regex() => OracleDialect
        case MssqlDialect.regex() => MssqlDialect
        case Db2Dialect.regex() => Db2Dialect
        case name => UnknownDialect(name)
      }
    } finally {
      connection.close()
    }
  }

}

sealed trait SqlDialect {
  protected[sql] val regex: Regex
  def supportPartitionBy: Boolean
  def tableWithSchemaName(table: String, schema: String): String
  def aliasTableWithSchemaName(table: String, schema: String, alias: String): String
  def aliasTable(table: String, alias: String): String
  def aliasColumn(column: String, alias: String): String
  def quote(name: String): String
  def addPaging(query: String, page: Int, pageSize: Int): String
  def month(columnName: ColumnName): SqlFunction
  def year(columnName: ColumnName): SqlFunction
  def standardDeviation(columnName: ColumnName): SqlFunction
  def concat(s1: Selectable, s2: Selectable): SqlFunction
  def substr(columnName: ColumnName): SqlFunction
  def strLength(columnName: ColumnName): SqlFunction
  def abortsTransactionOnException: Boolean
  def escapeBackSlash(): String
  def paramWithCollation(): Selectable
  def enforceNotNull(value: String): String
  def castToString(columnName: Selectable): SqlFunction
  def castToInt(columnName: Selectable): SqlFunction
  def lockSelectBuilder(tableName: TableName = TableName("XLD_LOCK_TABLE"))(implicit schemaInfo: SchemaInfo): SimpleSelectBuilder
  def inClauseLimit: Int
}

private object DialectFunctions {
  val stddev = new SimpleSqlFunction("stddev")(_)
  val stddev_samp = new SimpleSqlFunction("stddev_samp")(_)
  val stdev = new SimpleSqlFunction("stdev")(_)

  val month = new SimpleSqlFunction("month")(_)
  val date_part_month = new StaticSqlFunction("date_part('month', %s)::int")(_: Selectable)
  val extract_month = new StaticSqlFunction("extract(month from %s)")(_: Selectable)

  val year = new SimpleSqlFunction("year")(_)
  val date_part_year = new StaticSqlFunction("date_part('year', %s)::int")(_: Selectable)
  val extract_year = new StaticSqlFunction("extract(year from %s)")(_: Selectable)

  val concat = new StaticSqlFunction("concat(%s,%s)")(_: Selectable, _: Selectable)
  val bars = new StaticSqlFunction("%s || %s")(_: Selectable, _: Selectable)

  val substr = new StaticSqlFunction("substr(%s,?)")(_: Selectable)
  val substr2 = new StaticSqlFunction("substring(%s, ?, %d)")(_: Selectable, _: Int)

  val length = new SimpleSqlFunction("length")(_)
  val char_length = new SimpleSqlFunction("char_length")(_)
  val len = new SimpleSqlFunction("len")(_)

  val cast = new StaticSqlFunction("CAST(%s as %s)")(_: Selectable, _: String)

}

sealed abstract class AbstractDialect(val name: String) extends SqlDialect {
  override protected[sql] val regex: Regex = name.r.unanchored
  override def toString: String = name
  override def supportPartitionBy: Boolean = true
  override def tableWithSchemaName(table: String, schema: String): String = s"$schema.$table"
  override def aliasTableWithSchemaName(table: String, schema: String, alias: String): String = s"$schema.${aliasTable(table, alias)}"
  override def aliasTable(table: String, alias: String): String = s"$table as $alias"
  override def aliasColumn(column: String, alias: String): String = s"$column as $alias"
  override def quote(name: String): String = s""""$name""""
  override def month(columnName: ColumnName): SqlFunction = DialectFunctions.month(columnName)
  override def year(columnName: ColumnName): SqlFunction = DialectFunctions.year(columnName)
  override def standardDeviation(columnName: ColumnName): SqlFunction = DialectFunctions.stddev(columnName)
  override def concat(s1: Selectable, s2: Selectable): SqlFunction = DialectFunctions.bars(s1, s2)
  override def substr(columnName: ColumnName): SqlFunction = DialectFunctions.substr(columnName)
  override def abortsTransactionOnException: Boolean = false
  override def escapeBackSlash(): String = "{escape '\\'}"
  override def paramWithCollation(): Selectable = new SqlLiteral("?")
  override def enforceNotNull(value: String): String = value
  override def castToString(columnName: Selectable): SqlFunction = DialectFunctions.cast(columnName, "VARCHAR(255)")
  override def castToInt(columnName: Selectable): SqlFunction = DialectFunctions.cast(columnName, "int")
  override def lockSelectBuilder(tableName: TableName)(implicit schemaInfo: SchemaInfo): SimpleSelectBuilder = new DefaultLockSelectBuilder(tableName)
  override def inClauseLimit: Int = Short.MaxValue
}

trait Sql2008 {
  self: SqlDialect =>
  override def addPaging(query: String, page: Int, pageSize: Int): String = if (page > 1) {
    s"$query OFFSET ${(page - 1) * pageSize} ROWS FETCH FIRST $pageSize ROWS ONLY"
  } else {
    s"$query FETCH FIRST $pageSize ROWS ONLY"
  }
}

trait LimitOffset {
  self: SqlDialect =>
  override def addPaging(query: String, page: Int, pageSize: Int): String = if (page > 1) {
    s"$query LIMIT $pageSize OFFSET ${(page - 1) * pageSize}"
  } else {
    s"$query LIMIT $pageSize"
  }
}

object DerbyDialect extends AbstractDialect("derby") with Sql2008 {
  override def supportPartitionBy: Boolean = false
  override def standardDeviation(columnName: ColumnName): SqlFunction = DialectFunctions.stddev_samp(columnName)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.length(columnName)
  override def inClauseLimit: Int = 100
}

object PostgresqlDialect extends AbstractDialect("postgres") with LimitOffset {
  override def month(columnName: ColumnName): SqlFunction = DialectFunctions.date_part_month(columnName)
  override def year(columnName: ColumnName): SqlFunction = DialectFunctions.date_part_year(columnName)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.char_length(columnName)
  override def abortsTransactionOnException: Boolean = true
  override def inClauseLimit: Int = 16000
}

object OracleDialect extends AbstractDialect("oracle") {
  override def aliasTable(table: String, alias: String): String = s"$table $alias"
  override def addPaging(query: String, page: Int, pageSize: Int): String = if (page > 1) {
    s"""
       |SELECT *
       |FROM (
       |  SELECT tmp.*, rownum rn
       |  FROM ($query) tmp
       |  WHERE rownum <= ${page * pageSize})
       |WHERE rn > ${(page - 1) * pageSize}""".stripMargin
  } else {
    s"SELECT tmp.*, rownum rn FROM ($query) tmp WHERE rownum <= $pageSize"
  }
  override def month(columnName: ColumnName): SqlFunction = DialectFunctions.extract_month(columnName)
  override def year(columnName: ColumnName): SqlFunction = DialectFunctions.extract_year(columnName)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.length(columnName)
  override def enforceNotNull(value: String): String = value match {
    case s: String => s
    case null => ""
  }
  override def inClauseLimit: Int = 1000
}

object MysqlDialect extends AbstractDialect("mysql") with LimitOffset {
  override def supportPartitionBy: Boolean = false
  override def quote(name: String): String = s"`$name`"
  override def concat(s1: Selectable, s2: Selectable): SqlFunction = DialectFunctions.concat(s1, s2)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.char_length(columnName)
  override def escapeBackSlash(): String = ""
  override def castToString(columnName: Selectable): SqlFunction = DialectFunctions.cast(columnName, "CHAR(255)")
  override def castToInt(columnName: Selectable): SqlFunction = DialectFunctions.cast(columnName, "unsigned")
}

object MssqlDialect extends AbstractDialect("microsoft sql") with Sql2008 {
  override def quote(name: String): String = s"[$name]"
  override def addPaging(query: String, page: Int, pageSize: Int): String = if (page > 1) {
    s"$query OFFSET ${(page - 1) * pageSize} ROWS FETCH FIRST $pageSize ROWS ONLY"
  } else {
    s"$query OFFSET 0 ROWS FETCH FIRST $pageSize ROWS ONLY"
  }
  override def standardDeviation(columnName: ColumnName): SqlFunction = DialectFunctions.stdev(columnName)
  override def concat(s1: Selectable, s2: Selectable): SqlFunction = DialectFunctions.concat(s1, s2)
  override def substr(columnName: ColumnName): SqlFunction = DialectFunctions.substr2(columnName, 2000)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.len(columnName)
  override def paramWithCollation(): Selectable = new SqlLiteral("? collate Latin1_General_CS_AS")
  override def lockSelectBuilder(tableName: TableName)(implicit schemaInfo: SchemaInfo): SimpleSelectBuilder = new MssqlLockSelectBuilder(tableName)
  override def inClauseLimit: Int = 1000
}

object Db2Dialect extends AbstractDialect("db2") with LimitOffset {
  override def concat(s1: Selectable, s2: Selectable): SqlFunction = DialectFunctions.concat(s1, s2)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.length(columnName)
  override def inClauseLimit: Int = 1000
}

case class UnknownDialect(name: String) extends SqlDialect with Sql2008 {
  override val regex: Regex = null

  override def quote(name: String): String = s""""$name""""
  override def tableWithSchemaName(table: String, schema: String): String = s"$schema.$table"
  override def aliasTableWithSchemaName(table: String, schema: String, alias: String): String = s"$schema.${aliasTable(table, alias)}"
  override def aliasTable(table: String, alias: String): String = s"$table as $alias"
  override def aliasColumn(column: String, alias: String): String = s"$column as $alias"
  override def supportPartitionBy: Boolean = true
  override def toString: String = s"unknown '$name' (unsupported)"
  override def month(columnName: ColumnName): SimpleSqlFunction = DialectFunctions.month(columnName)
  override def year(columnName: ColumnName): SimpleSqlFunction = DialectFunctions.year(columnName)
  override def standardDeviation(columnName: ColumnName): SimpleSqlFunction = DialectFunctions.stddev(columnName)
  override def concat(s1: Selectable, s2: Selectable): SqlFunction = DialectFunctions.bars(s1, s2)
  override def substr(columnName: ColumnName): SqlFunction = DialectFunctions.substr(columnName)
  override def strLength(columnName: ColumnName): SqlFunction = DialectFunctions.length(columnName)
  override def abortsTransactionOnException: Boolean = false
  override def escapeBackSlash(): String = "{escape '\\'}"
  override def paramWithCollation(): Selectable = new SqlLiteral("?")
  override def enforceNotNull(value: String): String = value
  override def castToString(columnName: Selectable): SqlFunction = DialectFunctions.cast(columnName, "VARCHAR(255)")
  override def castToInt(columnName: Selectable): SqlFunction = DialectFunctions.cast(columnName, "int")
  override def lockSelectBuilder(tableName: TableName)(implicit schemaInfo: SchemaInfo): SimpleSelectBuilder = new DefaultLockSelectBuilder(tableName)
  override def inClauseLimit: Int = 1000
}
