package com.xebialabs.deployit.repository.sql.artifacts

import java.io.{File, InputStream}
import java.sql.PreparedStatement
import com.xebialabs.deployit.checksum.ChecksumAlgorithmProvider
import com.xebialabs.deployit.core.sql.spring.Setter
import com.xebialabs.deployit.core.sql.{ColumnName, Queries, SchemaInfo, TableName}
import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.deployit.repository.sql.base._
import com.xebialabs.deployit.sql.base.schema.CIS
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.jdbc.core.JdbcTemplate
import org.springframework.jdbc.datasource.DataSourceUtils
import org.springframework.jdbc.support.JdbcUtils
import org.springframework.stereotype.Component

@Component
class DbArtifactRepository(@Autowired @Qualifier("mainJdbcTemplate") val jdbcTemplate: JdbcTemplate)
                          (@Autowired @Qualifier("mainSchema") implicit val schemaInfo: SchemaInfo)
  extends DbArtifactQueries {

  def insertArtifact(id: String, inputStream: InputStream): Int = {
    jdbcTemplate.update(INSERT_ARTIFACT, (ps: PreparedStatement) => {
      ps.setString(1, id)
      ps.setBinaryStream(2, inputStream)
    })
  }

  def existsArtifact(id: String): Boolean =
    jdbcTemplate.queryForObject(COUNT_ARTIFACT_BY_ID, classOf[Number], id).intValue() > 0

  def removeArtifact(artifactId: String): Unit = jdbcTemplate.update(DELETE_ARTIFACT, artifactId)

  def countArtifacts: Int = jdbcTemplate.queryForObject(COUNT_ALL_ARTIFACTS, classOf[Number]).intValue()

  def getArtifactInputStream(id: String): InputStream = new InputStream {

    private val con = DataSourceUtils.getConnection(jdbcTemplate.getDataSource)
    private val ps = con.prepareStatement(SELECT_ARTIFACT_BY_PATH)
    Setter.setString(ps, 1, idToPath(id))
    private val rs = ps.executeQuery()
    if (!rs.next()) throw new NotFoundException(s"Artifact data not found for: $id")

    private val delegate = rs.getBinaryStream(1)

    override def read(): Int = delegate.read()

    override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)

    override def close(): Unit = {
      JdbcUtils.closeResultSet(rs)
      JdbcUtils.closeStatement(ps)
      DataSourceUtils.releaseConnection(con, jdbcTemplate.getDataSource)
    }
  }
}

@Component
class DbArtifactDataRepository(@Autowired @Qualifier("mainJdbcTemplate") val jdbcTemplate: JdbcTemplate,
                               @Autowired implicit val artifactRepository: DbArtifactRepository,
                               @Autowired @Qualifier("artifactRepositoryRoot") override val root: File,
                               @Autowired implicit val checksumAlgorithmProvider: ChecksumAlgorithmProvider)
                              (@Autowired @Qualifier("mainSchema") implicit val schemaInfo: SchemaInfo)
  extends ArtifactDataRepository with DbArtifactUsageQueries with TmpFileArtifactHandler {

  override def refFromChecksum(checksum: String): String = checksum

  override def store(pk: CiPKType, inputStream: InputStream): Unit = {
    storeToTempWithEval(inputStream) { (tempInputStream, checksum) =>
      if (!artifactRepository.existsArtifact(checksum)) {
        artifactRepository.insertArtifact(checksum, tempInputStream)
      }
      // insert usage
      insert(pk, checksum)
    }
  }

  override def insert(pk: CiPKType, reference: String): Unit =
    jdbcTemplate.update(INSERT_USAGE, pk, reference)

  override def copy(fromId: String, toPk: CiPKType): Unit =
    jdbcTemplate.update(COPY_USAGE, toPk, idToPath(fromId))

  override def retrieve(id: String): InputStream = artifactRepository.getArtifactInputStream(id)

  override def delete(pk: CiPKType): Unit =
    jdbcTemplate.queryForList(SELECT_ARTIFACT_BY_USAGE, classOf[String], pk).forEach { artifactId =>
      val count = jdbcTemplate.queryForObject(COUNT_USAGE, classOf[Number], artifactId).intValue()
      if (count > 0) {
        jdbcTemplate.update(DELETE_USAGE, pk)
      }
      if (count == 1) {
        // artifact
        artifactRepository.removeArtifact(artifactId)
      }
    }

  override def countArtifacts: Int = artifactRepository.countArtifacts
}

object DbArtifactSchema {
  val tableName = TableName("XLD_DB_ARTIFACTS")

  val ID = ColumnName("ID")
  val length = ColumnName("length")
  val data = ColumnName("data")
}

object DbArtifactUsageSchema {
  val tableName = TableName("XLD_DB_ART_USAGE")

  val ci_id = ColumnName("ci_id")
  val artifact_id = ColumnName("artifact_id")
}

trait DbArtifactQueries extends Queries {

  import DbArtifactSchema._

  val COUNT_ALL_ARTIFACTS = sqlb"select count(*) from $tableName"
  val INSERT_ARTIFACT = sqlb"insert into $tableName ($ID, $data) values (?,?)"
  val UPDATE_ARTIFACT_ID = sqlb"update $tableName set $ID = ? where $ID = ?"
  val COUNT_ARTIFACT_BY_ID = sqlb"select count(*) from $tableName where $ID = ?"
  val SELECT_ARTIFACT_BY_PATH: String =
    sqlb"""select artifact.$data
          |from $tableName artifact
          |inner join ${DbArtifactUsageSchema.tableName} usg
          |on (artifact.$ID = usg.${DbArtifactUsageSchema.artifact_id})
          |inner join ${CIS.tableName} ci
          |on (usg.${DbArtifactUsageSchema.ci_id} = ci.${CIS.ID})
          |where ci.${CIS.path} = ?"""
  val DELETE_ARTIFACT = sqlb"delete from $tableName where $ID = ?"
}

trait DbArtifactUsageQueries extends Queries {

  import DbArtifactUsageSchema._

  val INSERT_USAGE = sqlb"insert into $tableName ($ci_id, $artifact_id) values (?, ?)"
  val COPY_USAGE =
    sqlb"""insert into $tableName ($ci_id, $artifact_id)
          |select ?, $artifact_id from $tableName usg
          |inner join ${CIS.tableName} ci
          |on (usg.${FileArtifactUsageSchema.ci_id} = ci.${CIS.ID})
          |where ci.${CIS.path} = ?"""
  val SELECT_ARTIFACT_BY_USAGE = sqlb"select $artifact_id from $tableName where $ci_id = ?"
  val COUNT_USAGE = sqlb"select count(*) from $tableName where $artifact_id = ?"
  val DELETE_USAGE = sqlb"delete from $tableName where $ci_id = ?"
}
