package com.xebialabs.xlrelease.upgrade

import com.xebialabs.xlrelease.config.XlrConfig
import grizzled.slf4j.Logging
import org.springframework.transaction.support.TransactionTemplate

import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future, blocking}

object UpgradeSupport {

  trait ParallelSupport {
    self: Logging =>

    private lazy val xlrConfig = XlrConfig.getInstance
    private lazy val parallel = xlrConfig.upgrader.parallel

    implicit val ec = scala.concurrent.ExecutionContext.Implicits.global

    def doInParallel[T](items: Iterable[T])(block: T => Unit): Unit = {
      if (parallel) {
        val batchUpgrades = Future.sequence(items.map { item =>
          Future {
            blocking {
              block(item)
            }
          }
        })
        Await.result(batchUpgrades, Duration(10, TimeUnit.MINUTES))
      } else {
        items.foreach(block)
      }
    }
  }

  trait BatchSupport {
    self: Logging =>

    private lazy val xlrConfig = XlrConfig.getInstance
    private lazy val batchSize = xlrConfig.upgrader.batchSize

    case class Batch[T](items: Iterable[T], group: Int, logPrefix: String)

    def doInBatch[T](items: Iterable[T], batchSize: Int = batchSize, itemsName: String = "items")(block: Batch[T] => Unit): Unit = {
      val totalItemsNum = items.size
      val totalBatchNum = Math.max(1, (totalItemsNum.toDouble / batchSize.toDouble).ceil.toInt)
      val batchDigits = digits(totalBatchNum)

      def padBatchNum: Int => String = pad(batchDigits)

      def batchCounter(batchNum: Int): String = s"Batch[${padBatchNum(batchNum)}/${padBatchNum(totalBatchNum)}]"

      items.grouped(batchSize).zipWithIndex.foreach { case (batchItems, batchGroup) =>
        val logPrefix = batchCounter(batchGroup + 1)
        logger.info(s"$logPrefix: Upgrading ${batchItems.size} $itemsName")
        block(Batch(batchItems, batchGroup, logPrefix))
      }
    }

    protected def digits(num: Int): Int = Math.max(1, Math.log10(num.toDouble).ceil.toInt)

    protected def pad(size: Int, leadingZeros: Boolean = true)(num: Int): String =
      s"%${if (leadingZeros) "0" else ""}${Math.max(1, size)}d".format(num)
  }

  trait TransactionSupport {
    self: Logging =>
    def transactionTemplate: TransactionTemplate

    def doInTransaction(block: => Unit): Unit = {
      transactionTemplate.execute(_ =>
        try {
          block
        } catch {
          case e: Exception =>
            val errorMsg = s"Unable to successfully finish database transaction"
            logger.error(errorMsg, e)
            throw new IllegalStateException(errorMsg, e)
        }
      )
    }
  }

}
