package com.xebialabs.xlrelease.upgrade


import com.google.common.io.ByteStreams
import com.xebialabs.deployit.server.api.upgrade.{Upgrade, Version}
import com.xebialabs.xlrelease.db.ArchivedReleases.{REPORT_RELEASES_ID_COLUMN, REPORT_RELEASES_RELEASEJSON_COLUMN, REPORT_RELEASES_TABLE_NAME}
import com.xebialabs.xlrelease.db.sql.SqlBuilder.{CommonDialect, Dialect}
import com.xebialabs.xlrelease.upgrade.UpgradeSupport.TransactionSupport
import grizzled.slf4j.Logging
import org.codehaus.jettison.json.JSONObject
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.jdbc.core.{BatchPreparedStatementSetter, ConnectionCallback, JdbcTemplate}
import org.springframework.transaction.support.TransactionTemplate

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.charset.StandardCharsets.UTF_8
import java.sql.ResultSet._
import java.sql.{Connection, PreparedStatement, ResultSet}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Using

/**
 * This class must be subclassed when an upgrade must be performed
 * on both template import AND archived releases
 *
 * @see com.xebialabs.xlrelease.upgrade.JsonUpgrade
 */
abstract class JsonUpgrade extends Upgrade with ImportUpgrade with Logging with TransactionSupport {

  private val BATCH_SIZE = 20
  private val LOG_BATCH_PROGRESS_EVERY_ROWS = 1000

  @Autowired
  @Qualifier("reportingJdbcTemplate")
  private[upgrade] var jdbcTemplate: JdbcTemplate = _

  @Autowired
  @Qualifier("reportTransactionTemplate")
  var transactionTemplate: TransactionTemplate = _

  @Autowired
  @Qualifier("reportingSqlDialect")
  var reportingSqlDialect: Dialect = _

  override def getUpgradeVersion: Version = upgradeVersion()

  override def doUpgrade(): Boolean = {
    var processedRows = 0L
    var upgradedRows = 0L
    var lastProcessedId: String = null
    val rowsToUpdate: ArrayBuffer[ReleaseRow] = mutable.ArrayBuffer()

    val total: Long = jdbcTemplate.queryForObject(s"select count(1) from $REPORT_RELEASES_TABLE_NAME", classOf[Long])
    logger.info(s"About to upgrade $total rows")

    while (processedRows < total) {
      withCursor(query(lastProcessedId, processedRows)) { rs: ResultSet =>
        val ReleaseRow(releaseId: String, releaseJson: JSONObject) = ReleaseRow.from(rs)
        lastProcessedId = releaseId
        val result = performUpgrade(releaseJson)
        if (result.isUpgraded) {
          rowsToUpdate += ReleaseRow(releaseId, releaseJson)
        }

        if (rowsToUpdate.nonEmpty && rowsToUpdate.size % BATCH_SIZE == 0) {
          // we could have take a bit different approach and upgrade row in place
          // but then cursor would have to be created a bit differently and it would lock the table
          upgradedRows = upgradedRows + executeUpdates(rowsToUpdate)
        }

        processedRows = processedRows + 1
        if (processedRows > 0 && processedRows % LOG_BATCH_PROGRESS_EVERY_ROWS == 0) logger.info(s"Processed $processedRows / $total rows")
      }
    }

    // left-overs
    upgradedRows = upgradedRows + executeUpdates(rowsToUpdate)

    logger.info(s"Processed $processedRows and upgraded $upgradedRows out of total $total rows")

    true
  }

  private def query(lastProcessedId: String, processedRows: Long) = reportingSqlDialect match {
    case CommonDialect(dbName) if dbName.toLowerCase().contains("postgres") || dbName.toLowerCase().contains("mysql") =>
      val where = if (lastProcessedId != null) s"WHERE releaseId > '$lastProcessedId'" else ""
      s"select * from $REPORT_RELEASES_TABLE_NAME $where ORDER BY $REPORT_RELEASES_ID_COLUMN LIMIT $BATCH_SIZE"
    case _ =>
      s"select * from $REPORT_RELEASES_TABLE_NAME"
  }

  private def withCursor(sql: String)(block: ResultSet => Unit): Unit = {
    jdbcTemplate.execute(new ConnectionCallback[Unit] {
      override def doInConnection(conn: Connection): Unit = {
        val originalAutoCommit = conn.getAutoCommit
        conn.setAutoCommit(false)
        Using.resource(conn.prepareStatement(sql, TYPE_FORWARD_ONLY, CONCUR_READ_ONLY, HOLD_CURSORS_OVER_COMMIT)) { stmt =>
          stmt.setFetchSize(BATCH_SIZE)
          Using.resource(stmt.executeQuery()) { rs =>
            while (rs.next()) {
              block(rs)
            }
          }
        }
        conn.setAutoCommit(originalAutoCommit)
      }
    })
  }

  private def executeUpdates(updates: ArrayBuffer[ReleaseRow]): Int = {
    var res = 0
    if (updates.nonEmpty) {
      doInTransaction {
        jdbcTemplate.batchUpdate(s"update $REPORT_RELEASES_TABLE_NAME set $REPORT_RELEASES_RELEASEJSON_COLUMN = ? where $REPORT_RELEASES_ID_COLUMN = ?",
          new BatchPreparedStatementSetter {
            override def setValues(ps: PreparedStatement, i: Int): Unit = {
              val item = updates(i)
              ps.setBinaryStream(1, new ByteArrayInputStream(item.releaseContent()))
              ps.setString(2, item.releaseId)
            }

            override def getBatchSize: Int = updates.size
          }
        )
        res = updates.size
        updates.clear()
      }
    }
    res
  }

  case class ReleaseRow(releaseId: String, releaseJson: JSONObject) {
    def releaseContent(): Array[Byte] = releaseJson.toString.getBytes(UTF_8)
  }

  object ReleaseRow {
    def from(rs: ResultSet): ReleaseRow = {
      val releaseId = rs.getString(REPORT_RELEASES_ID_COLUMN)

      val outputStream = new ByteArrayOutputStream()
      val binaryStream = rs.getBinaryStream(REPORT_RELEASES_RELEASEJSON_COLUMN)
      try {
        ByteStreams.copy(binaryStream, outputStream)
      } finally {
        binaryStream.close()
      }
      val releaseContent = new String(outputStream.toByteArray)
      val releaseJson = new JSONObject(releaseContent)

      ReleaseRow(releaseId, releaseJson)
    }

  }
}
