package com.xebialabs.xlrelease.delivery.actors

import com.xebialabs.xlplatform.cluster.ClusterMode.{FULL, HOT_STANDBY, STANDALONE}
import com.xebialabs.xlrelease.actors.ActorSystemHolder
import com.xebialabs.xlrelease.actors.initializer.ActorInitializer
import com.xebialabs.xlrelease.config.XlrConfig
import com.xebialabs.xlrelease.delivery.actors.DeliveryActor.DeliveryAction
import com.xebialabs.xlrelease.delivery.service.{DeliveryExecutionService, DeliveryPatternService, DeliveryService}
import grizzled.slf4j.Logging
import org.apache.pekko.actor.ActorRef
import org.apache.pekko.cluster.sharding.{ClusterSharding, ClusterShardingSettings, ShardRegion}
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Component

import scala.concurrent.{Await, Promise}

trait DeliveryActorInitializer extends ActorInitializer with Logging

@Component
class DeliveryActorHolder @Autowired()(xlrConfig: XlrConfig) {

  private val actorPromise = Promise[ActorRef]()

  def actorRef(): ActorRef = {
    Await.result(actorPromise.future, xlrConfig.timeouts.systemInitialization)
  }

  def resolveActorRef(actorRef: ActorRef): Unit = actorPromise.success(actorRef)
}

@Component
@Profile(Array(STANDALONE, HOT_STANDBY))
class NonClusteredDeliveryProcessingActorInitializer @Autowired()(systemHolder: ActorSystemHolder,
                                                                  deliveryService: DeliveryService,
                                                                  deliveryPatternService: DeliveryPatternService,
                                                                  deliveryExecutionService: DeliveryExecutionService,
                                                                  deliveryActorHolder: DeliveryActorHolder) extends DeliveryActorInitializer {

  private def releaseDeliveryActorMaker: DeliveryActorMaker = (ctx, deliveryId) => {
    ctx.actorOf(
      DeliveryActor.props(clustered = false, deliveryService, deliveryPatternService, deliveryExecutionService),
      deliveryId.deliveryActorName
    )
  }

  private lazy val releaseDeliveryProcessingActor: ActorRef = systemHolder.actorSystem.actorOf(
    DeliveryProcessingActor.props(releaseDeliveryActorMaker),
    DeliveryProcessingActor.name
  )

  override def initialize(): Unit = {
    logger.debug("Initializing non-clustered delivery processing actor...")
    deliveryActorHolder.resolveActorRef(releaseDeliveryProcessingActor)
  }
}

@Component
@Profile(Array(FULL))
class ClusteredDeliveryProcessingActorInitializer @Autowired()(xlrConfig: XlrConfig,
                                                               systemHolder: ActorSystemHolder,
                                                               deliveryService: DeliveryService,
                                                               deliveryPatternService: DeliveryPatternService,
                                                               deliveryExecutionService: DeliveryExecutionService,
                                                               deliveryActorHolder: DeliveryActorHolder) extends DeliveryActorInitializer {

  private def system = systemHolder.actorSystem

  private val numberOfShards = xlrConfig.sharding.numberOfReleaseShards
  private val shardingSettings = ClusterShardingSettings(system)

  def startDeliverySharding(): ActorRef = {
    ClusterSharding(system).start(
      typeName = "Delivery",
      entityProps = DeliveryActor.props(clustered = true, deliveryService, deliveryPatternService, deliveryExecutionService),
      settings = shardingSettings,
      extractEntityId = extractDeliveryId,
      extractShardId = extractDeliveryShardId
    )
  }

  private def extractDeliveryId: ShardRegion.ExtractEntityId = {
    case msg: DeliveryAction => (msg.deliveryId.deliveryActorName, msg)
  }

  private def extractDeliveryShardId: ShardRegion.ExtractShardId = {
    case msg: DeliveryAction => extractShardId(msg.deliveryId)
    case ShardRegion.StartEntity(id) => extractShardId(id)
  }

  private def extractShardId(id: String): String = (math.abs(id.deliveryActorName.hashCode) % numberOfShards).toString

  override def initialize(): Unit = {
    logger.debug("Initializing clustered delivery processing actor...")
    val shardRegion = startDeliverySharding()
    deliveryActorHolder.resolveActorRef(shardRegion)
  }

}
