package com.xebialabs.deployit.deployment.stager

import com.xebialabs.deployit.deployment.planner.PipedPlanner.{PlanTransformer, TransformerContext}
import com.xebialabs.deployit.deployment.planner.PlanSugar._
import com.xebialabs.deployit.deployment.planner._
import com.xebialabs.deployit.deployment.stager.DeploymentStager._StagingContext
import com.xebialabs.deployit.engine.spi.execution.ExecutionStateListener
import com.xebialabs.deployit.plugin.api.flow._
import com.xebialabs.deployit.plugin.api.udm.artifact.{Artifact, DerivedArtifact, SourceArtifact}
import grizzled.slf4j.Logging

import java.util.{HashMap => JHashMap, List => JList, Map => JMap}
import scala.collection.mutable
import scala.collection.mutable.{Map => MMap}
import scala.jdk.CollectionConverters._
import scala.language.reflectiveCalls

class DeploymentStager extends PlanTransformer with Logging with PlanSugar {
  type GetStagingTarget = {def getStagingTarget: StagingTarget}

  def transform(context: TransformerContext): PhasedPlan = {
    logger.info(s"Staging artifacts for plan: [${context.plan.getDescription}]")
    val (stagedSteps, stagedTargets) = stagePhases(context.plan.phases.asScala.toList)

    val stagePhase = generatePhase(
      stepsToPlan(stagedSteps.toList, "Stage artifacts on", context.plan.getListeners),
      description = "Stage artifacts",
      alwaysExecuted = false
    )
    val cleanupPhase = generatePhase(
      stepsToPlan(stagedTargets.toList.map(new StagedFileCleaningStep(_)), "Clean up staged files on", context.plan.getListeners),
      description = "Clean up staged artifacts",
      alwaysExecuted = true
    )

    val newPhases = stagePhase ::: context.plan.phases.asScala.toList ::: cleanupPhase
    context.plan.copy(phases = newPhases.asJava)
  }

  private def stepsToPlan(steps: List[Step with GetStagingTarget],
                          descriptionPrefix: String,
                          listeners: JList[ExecutionStateListener]) = {
    val plans = steps
      .groupBy(_.getStagingTarget)
      .map { case (target, stepList) =>
        target -> new StepPlan(s"$descriptionPrefix ${target.getName}", stepList.asInstanceOf[List[Step]].asJava, listeners)
      }
    steps.map(_.getStagingTarget).distinct.flatMap(plans.get)
  }

  private def generatePhase(underlyingPlans: List[ExecutablePlan],
                            description: String,
                            alwaysExecuted: Boolean): List[PlanPhase] = underlyingPlans match {
    case Nil => Nil
    case plan :: Nil => new PlanPhase(plan, description, plan.getListeners, alwaysExecuted) :: Nil
    case plan :: _ =>
      val parallelPlan = new ParallelPlan(description, underlyingPlans.asJava, plan.getListeners)
      new PlanPhase(parallelPlan, description, parallelPlan.getListeners, alwaysExecuted) :: Nil
  }

  private def stagePhases(phases: List[PlanPhase]) = {
    val stagingContext = new _StagingContext
    phases.foreach(phase => doStage(phase.plan, stagingContext))
    stagingContext.stagingSteps -> stagingContext.cleanupHosts
  }

  private def doStage(plan: Plan, stagingContext: _StagingContext): Unit = {
    logger.debug(s"Staging for [${plan.getClass.getSimpleName}(${plan.getDescription})]")
    plan match {
      case cp: CompositePlan => cp.getSubPlans.asScala.foreach(doStage(_, stagingContext))
      case sp: StepPlan => doStage(sp, stagingContext)
    }
  }

  private def doStage(stepPlan: StepPlan, stagingContext: _StagingContext): Unit = {
    stepPlan
      .getSteps
      .asScala
      .withFilter(_.isInstanceOf[StageableStep])
      .foreach(step => doStage(step.asInstanceOf[StageableStep], stagingContext))
  }

  private def doStage(step: StageableStep, stagingContext: _StagingContext): Unit = {
    logger.debug(s"Preparing stage of artifacts for step [${step.getDescription}]")
    step.requestStaging(stagingContext)
  }
}

object DeploymentStager extends Logging {

  private[stager] class _StagingContext extends StagingContext {

    case class StagingKey(checksum: String, placeholders: JMap[String, String], target: StagingTarget)

    val stagingFiles: MMap[StagingKey, StagingFile] = new mutable.LinkedHashMap[StagingKey, StagingFile]()
    val cleanupHosts: mutable.Set[StagingTarget] = new mutable.HashSet[StagingTarget]()

    @SuppressWarnings(Array("ComparingUnrelatedTypes"))
    def stageArtifact(artifact: Artifact, target: StagingTarget): StagedFile = {
      if (Option(target.getStagingDirectoryPath).forall(_.trim.isEmpty)) {
        new JustInTimeFile(artifact)
      } else {
        val key = artifact match {
          case bda: DerivedArtifact[_] if bda.getSourceArtifact != null => StagingKey(bda.getSourceArtifact.getChecksum, bda.getPlaceholders, target)
          case bda: DerivedArtifact[_] if bda.getSourceArtifact == null => StagingKey(null, bda.getPlaceholders, target)
          case sa: SourceArtifact => StagingKey(sa.getChecksum, new JHashMap(), target)
        }

        if (!stagingFiles.contains(key)) {
          stagingFiles.put(key, new StagingFile(artifact, target))
          cleanupHosts += target
        }

        stagingFiles(key)
      }
    }

    def stagingSteps: Iterable[StagingStep] = stagingFiles.values.map(new StagingStep(_))
  }

}
