package com.xebialabs.deployit.engine.tasker

import akka.actor.ActorRef
import com.typesafe.config.Config
import com.xebialabs.deployit.engine.tasker.TaskReceiver.messages._
import grizzled.slf4j.Logging
import javax.jms.{Message, MessageListener}
import org.springframework.jms.config.{JmsListenerContainerFactory, SimpleJmsListenerEndpoint}
import org.springframework.jms.listener.DefaultMessageListenerContainer
import org.springframework.util.ErrorHandler

object EnqueueTaskListener extends Logging {

  var listeners: List[DefaultMessageListenerContainer] = Nil

  def registerTaskListenerContainers(workerAddress: String, taskReceiver: ActorRef,
                                     containerFactory: JmsListenerContainerFactory[DefaultMessageListenerContainer],
                                     config: Config): Unit = {
    val taskQueueName = config.getString("deploy.task.queue.name")
    listeners = List(
      initListener(containerFactory, createEndpoint(taskQueueName, workerAddress, taskReceiver)),
      initListener(containerFactory, createRollbackEndpoint(taskQueueName, workerAddress, taskReceiver), isRollback = true)
    )
  }

  def startListeners(): Unit = listeners.foreach(startListener)

  def stopListeners(): Unit = listeners.foreach(stopListener)

  def shutdownListeners(): Unit = listeners.foreach(shutdownListener)

  private def initListener(containerFactory: JmsListenerContainerFactory[DefaultMessageListenerContainer],
                           endpoint: SimpleJmsListenerEndpoint, isRollback: Boolean = false): DefaultMessageListenerContainer = {
    logger.debug(s"Creating for endpoint [${endpoint.getId}]")
    val container: DefaultMessageListenerContainer = containerFactory.createListenerContainer(endpoint)
    container.setErrorHandler(new EnqueueTaskErrorHandler(endpoint.getDestination, isRollback))
    container.setPubSubDomain(isRollback)
    container.initialize()
    container
  }

  private def createEndpoint(taskQueueName: String, workerAddress: String, taskReceiver: ActorRef) =
    newEndpoint(taskReceiver, workerAddress, taskQueueName)

  private def createRollbackEndpoint(taskQueueName: String, workerAddress: String, taskReceiver: ActorRef) = {
    newEndpoint(taskReceiver, s"$workerAddress-rollback", workerAddress, Some(s"workerAddress = '$workerAddress'"))
  }

  private def newEndpoint(taskReceiver: ActorRef, id: String, destination: String, selector: Option[String] = None) = {
    val endpoint: SimpleJmsListenerEndpoint = new SimpleJmsListenerEndpoint()
    endpoint.setId(id)
    endpoint.setDestination(destination)
    endpoint.setMessageListener(new EnqueueTaskListener(taskReceiver))
    selector.foreach(endpoint.setSelector)
    endpoint
  }

  private def startListener(listener: DefaultMessageListenerContainer): Unit = {
    if (!listener.isRunning) {
      logger.debug(s"Starting the container listener for destination [${listener.getDestinationName}]")
      listener.start()
    }
  }

  private def stopListener(listener: DefaultMessageListenerContainer): Unit = {
    if (listener.isRunning) {
      logger.debug(s"Stopping the container listener for destination [${listener.getDestinationName}]")
      listener.stop()
    }
  }

  private def shutdownListener(listener: DefaultMessageListenerContainer): Unit = {
    if (listener.isActive) {
      logger.debug(s"Shutting down the container listener for destination [${listener.getDestinationName}]")
      listener.shutdown()
    }
  }
}

class EnqueueTaskListener(taskReceiver: ActorRef) extends MessageListener with Logging {
  override def onMessage(message: Message): Unit = {
    logger.debug(s"Received message [$message]")
    EnqueueTaskListener.stopListeners()
    taskReceiver ! ExecuteTask(message.getStringProperty("taskId"))
  }
}

class EnqueueTaskErrorHandler(workerAddress: String, isRollback: Boolean) extends ErrorHandler with Logging {
  def queueName: String = if (isRollback) s"$workerAddress-rollback" else workerAddress
  override def handleError(t: Throwable): Unit = {
    logger.error(s"Error happened during message reading from [$queueName]: ${t.getMessage}")
    logger.error(t.getStackTrace)
  }
}
