package com.xebialabs.xlrelease.runner.actors

import akka.actor.ActorRef
import akka.cluster.sharding.{ClusterSharding, ClusterShardingSettings, ShardRegion}
import com.typesafe.config.Config
import com.xebialabs.xlplatform.cluster.ClusterMode
import com.xebialabs.xlrelease.actors.ActorSystemHolder
import com.xebialabs.xlrelease.config.XlrConfig
import com.xebialabs.xlrelease.runner.actors.JobRunnerActor.{RunnerCommand, actorName}
import com.xebialabs.xlrelease.runner.domain._
import com.xebialabs.xlrelease.support.akka.spring.SpringExtension
import org.springframework.context.annotation.Profile
import org.springframework.stereotype.Component

import javax.annotation.PostConstruct

trait JobRunnerActorFactory extends ActorFactory {
  override type FactoryInput = Unit
}

@Profile(Array(ClusterMode.STANDALONE))
@Component
class NonClusteredJobRunnerActorFactory(springExtension: SpringExtension) extends JobRunnerActorFactory {

  private lazy val jobRunnerProcessingActor: ActorRef = springExtension.actorOf(classOf[JobRunnerProcessingActor], name = "job-runner-processing-actor")

  override def create: FactoryMethod = {
    case () => jobRunnerProcessingActor
  }
}

@Profile(Array(ClusterMode.FULL))
@Component
class ClusteredJobRunnerActorFactory(xlrConfig: XlrConfig, systemHolder: ActorSystemHolder, springExtension: SpringExtension) extends JobRunnerActorFactory {
  private def actorSystem = systemHolder.actorSystem

  @PostConstruct
  def init(): Unit = {
    shardRegion
  }

  lazy val shardRegion = {
    val sharding = ClusterSharding(actorSystem)
    val originalConfig = actorSystem.settings.config.getConfig("akka.cluster.sharding")
    val jobRunnerShardingConfig: Config = xlrConfig.xl.getConfig("job-runner.akka.cluster.sharding").withFallback(originalConfig)
    val shardingSettings = ClusterShardingSettings(jobRunnerShardingConfig)

    val actorProps = springExtension.props(classOf[JobRunnerActor])
    val shardRegion = sharding.start(
      typeName = JobRunnerActor.SHARDING_TYPE_NAME,
      entityProps = actorProps,
      settings = shardingSettings,
      extractEntityId = extractEntityId,
      extractShardId = extractShardId
    )
    shardRegion
  }

  override def create: FactoryMethod = {
    case () => shardRegion
  }

  private def extractEntityId: ShardRegion.ExtractEntityId = {
    case msg: RunnerCommand =>
      val entityId = actorName(msg.runnerId)
      (entityId, msg)
  }

  private def extractShardId: ShardRegion.ExtractShardId = {
    case msg: RunnerCommand =>
      val entityId = actorName(msg.runnerId)
      entityId.shardId()
    case ShardRegion.StartEntity(entityId) =>
      entityId.shardId()
  }

}
