package com.xebialabs.deployit.repository.sql

import com.xebialabs.deployit.plugin.api.reflect.Type
import com.xebialabs.deployit.plugin.api.udm.ConfigurationItem
import com.xebialabs.license.LicenseCiCounter.LicensedCiUse
import com.xebialabs.license.{AmountOfCisExceededException, LicenseCiCounter}
import com.xebialabs.license.service.LicenseTransaction

class SqlLicenseCiCounter(val allowedCiAmounts: Map[Type, Int], val sqlProvider: SqlLicenseUsageProvider) extends LicenseCiCounter {
  override def getCiCount(ciType: Type): Int = sqlProvider.getCiCount(ciType)

  override def licensedCisInUse(): Array[LicensedCiUse] = sqlProvider.licensedCisInUse(allowedCiAmounts)

  override def registerCisCreation(cis: Seq[ConfigurationItem], transaction: LicenseTransaction): Unit = {
    registerTypesCreation(cis.map(_.getType), transaction)
  }

  override def registerTypesCreation(tt: Seq[Type], transaction: LicenseTransaction): Unit = {
    countRestrictedTypes(tt).foreach {
      case (t: Type, cnt: Int) if allowedCiAmounts(t) - getCiCount(t) >= cnt =>
        transaction.registerCreate(t, cnt)
      case (t: Type, _) =>
        throw new AmountOfCisExceededException(
          s"Unable to create ${t}. Your license is limited to ${allowedCiAmounts(t)} ${t} CIs and you currently have ${getCiCount(t)}.")
    }
    validate()
  }

  override def registerTypeRemoval(t: Type, transaction: LicenseTransaction): Unit = {
    registerTypesRemoval(Seq(t), transaction)
  }

  override def registerTypesRemoval(tt: Seq[Type], transaction: LicenseTransaction): Unit = {
    sqlProvider.licensedCisInUse(allowedCiAmounts).foreach {
      case LicensedCiUse(t, _, cnt) =>
        transaction.registerDelete(t, cnt)
    }
  }

  override def rollbackTransaction(tt: Type, transaction: LicenseTransaction): Unit = {}

  override def restrictedTypes: Set[Type] = allowedCiAmounts.keys.toSet

  override protected def findViolations(): Seq[LicensedCiUse] = {
    val ciInUse = sqlProvider.licensedCisInUse(allowedCiAmounts).map(ciUse => (ciUse.`type`, ciUse.actualAmount)).toMap
    allowedCiAmounts.collect {
      case (ciType: Type, maxAllowed: Int) if ciInUse.get(ciType).exists(used => used > maxAllowed) =>
        LicensedCiUse(ciType, maxAllowed, ciInUse(ciType))
    }.toSeq
  }

  override protected def countRestrictedTypes(types: Seq[Type]): Map[Type, Int] = super.countRestrictedTypes(types)
}
