package com.xebialabs.xlrelease.triggers.actors

import akka.actor.ActorRef
import akka.cluster.sharding.{ClusterSharding, ClusterShardingSettings, ShardRegion}
import com.xebialabs.xlplatform.cluster.ClusterMode.{FULL, HOT_STANDBY, STANDALONE}
import com.xebialabs.xlrelease.actors.ActorSystemHolder
import com.xebialabs.xlrelease.actors.initializer.ActorInitializer
import com.xebialabs.xlrelease.config.XlrConfig
import com.xebialabs.xlrelease.repository.TriggerRepository
import com.xebialabs.xlrelease.triggers.actors.TriggerActor.TriggerAction
import grizzled.slf4j.Logging
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Component

import scala.concurrent.{Await, Promise}

trait TriggerActorInitializer extends ActorInitializer with Logging

@Component
class TriggerActorHolder @Autowired()(xlrConfig: XlrConfig) {

  private val actorPromise = Promise[ActorRef]()

  def actorRef(): ActorRef = {
    Await.result(actorPromise.future, xlrConfig.timeouts.systemInitialization)
  }

  def resolveActorRef(actorRef: ActorRef): Unit = actorPromise.success(actorRef)
}

@Component
@Profile(Array(STANDALONE, HOT_STANDBY))
class NonClusteredTriggerProcessingActorInitializer @Autowired()(xlrConfig: XlrConfig,
                                                                 systemHolder: ActorSystemHolder,
                                                                 triggerRepository: TriggerRepository,
                                                                 triggerOperations: TriggerOperations,
                                                                 triggerActorHolder: TriggerActorHolder)
  extends TriggerActorInitializer {

  private def triggerActorMaker: TriggerActorMaker = (ctx, triggerId) => {

    ctx.actorOf(
      TriggerActor.props(clustered = false, triggerOperations, xlrConfig),
      triggerId.triggerActorName
    )
  }

  private lazy val triggerProcessingActor: ActorRef = systemHolder.actorSystem.actorOf(
    TriggerProcessingActor.props(triggerActorMaker),
    TriggerProcessingActor.name
  )

  override def initialize(): Unit = {
    logger.debug("Initializing non-clustered trigger processing actor...")
    triggerActorHolder.resolveActorRef(triggerProcessingActor)
  }
}

@Component
@Profile(Array(FULL))
class ClusteredTriggerProcessingActorInitializer @Autowired()(xlrConfig: XlrConfig,
                                                              systemHolder: ActorSystemHolder,
                                                              triggerOperations: TriggerOperations,
                                                              triggerActorHolder: TriggerActorHolder)
  extends TriggerActorInitializer {
  
  private def system = systemHolder.actorSystem

  private val numberOfShards = xlrConfig.sharding.numberOfReleaseShards
  private val shardingSettings = ClusterShardingSettings(system)

  def startTriggerSharding(): ActorRef = {
    val sharding = ClusterSharding(system)
    sharding.start(
      typeName = "Trigger",
      entityProps = TriggerActor.props(clustered = true, triggerOperations, xlrConfig),
      settings = shardingSettings,
      extractEntityId = extractTriggerId,
      extractShardId = extractTriggerShardId
    )
  }

  private def extractTriggerId: ShardRegion.ExtractEntityId = {
    case msg: TriggerAction => (msg.triggerId.triggerActorName, msg)
  }

  private def extractTriggerShardId: ShardRegion.ExtractShardId = {
    case msg: TriggerAction => extractShardId(msg.triggerId)
    case ShardRegion.StartEntity(id) => extractShardId(id)
  }

  private def extractShardId(id: String): String = (math.abs(id.triggerActorName.hashCode) % numberOfShards).toString

  override def initialize(): Unit = {
    logger.debug("Initializing clustered trigger processing actor...")
    val shardRegion = startTriggerSharding()
    triggerActorHolder.resolveActorRef(shardRegion)
  }
}
