package com.xebialabs.deployit.engine.tasker

import java.util.UUID

import ai.digital.deploy.task.steplog.TaskStepLogStoreHolder
import akka.actor._
import akka.event.LoggingReceive
import com.xebialabs.deployit.core.events.TaskStepLogEvent
import com.xebialabs.deployit.engine.api.execution.BlockExecutionState.{DONE, FAILED}
import com.xebialabs.deployit.engine.api.execution.SatelliteConnectionState._
import com.xebialabs.deployit.engine.api.execution.{BlockExecutionState, SatelliteConnectionState, StepState}
import com.xebialabs.deployit.engine.tasker.BlockExecutingActor._
import com.xebialabs.deployit.engine.tasker.StateChangeEventListenerActor.{SatelliteStepStateEvent, StepStateEvent}
import com.xebialabs.deployit.engine.tasker.distribution.MessageSequenceReceiver
import com.xebialabs.deployit.engine.tasker.messages._
import com.xebialabs.deployit.engine.tasker.satellite.{ActorLocator, Paths}

import scala.jdk.CollectionConverters._

class BlockOnSatellite private(val task: Task, block: ExecutableBlock, actorLocator: ActorLocator, notificationActor: ActorRef)
  extends Actor with ModifyStepsSupport with Stash with BecomeWithMdc with MessageSequenceReceiver {

  private val remoteActor = actorLocator.locate(Paths.tasks)(context.system)

  val taskId: TaskId = task.getId

  val blockId: BlockId = block.getId()

  override def receive: Receive = disconnected(false)

  def disconnected(whileCancelling: Boolean): Receive = {
    case _ if whileCancelling =>
      warn(s"Not reconnecting to satellite ${block.satellite} in cancelling mode. Setting block ${block.id} to DISCONNECTED/DONE")
      updateBlockSatelliteState(DISCONNECTED, DONE)
      notificationActor ! BlockDone(taskId, block)
    case _ =>
      tryConnect()
      stash()
  }

  private def tryConnect(): Unit = {
    if (block.getSatelliteConnectionState == DISCONNECTED) {
      updateBlockSatelliteState(RECONNECTING, block.getState())
    }
    val uuid = UUID.randomUUID()
    remoteActor ! Identify(uuid)
    becomeWithMdc(identifyRemoteActor(uuid, sender()))
  }

  private def identifyRemoteActor(uuid: UUID, originalSender: ActorRef): Receive = {
    case ActorIdentity(`uuid`, Some(actorRef)) =>
      debug(s"Remote actor $remoteActor found")
      context.watch(actorRef)
      actorRef ! BlockEnvelope(taskId, block.id, ReportBlockState)
      becomeWithMdc(handleTaskNotFound() orElse inquireBlockState(actorRef))
    case ActorIdentity(`uuid`, None) =>
      disconnect(s"Could not connect to $remoteActor. Probably, connection to a satellite is broken", Option(originalSender))
    case _ =>
      stash()
  }

  private def inquireBlockState(tasksActor: ActorRef): Receive = receiveChunks orElse {
    case BlockStateReport(`taskId`, `blockId`, blockState) =>
      updateBlockSatelliteState(CONNECTED, blockState)
      unstashAll()
      becomeWithMdc(sendModifyStepsAndUpdateBlock(tasksActor) orElse sendToSatellite(tasksActor) orElse
        handleBlockDoneOrChanged orElse handleDeathOfRemote() orElse handleTaskNotFound())
    case _ =>
      stash()
  }

  private def handleDeathOfRemote(originalSender: Option[ActorRef] = None): Receive = {
    case Terminated(actorRef) =>
      context.unwatch(actorRef)
      disconnect(s"Remote actor $remoteActor is terminated. Probably, connection to satellite ${block.satellite.get} is broken", originalSender)
  }

  def updateBlockSatelliteState(satelliteState: SatelliteConnectionState, newBlockState: BlockExecutionState): Unit = {
    block.setSatelliteState(satelliteState)
    val oldState: BlockExecutionState = block.getState()
    if (oldState != newBlockState) {
      block.newState(newBlockState)
      notificationActor ! BlockStateChanged(taskId, block, oldState, block.getState())
    }
  }

  private[this] def handleTaskNotFound(): Receive = {
    case m@TaskNotFound(`taskId`) =>
      error(s"Satellite ${block.satellite.get} reported: ${m.msg}")
      updateBlockSatelliteState(UNKNOWN_TASK, FAILED)
      notificationActor ! BlockDone(taskId, block)
      becomeWithMdc(broken)
  }

  private[this] def broken: Receive = {
    case m =>
      error(s"Cannot handle $m, the Satellite ${block.satellite.get} is broken")
      notificationActor ! BlockDone(taskId, block)
  }

  private def disconnect(msg: => String, originalSender: Option[ActorRef] = None): Unit = {
    warn(msg)
    becomeWithMdc(disconnected(task.cancelling))
    originalSender.foreach(_.tell(ActorNotFound(remoteActor), self))
    markBlockDisconnected(block)
    notificationActor ! BlockDone(taskId, block)
  }

  private def sendToSatellite(remoteActor: ActorRef): Receive = LoggingReceive {
    case msg@Start(`taskId`, runMode) =>
      debug(s"sending $msg to remote: $remoteActor")
      remoteActor ! StartBlock(taskId, block.id, runMode)
    case msg@Stop(`taskId`) =>
      debug(s"sending $msg to remote: $remoteActor")
      remoteActor ! StopBlock(taskId, block.id)
    case msg@Abort(`taskId`) =>
      debug(s"sending $msg to remote: $remoteActor")
      remoteActor ! AbortBlock(taskId, block.id)
  }

  private def handleBlockDoneOrChanged: Receive = receiveChunks orElse {
    case BlockDone(`taskId`, updatedBlock: ExecutableBlock) =>
      updateState(block, updatedBlock)
      block.setSatelliteState(updatedBlock.getSatelliteConnectionState)
      sendLogs(updatedBlock)
      notificationActor ! BlockDone(taskId, block)
    case BlockStateChanged(`taskId`, updatedBlock: ExecutableBlock, oldState, newState) =>
      updateState(block, updatedBlock)
      block.setSatelliteState(CONNECTED)
      if (oldState != newState) {
        notificationActor ! BlockStateChanged(taskId, block, oldState, newState)
      }
  }

  private def sendLogs(updatedBlock: ExecutableBlock): Unit = {
    updatedBlock.getStepList().asScala.zipWithIndex.foreach {
      case (stepState: StepState, stepNumber: Int) =>
        stepState.getLogs.asScala.zipWithIndex.foreach {
          case (logText: String, logAttempt: Int) =>
            val logLines = logText.split("\n")
            logLines.zipWithIndex.foreach {
              case (logLine: String, lineNumber: Int) =>
                TaskStepLogStoreHolder.getTaskStepLogJmsBean.sendStepLogEvent(
                  TaskStepLogEvent(
                    taskId,
                    stepState.getMetadata.get("blockPath"),
                    lineNumber,
                    if (logLine.startsWith("[ERROR]: ")) "ERROR" else "INFO",
                    logAttempt,
                    logLine
                  ))
            }
        }
    }
  }

  private def updateState(localBlock: ExecutableBlock, remoteBlock: ExecutableBlock): Unit = {
    info(s"Changing state of local ${localBlock.id} and remote ${remoteBlock.id} from ${localBlock.getState()} -> ${remoteBlock.getState()}")
    (localBlock, remoteBlock) match {
      case (local: StepBlock, remote: StepBlock) =>
        local.newState(remote.getState())
        val events = createEvents(local.steps.toList, remote.steps.toList)
        local.steps = remote.steps
        events.foreach { event =>
          context.system.eventStream.publish(SatelliteStepStateEvent(event))
        }

      case (local: CompositeBlock, remote: CompositeBlock) =>
        local.blocks.zip(remote.blocks).foreach {
          case (localSubBlock, remoteSubBlock) =>
            local.newState(remote.getState())
            updateState(localSubBlock, remoteSubBlock)
        }
      case _ =>
        error(s"local block ${localBlock.getId()}, ${localBlock.getClass} and remote block ${remoteBlock.getId()}, " +
          s"${remoteBlock.getClass} are not both StepBlocks or both CompositeBlocks")
    }
  }

  private def createEvents(localSteps: List[StepState], satelliteSteps: List[StepState]): List[StepStateEvent] = {
    localSteps.zip(satelliteSteps).flatMap { case (local, satellite) => createEvent(local, satellite) }
  }

  private def createEvent(localStep: StepState, satelliteStep: StepState): Option[StepStateEvent] = {
    if (localStep.getState != satelliteStep.getState) {
      val stepId: String = UUID.randomUUID().toString
      Option(StepStateEvent(taskId, stepId, task, satelliteStep, localStep.getState, satelliteStep.getState, None, None))
    } else {
      None
    }
  }

  private def markBlockDisconnected(localBlock: ExecutableBlock): Unit = localBlock match {
    case cb: CompositeBlock if !cb.getState().isFinished =>
      cb.newState(FAILED)
      cb.setSatelliteState(DISCONNECTED)
    case st: StepBlock if !st.getState().isFinished =>
      st.newState(FAILED)
      st.setSatelliteState(DISCONNECTED)
    case _ =>
  }

  private def sendModifyStepsAndUpdateBlock(actorRef: ActorRef): Receive = {
    case msg: ModifySteps if msg.taskId == taskId =>
      debug(s"sending $msg to the remote actor")
      becomeWithMdc(
        handleDeathOfRemote(Option(sender())) orElse waitForStepModified(taskId, block, context.sender(), msg),
        discardOld = false)
      actorRef ! BlockEnvelope(taskId = msg.taskId, blockPath = block.id, message = msg)
  }

  private def waitForStepModified(taskId: TaskId, block: ExecutableBlock, originalSender: ActorRef,
                                  originalMessage: ModifySteps): Receive = {
    case msg: ErrorMessage =>
      originalSender ! msg
      context.unbecome()
      unstashAll()
    case _: SuccessMessage =>
      modifySteps(taskId, path => path.relative(block.id).flatMap(block.getBlock), originalSender).apply(originalMessage)
      context.unbecome()
      unstashAll()
    case _ => stash()
  }
}

object BlockOnSatellite {
  def props(task: Task, block: ExecutableBlock, actorLocator: ActorLocator, notificationActor: ActorRef): Props =
    Props(classOf[BlockOnSatellite], task, block, actorLocator, notificationActor).withDispatcher(stateManagementDispatcher) //construction invocation breaks SatelliteBlockRouterSpec
}
