package com.xebialabs.license

import java.util.concurrent.atomic.AtomicInteger

import com.xebialabs.deployit.plugin.api.reflect.Type
import com.xebialabs.deployit.plugin.api.udm.ConfigurationItem
import com.xebialabs.deployit.repository.{SearchParameters, RepositoryService}
import com.xebialabs.license.LicenseCiCounter.LicensedCiUse
import com.xebialabs.license.service.LicenseTransaction

import collection.convert.wrapAll._
import scala.collection.concurrent.TrieMap

object LicenseCiCounter {

  case class LicensedCiUse(`type`: Type, allowedAmount: Int, actualAmount: Int) {
    override def toString: String = s"Your license is limited to $allowedAmount ${`type`} CIs and you currently have $actualAmount."
  }

  def readFrom(license: License, repositoryService: RepositoryService): LicenseCiCounter = {
    val restrictions = license.getMapValue(LicenseProperty.MAX_NUMBER_OF_CIS).map {
      case (typeString, n) => (Type.valueOf(typeString), n.toInt)
    }
    val counter = new LicenseCiCounter(restrictions.toMap)
    countRestrictedCisInRepo(repositoryService, counter)
    counter
  }

  def countRestrictedCisInRepo(repositoryService:RepositoryService, counter:LicenseCiCounter) {
    counter.allowedCiAmounts.keySet.foreach {
      case t =>
        val criteria = new SearchParameters()
        criteria.setType(t)
        counter.typeCounter.put(t, new AtomicInteger(repositoryService.list(criteria).size()))

    }
  }
}

class LicenseCiCounter(val allowedCiAmounts: Map[Type, Int]) {

  // This will be used from java
  def this() = this(Map.empty)

  val typeCounter = new TrieMap[Type, AtomicInteger]().withDefaultValue(new AtomicInteger(0))

  def getAtomicCiCount(t: Type): AtomicInteger = typeCounter.get(t) match {
    case Some(value) => value
    case None => new AtomicInteger(0)
  }

  def getCiCount(t: Type) = getAtomicCiCount(t).get()

  def licensedCisInUse = allowedCiAmounts.map {
    case (ciType, licensed) => LicensedCiUse(ciType, licensed, getCiCount(ciType))
  }.toArray // for Java interop

  /**
   * Validates current counters against the license limitations, throws [[AmountOfCisExceededException]] if actual amount of CIs exceeds allowed values.
   */
  def validate(): Unit = {
    findViolations() match {
      case Nil =>
      case violations => throw new AmountOfCisExceededException(s"The system reached the maximum allowed number of Configuration Items: ${violations.mkString(", ")}. Please check your license.")
    }
  }

  /**
   * Increases the counter according to given amount of configuration items and performs the validation.
   * If the validation failed, reverts counter changes and throws a validation exception.
   */
  def registerCisCreation(cis: Seq[ConfigurationItem], transaction: LicenseTransaction) {
    registerTypesCreation(cis.map(_.getType), transaction)
  }

  /**
   * Increases the counter according to given amount of types and performs the validation.
   * If the validation failed, reverts counter changes and throws a validation exception.
   */
  def registerTypesCreation(tt: Seq[Type], transaction: LicenseTransaction) {
    countRestrictedTypes(tt).foreach {
      case (t, cnt) if allowedCiAmounts(t) - getCiCount(t) >= cnt =>
        updateGeneralCounter(t, cnt)
        transaction.registerCreate(t, cnt)
      case (t, cnt) =>
        throw new AmountOfCisExceededException(s"Unable to create ${t}. Your license is limited to ${allowedCiAmounts(t)} ${t} CIs and you currently have ${getCiCount(t)}.")
    }

    validate()
  }

  /**
   * Decrements the counter according to the passed types.
   */
  def registerTypesRemoval(tt: Seq[Type], transaction: LicenseTransaction) {
    countRestrictedTypes(tt).foreach {
      case (t, cnt) =>
        updateGeneralCounter(t, -cnt)
        transaction.registerDelete(t, cnt)
    }
  }

  def updateGeneralCounter(t: Type, cnt: Int) {
    typeCounter.get(t) match {
      case Some(value) => value.addAndGet(cnt)
      case None => typeCounter.put(t, new AtomicInteger(cnt))
    }
  }

  /**
   * Decrements the counter for the passed type.
   */
  def registerTypeRemoval(t: Type, transaction: LicenseTransaction) = registerTypesRemoval(Seq(t), transaction)

  private def findViolations(): Seq[LicensedCiUse] = {
    allowedCiAmounts.collect {
      case (t, allowedAmount) if typeCounter.get(t).exists(_.get() > allowedAmount) => LicensedCiUse(t, allowedAmount, getCiCount(t))
    }.toSeq
  }

  private def countRestrictedTypes(types: Seq[Type]): Map[Type, Int] = {
    allowedCiAmounts.keys.collect {
      case t => (t, types.count(_.instanceOf(t)))
    }.toMap.filterNot(_._2 == 0)
  }
}
