package com.xebialabs.deployit.core.rest.api

import ai.digital.configuration.central.deploy.converter.HoconDurationConverter
import ai.digital.deploy.sql.model.WorkerInfo
import com.xebialabs.deployit.core.api.WorkersService
import com.xebialabs.deployit.core.rest.api.support.PaginationSupport
import com.xebialabs.deployit.core.rest.secured.AbstractSecuredResource
import com.xebialabs.deployit.engine.api.distribution.TaskExecutionWorkerRepository
import com.xebialabs.deployit.engine.api.dto.{Ordering, Paging, Worker}
import com.xebialabs.deployit.engine.tasker.TaskExecutionEngine
import com.xebialabs.deployit.engine.tasker.distribution.WorkerManager.messages._
import com.xebialabs.deployit.security.permission.PlatformPermissions.ADMIN
import com.xebialabs.deployit.spring.BeanWrapper
import grizzled.slf4j.Logging
import jakarta.ws.rs.NotFoundException
import jakarta.ws.rs.core.Context
import org.apache.pekko.actor.{ActorRef, AddressFromURIString}
import org.apache.pekko.pattern._
import org.apache.pekko.util.Timeout
import org.jboss.resteasy.spi.HttpResponse
import org.springframework.beans.factory.annotation.{Autowired, Qualifier, Value}
import org.springframework.stereotype.Controller

import java.util
import java.util.stream.Collectors
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._

@Controller
class WorkersResource(@Autowired @Qualifier("workerManager") workerManager: BeanWrapper[ActorRef],
                      @Autowired workerRepository: TaskExecutionWorkerRepository,
                      @Autowired engine: BeanWrapper[TaskExecutionEngine],
                      @Value("${xl.tasker.askTimeout:10 seconds}") askTimeout: String)
  extends AbstractSecuredResource with WorkersService with Logging {

  implicit private val timeout: Timeout = Timeout(HoconDurationConverter.convert(askTimeout))

  @Context val httpResponse: HttpResponse = null

  override def shutdownWorkers(workers: util.List[String]): Unit = {
    checkPermission(ADMIN)

    if (workers.isEmpty) {
      logger.info("Shutting down all workers.")
    } else {
      logger.info(s"Shutting down workers $workers.")
    }

    workers.asScala.find(isLocalWorker).foreach(throwLocalWorkerError)

    workerRepository.listWorkers.map(_._1).foreach { worker =>
      if (workers.isEmpty || workers.contains(worker.address)) {
        workerRepository.removeTasks(worker.id)
      }
    }

    val response = Await.result(workerManager.get() ? ShutdownWorker(workers.asScala.toList, force = true), timeout.duration).asInstanceOf[WorkerShutdownStarted]

    if (!workers.isEmpty) {
      val shutdownWorkers = response.workers
      logger.info(s"Workers ${shutdownWorkers.mkString("[", ", ", "]")} shutdown.")
      val notFound = new util.ArrayList(workers)
      notFound.removeAll(shutdownWorkers.asJava)
      if (!notFound.isEmpty) {
        throw new NotFoundException(s"Workers ${notFound.asScala.mkString("[", ", ", "]")} not found. (Workers ${shutdownWorkers.mkString("[", ", ", "]")} shutdown.)")
      }
    }
  }

  override def shutdownWorker(workerId: Integer): Unit = {
    checkPermission(ADMIN)

    workerRepository.getWorker(workerId).orElse(throw new NotFoundException(s"Worker for id $workerId not found.")).foreach { worker =>
      val address = worker.address

      if (isLocalWorker(address)) {
        throwLocalWorkerError(address)
      }

      val response = Await.result(workerManager.get() ? ShutdownWorker(List(address), force = false), timeout.duration).asInstanceOf[WorkerShutdownStarted]

      val shutdownWorkers = response.workers
      logger.info(s"Workers ${shutdownWorkers.mkString("[", ", ", "]")} shutdown.")
      if (!response.workers.contains(address)) {
        throw new NotFoundException(s"Worker $address not found.")
      }
    }

  }

  private def isLocalWorker(address: String) = {
    AddressFromURIString(address).hasLocalScope
  }

  private def throwLocalWorkerError(s: String) = {
    throw new IllegalArgumentException(s"Local worker ($s) cannot be shutdown.")
  }

  override def removeWorker(workerId: Integer): Unit = {
    checkPermission(ADMIN)

    workerRepository.getWorker(workerId).orElse(throw new NotFoundException(s"Worker for id $workerId not found.")).foreach { worker =>
      val response = Await.result(workerManager.get() ? FetchWorkers(), timeout.duration).asInstanceOf[WorkersFetched]
      val active = response.draining ++ response.healthy

      if (active.contains(worker.address)) {
        throw new IllegalArgumentException(s"Worker for id $workerId is still active, cannot be removed.")
      }

      workerRepository.removeTasks(workerId)
      workerRepository.removeWorker(workerId)
    }
  }

  override def reregisterGhostTasks(): util.List[String] = {
    checkPermission(ADMIN)
    engine.get().reregisterGhostTasks()
  }

  override def listWorkers(paging: Paging, order: Ordering): util.List[WorkerInfo] = {
    def calculateState(w: Worker, healthy: List[String], incompatible: List[String], draining: List[String]): WorkerState.Value = {
      if (healthy.contains(w.address)) {
        WorkerState.CONNECTED
      } else if (incompatible.contains(w.address)) {
        WorkerState.INCOMPATIBLE
      } else if (draining.contains(w.address)) {
        WorkerState.DRAINING
      } else {
        WorkerState.DISCONNECTED
      }
    }

    checkPermission(ADMIN)
    implicit val timeout: Timeout = new Timeout(10.seconds)
    val response = Await.result(workerManager.get() ? FetchWorkers(), timeout.duration).asInstanceOf[WorkersFetched]

    val workers = workerRepository.listWorkers.map(Function.tupled((w: Worker, deploymentTasks: Integer, controlTasks: Integer) => {
      new WorkerInfo(w.id, w.name, w.address, calculateState(w, response.healthy, response.incompatible, response.draining).toString, deploymentTasks, controlTasks)
    }))

    if ((paging == null || paging.resultsPerPage == -1) && order == null) {
      workers.sortBy(_.getId).asJava.stream().collect(Collectors.toList())
    } else {
      PaginationSupport.addTotalCountHeader(workers.size, httpResponse)
      val from = (paging.page - 1) * paging.resultsPerPage
      val to = if (from + paging.resultsPerPage <= workers.asJava.size()) from + paging.resultsPerPage else workers.asJava.size()

      WorkersSorting.sort(workers, order).slice(from, to).asJava
    }
  }
}

object WorkerState extends Enumeration {
  val CONNECTED, INCOMPATIBLE, DRAINING, DISCONNECTED = Value
}

object WorkersSorting {

  import WorkerState.withName

  def sort(workers: List[WorkerInfo], ordering: Ordering): List[WorkerInfo] = {
    val sorted = ordering.field match {
      case "id" => workers.sortBy(_.getId)
      case "name" => workers.sortBy(_.getName)
      case "address" => workers.sortBy(_.getAddress)
      case "state" => workers.sortBy(w => (withName(w.getState), w.getId))
      case "deploymentTasks" => workers.sortBy(w => (w.getDeploymentTasks, w.getId))
      case "controlTasks" => workers.sortBy(w => (w.getControlTasks, w.getId))
    }
    if (ordering.isAscending) sorted else sorted.reverse
  }
}
