package com.xebialabs.deployit.repository.sql.batch

import com.xebialabs.deployit.core.sql.batch.BatchCommand.Args
import com.xebialabs.deployit.core.sql.batch.{BatchCommandWithArgs, BatchCommandWithSetter, BatchExecutorRepository}
import com.xebialabs.deployit.core.sql.spring.{DeployJdbcTemplate, Setter}
import org.slf4j.{Logger, LoggerFactory}
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.jdbc.core.{BatchPreparedStatementSetter, JdbcTemplate}
import org.springframework.stereotype.Repository

import java.sql.PreparedStatement
import java.util
import javax.sql.DataSource
import scala.jdk.CollectionConverters._

class BatchExecutorRepositoryImpl(private val jdbcTemplate: JdbcTemplate,
                                  val maxBatch: Int = 1000) extends BatchExecutorRepository {

  private[BatchExecutorRepositoryImpl] val logger: Logger = LoggerFactory.getLogger(classOf[BatchExecutorRepositoryImpl])

  def execute(commands: Iterable[BatchCommandWithArgs]): Iterable[Array[Int]] =
    commands
      .groupBy(_.sql)
      .view
      .mapValues(_.map(_.args))
      .map { case (sql, args) => execute(sql, args.toVector) }
      .toSeq

  def executeWithSetter(commands: Iterable[BatchCommandWithSetter]): Iterable[Array[Int]] =
    commands
      .groupBy(_.sql)
      .view
      .mapValues(_.map(_.args))
      .map { case (sql, args) => execute(sql, args.toVector) }
      .toSeq

  def execute(commands: util.Map[String, Seq[Args]]): Unit =
    commands
      .entrySet()
      .forEach(batchCommands =>
        execute(batchCommands.getKey, batchCommands.getValue))

  def execute(sql: String, args: Seq[Args], argTypes: Array[Int] = Array.empty): Array[Int] = {
    if (logger.isDebugEnabled())
      logger.debug(s"Executing batched sql $sql for ${args.length} rows")

    if (args.isEmpty)
      Array.empty
    else if (args.length > maxBatch)
      args
        .grouped(maxBatch)
        .flatMap(argsSplit => jdbcTemplate.batchUpdate(sql, argsSplit.asJava, argTypes))
        .toArray
    else
      jdbcTemplate.batchUpdate(sql, args.asJava, argTypes)
  }

  def execute(sql: String, batchSetters: Seq[Setter]): Array[Int] = {

    if (logger.isDebugEnabled())
      logger.debug(s"Executing batched sql $sql for ${batchSetters.length} rows")

    if (batchSetters.isEmpty)
      Array.empty
    else if (batchSetters.length > maxBatch)
      batchSetters
        .grouped(maxBatch)
        .flatMap(batchSettersSplit => batchUpdateBatchPreparedStatementSetter(sql, batchSettersSplit.toArray))
        .toArray
    else
      batchUpdateBatchPreparedStatementSetter(sql, batchSetters.toArray)
  }

  private def batchUpdateBatchPreparedStatementSetter(sql: String, batchSetters: Array[Setter]): Array[Int] =
    jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
      override def setValues(ps: PreparedStatement, i: Int): Unit = {
        val setter = batchSetters(i)
        setter.setValues(ps)
      }

      override def getBatchSize: Int = batchSetters.length
    })
}

@Repository
@Qualifier("mainBatchExecutorRepository")
class MainBatchExecutorRepository(@Autowired @Qualifier("mainJdbcTemplate") jdbcTemplate: JdbcTemplate)
  extends BatchExecutorRepositoryImpl(jdbcTemplate)

@Repository
@Qualifier("reportingBatchExecutorRepository")
class ReportingBatchExecutorRepository(@Autowired @Qualifier("reportingDataSource") val reportingDataSource: DataSource)
  extends BatchExecutorRepositoryImpl(new DeployJdbcTemplate(reportingDataSource, false))
