package com.xebialabs.deployit.engine.tasker

import akka.actor.{Actor, ActorRef, Stash}
import com.xebialabs.deployit.engine.tasker.SplitModifySteps.LookUpActor
import com.xebialabs.deployit.engine.tasker.messages._
import grizzled.slf4j.Logging

object SplitModifySteps {
  type LookUpActor = BlockPath => Option[ActorRef]
}

trait SplitModifySteps extends Logging with Stash {


  def sendModifyStepsToBlocks(taskId: TaskId, actorLookUp: LookUpActor): Actor.Receive = {
    case msg: ModifySteps if msg.taskId == taskId =>
      val messages = splitMessage(msg)
      val allActors = messages.keys.map(path => path -> actorLookUp(path)).toMap
      val notFoundActors = allActors.filter(_._2.isEmpty)
      if (notFoundActors.nonEmpty) {
        debug(s"Path not found ${notFoundActors.keys.toSeq}")
        context.sender() ! PathsNotFound(notFoundActors.keys.toSeq)
      } else {
        allActors.foreach { case (path, someActor) => someActor.get ! messages(path)}
        waitOrReply(Nil, allActors.size, context.sender())
      }
  }

  private def waitOrReply(answersSoFar: Seq[AnyRef], expectedSize: Int, originalSender: ActorRef) = {
    if (answersSoFar.size == expectedSize) {
      for (i <- 0 to expectedSize) {
        context.unbecome()
      }
      unstashAll()
      val result = aggregateResults(answersSoFar)
      debug(s"Returning after aggregation $result")
      originalSender ! result
    } else {
      context.become(waitForResults(answersSoFar, expectedSize, originalSender), discardOld = false)
    }
  }

  private def aggregateResults(answers: Seq[AnyRef]): AnyRef = answers match {
    case ExceptionExtractor(ex) => ex
    case ErrorMessageExtractor(error) => error
    case SuccessMessageExtractor(msg) => msg
  }

  private def waitForResults(answersSoFar: Seq[AnyRef], expectedSize: Int, originalSender: ActorRef): Receive = {
    case msg: ErrorMessage => waitOrReply(answersSoFar :+ msg, expectedSize, originalSender)
    case msg: SuccessMessage => waitOrReply(answersSoFar :+ msg, expectedSize, originalSender)
    case ex: Exception => waitOrReply(answersSoFar :+ ex, expectedSize, originalSender)
    case _ => stash()
  }

  private def splitMessage(msg: ModifySteps): Map[BlockPath, ModifySteps] = msg match {
    case msg: AddPauseStep => Map(msg.blockPath.init -> msg)
    case SkipSteps(taskId, paths) => findCommonAncestors(paths).map {
      case (path, children) => path -> SkipSteps(taskId, children)
    }
    case UnSkipSteps(taskId, paths) => findCommonAncestors(paths).map {
      case (path, children) => path -> UnSkipSteps(taskId, children)
    }
  }

  private def findCommonAncestors(blockPaths: Seq[BlockPath]): Map[BlockPath, Seq[BlockPath]] = {
    def groupByAncestor(n: Int) = blockPaths.groupBy(_.take(n))

    val minPathDepth = blockPaths.map(_.depth).min - 1

    val groupsByAncestor = for {
      i <- 1 to minPathDepth
      groups = groupByAncestor(i)
      if groups.size > 1
    } yield groups

    groupsByAncestor.headOption.getOrElse(groupByAncestor(minPathDepth))
  }
}

object ExceptionExtractor {
  def unapply(messages: Seq[AnyRef]) = {
    messages.collect{case ex: Exception => ex}.headOption
  }
}

object ErrorMessageExtractor {
  def unapply(messages: Seq[AnyRef]) = {
    messages.collect{case error: ErrorMessage => error}.headOption
  }
}

object SuccessMessageExtractor {
  def unapply(messages: Seq[AnyRef]) = {
    messages.collect{case msg: SuccessMessage => msg}.headOption
  }
}
