package com.xebialabs.deployit.repository.sql.artifacts

import java.io.{File, FileInputStream, InputStream}
import java.sql.Connection
import java.util
import com.xebialabs.deployit.checksum.ChecksumAlgorithmProvider
import com.xebialabs.deployit.core.sql.{ColumnName, Queries, SchemaInfo, TableName, SqlCondition => cond}
import com.xebialabs.deployit.repository.sql.base._
import com.xebialabs.deployit.sql.base.schema.CIS
import grizzled.slf4j.Logging
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.jdbc.core.{ConnectionCallback, JdbcTemplate}
import org.springframework.jdbc.support.GeneratedKeyHolder
import org.springframework.stereotype.Component
import org.springframework.transaction.annotation.Transactional
import org.springframework.transaction.support.TransactionSynchronization.STATUS_COMMITTED
import org.springframework.transaction.support.{TransactionSynchronization, TransactionSynchronizationManager}

import java.nio.file.Files
import scala.util.{Failure, Success, Try}

@Component("fileArtifactDataRepository")
@Transactional("mainTransactionManager")
class FileArtifactDataRepositoryImpl(@Autowired @Qualifier("mainJdbcTemplate") override val jdbcTemplate: JdbcTemplate,
                                     @Autowired @Qualifier("artifactRepositoryRoot") override val root: File,
                                     @Autowired implicit val checksumAlgorithmProvider: ChecksumAlgorithmProvider)
                                    (@Autowired @Qualifier("mainSchema") override implicit val schemaInfo: SchemaInfo)
  extends FileArtifactDataRepository with FileArtifactQueries with FileArtifactUsageQueries with TmpFileArtifactHandler with Logging {

  import FileArtifactDataRepositoryImpl._

  if (root != null) root.mkdirs()

  override def fileFromRelativePath(path: String) = new File(root, path)

  override def refFromChecksum(checksum: String): String = pathFromChecksum(checksum)

  override def store(pk: CiPKType, inputStream: InputStream): Unit = {
    val (tempFile, checksum) = storeToTempWithChecksumCalc(inputStream)
    val path = pathFromChecksum(checksum)
    val file = fileFromRelativePath(path)
    TransactionSynchronizationManager.registerSynchronization(new MoveTempFileAfterCommit(tempFile, file))
    insert(pk, path)
  }

  private def executeLockQueryForLocation(location: String): util.List[CiPKType] = {
    val query = schemaInfo.sqlDialect.lockSelectBuilder(FileArtifactSchema.tableName).select(FileArtifactSchema.ID)
      .where(cond.equals(FileArtifactSchema.location, location)).query
    Try(jdbcTemplate.queryForList(query, classOf[CiPKType], location)) match {
      case Success(value) => value
      case Failure(_) => // retry
        Thread.sleep(100)
        jdbcTemplate.queryForList(query, classOf[CiPKType], location)
    }
  }

  private def executeLockQueryForId(id: CiPKType): util.List[CiPKType] = {
    val query = schemaInfo.sqlDialect.lockSelectBuilder(FileArtifactSchema.tableName).select(FileArtifactSchema.ID)
      .where(cond.equals(FileArtifactSchema.ID, id)).query
    Try(jdbcTemplate.queryForList(query, classOf[CiPKType], id)) match {
      case Success(value) => value
      case Failure(_) => // retry
        Thread.sleep(100)
        jdbcTemplate.queryForList(query, classOf[CiPKType], id)
    }
  }

  override def insert(pk: CiPKType, reference: String): Unit = {
    def getOrCreate(location: String): CiPKType = {
      val ids = executeLockQueryForLocation(location)
      if (ids.isEmpty) {
        debug("ids were empty - creating new artifact")
        val id: CiPKType = insert(location)
        debug(s"ids were empty - created new artifact $id")
        executeLockQueryForId(id)
        debug(s"ids were empty - fetched artifact $id with a lock")
        id
      } else {
        debug(s"ids were not empty returning ${ids.get(0)}")
        ids.get(0)
      }
    }
    jdbcTemplate.execute(new ConnectionCallback[Unit] {
      override def doInConnection(con: Connection): Unit = {
        val id = getOrCreate(reference)
        debug(s"Artifact id: ${id} location: ${reference} used for usage ${pk}")
        insertUsage(id, pk)
      }
    })
  }

  override def insert(location: String): Int = {
    val keyHolder = new GeneratedKeyHolder()
    jdbcTemplate.update((con: Connection) => {
      val preparedStatement = con.prepareStatement(INSERT_ARTIFACT, Array(FileArtifactSchema.ID.name))
      preparedStatement.setString(1, location)
      preparedStatement
    }, keyHolder)
    keyHolder.getKey.intValue()
  }

  override def insertUsage(id: Number, pk: CiPKType): Unit = {
    jdbcTemplate.update(INSERT_USAGE, pk, id)
  }

  override def copy(fromId: String, toPk: CiPKType): Unit = {
    jdbcTemplate.update(COPY_USAGE, toPk, idToPath(fromId))
  }

  override def retrieve(id: String): InputStream = {
    val location = jdbcTemplate.queryForObject(SELECT_LOCATION_BY_PATH, classOf[String], idToPath(id))
    new FileInputStream(fileFromRelativePath(location))
  }

  override def delete(pk: CiPKType): Unit =
    jdbcTemplate.queryForList(SELECT_ARTIFACT_BY_USAGE, classOf[CiPKType], pk).forEach { artifactId =>
      jdbcTemplate.execute(new ConnectionCallback[Unit] {
        override def doInConnection(con: Connection): Unit = {
          executeLockQueryForId(artifactId)
          val count = jdbcTemplate.queryForObject(COUNT_USAGE, classOf[Number], artifactId).intValue()
          if (count > 0) jdbcTemplate.update(DELETE_USAGE, pk)
          if (count == 1) {
            val file = getLocation(artifactId)
            logger.debug(s"Registering for after transaction deletion $artifactId.")
            TransactionSynchronizationManager.registerSynchronization(new AfterTransactionFileDeletion(file))
            remove(artifactId)
          }
        }
      })
    }

  override def getLocation(id: Number): File = {
    fileFromRelativePath(jdbcTemplate.queryForObject(SELECT_LOCATION_BY_ID, classOf[String], id))
  }

  override def remove(artifactId: Number): Int = {
    jdbcTemplate.update(DELETE_ARTIFACT, artifactId)
  }

  override def countArtifacts: Int = {
    jdbcTemplate.queryForObject(COUNT_ALL_ARTIFACTS, classOf[Number]).intValue()
  }
}

object FileArtifactDataRepositoryImpl {

  def pathFromChecksum(checksum: String) = s"${checksum.substring(0, 2)}/${checksum.substring(2, 4)}/${checksum.substring(4, 6)}/$checksum"

  def moveTempFileToArtifactFile(tempFile : File, artifactFile : File) : Unit = {
    artifactFile.getParentFile.mkdirs()
    if (!tempFile.renameTo(artifactFile)) {
      Files.move(tempFile.toPath, artifactFile.toPath)
    }
  }
}

object FileArtifactSchema {
  val tableName = TableName("XLD_FILE_ARTIFACTS")

  val ID = ColumnName("ID")
  val location = ColumnName("location")
}

object FileArtifactUsageSchema {
  val tableName = TableName("XLD_FILE_ART_USAGE")

  val ci_id = ColumnName("ci_id")
  val artifact_id = ColumnName("artifact_id")
}

trait FileArtifactQueries extends Queries {
  import FileArtifactSchema._

  val COUNT_ALL_ARTIFACTS = sqlb"select count(*) from $tableName"
  val INSERT_ARTIFACT = sqlb"insert into $tableName ($location) values (?)"
  val SELECT_LOCATION_BY_PATH =
    sqlb"""select artifact.$location
          |from $tableName artifact
          |inner join ${FileArtifactUsageSchema.tableName} usg
          |on (artifact.$ID = usg.${FileArtifactUsageSchema.artifact_id})
          |inner join ${CIS.tableName} ci
          |on (usg.${FileArtifactUsageSchema.ci_id} = ci.${CIS.ID})
          |where ci.${CIS.path} = ?"""
  val SELECT_LOCATION_BY_ID = sqlb"select $location from $tableName where $ID = ?"
  val DELETE_ARTIFACT = sqlb"delete from $tableName where $ID = ?"
}

trait FileArtifactUsageQueries extends Queries {
  import FileArtifactUsageSchema._

  val INSERT_USAGE = sqlb"insert into $tableName ($ci_id, $artifact_id) values (?, ?)"
  val COPY_USAGE =
    sqlb"""insert into $tableName ($ci_id, $artifact_id)
          |select ?, $artifact_id from $tableName usg
          |inner join ${CIS.tableName} ci
          |on (usg.${FileArtifactUsageSchema.ci_id} = ci.${CIS.ID})
          |where ci.${CIS.path} = ?"""
  val SELECT_ARTIFACT_BY_USAGE = sqlb"select $artifact_id from $tableName where $ci_id = ?"
  val COUNT_USAGE = sqlb"select count(*) from $tableName where $artifact_id = ?"
  val DELETE_USAGE = sqlb"delete from $tableName where $ci_id = ?"
}

class AfterTransactionFileDeletion(file: File) extends TransactionSynchronization with Logging {
  override def afterCommit(): Unit = {
    logger.debug(s"Deleting artifact file $file.")
    file.delete()
  }
}

class MoveTempFileAfterCommit(tempFile: File, artifactFile: File) extends TransactionSynchronization with Logging {
  override def afterCompletion(status: Int): Unit = {
    if (status == STATUS_COMMITTED && !artifactFile.exists()) {
      logger.debug(s"Moving artifact file from $tempFile to $artifactFile.")
      FileArtifactDataRepositoryImpl.moveTempFileToArtifactFile(tempFile, artifactFile)
    }
    tempFile.delete()
  }
}
