package com.xebialabs.xlrelease.scheduler.logs

import akka.actor.{Actor, ReceiveTimeout}
import akka.event.LoggingReceive
import akka.event.slf4j.SLF4JLogging
import com.xebialabs.xlrelease.scheduler.logs.ExecutionLogWatchActor._
import com.xebialabs.xlrelease.scheduler.storage.spring.StorageConfiguration.URI_SCHEME_LOCAL_STORAGE
import com.xebialabs.xlrelease.storage.domain.JobEntryRef
import com.xebialabs.xlrelease.storage.service.StorageService
import com.xebialabs.xlrelease.support.akka.spring.SpringActor
import org.apache.commons.io.IOUtils

import java.net.URI
import java.time.{Duration, Instant}
import javax.ws.rs.sse.{OutboundSseEvent, Sse, SseEventSink}
import scala.collection.immutable.HashSet
import scala.concurrent.duration.DurationInt
import scala.util.{Try, Using}

@SpringActor
class ExecutionLogWatchActor(taskExecutionLogService: TaskExecutionLogService, storageService: StorageService)
  extends Actor with SLF4JLogging {
  var executionId: String = _
  var taskId: String = _
  // if this actor is stopped just close sse and do not restart it
  var sinks: Set[SseEventSink] = HashSet[SseEventSink]()
  var sse: Sse = _
  var lastEventTimestamp: Instant = Instant.now()

  context.setReceiveTimeout(5.seconds)

  override def receive: Receive = LoggingReceive {
    case m: WatcherMsg => handleWatchMsgs(m)
    case ReceiveTimeout => self ! Check(executionId)
  }

  override def postStop(): Unit = {
    this.sinks.foreach(s => Try(s.close()))
  }

  //noinspection ScalaStyle
  private def handleWatchMsgs(watcherMsg: WatcherMsg): Unit = watcherMsg match {
    case StartWatch(taskId, executionId, sseEventSink, sse) =>
      this.executionId = executionId
      this.taskId = taskId
      this.sinks += sseEventSink
      this.sse = sse
      streamCurrentState(sseEventSink)
      streamLastChunk(sseEventSink)
      self ! Check(executionId)
    case StopWatching(_) =>
      this.sinks.foreach(_.close())
      self ! Check(executionId)
    case Check(id) =>
      val (closedSinks, openSinks) = sinks.partition(s => s.isClosed)
      sinks = sinks.removedAll(closedSinks)
      // should unsubscribe... any msg that arrives will be sent to dead letter
      if (openSinks.isEmpty) {
        context.stop(self)
      } else {
        val maybeExecutionEntry = taskExecutionLogService.getTaskExecutionEntry(taskId, id)
        val executionIsCompletedOrMissing = maybeExecutionEntry.forall(_.endDate != null)
        if (executionIsCompletedOrMissing) {
          self ! StopWatching(id)
        } else {
          // if there was no new entry for certain amount of time ... send ping so we would close closed sinks
          val durationSinceLastEvent = Duration.between(lastEventTimestamp, Instant.now()).toSeconds
          if (durationSinceLastEvent > 10 && durationSinceLastEvent < 60) {
            val ping = this.sse.newEvent(PING_SSE_EVENT, new String(":\n\n"))
            sendEventToAllSinks(ping)
          } else if (durationSinceLastEvent >= 60) {
            self ! StopWatching(id)
          }
        }
      }
    case NewEntry(executionId, newEntryUri) =>
      this.lastEventTimestamp = Instant.now()
      // store instant of the entry into internal state
      val payload = Using.resource(storageService.get(JobEntryRef(newEntryUri))) { content =>
        IOUtils.toByteArray(content)
      }
      val event = this.sse.newEvent(LOG_CHUNK_ENTRY_CREATED_EVENT, new String(payload))
      sendEventToAllSinks(event)
  }

  private def sendEventToAllSinks(event: OutboundSseEvent): Unit = {
    this.sinks.foreach(sendEventToSink(event))
  }

  private def sendEventToSink(event: OutboundSseEvent)(sink: SseEventSink) = {
    if (!sink.isClosed) {
      import scala.jdk.FutureConverters._
      sink.send(event).asScala.recover { t => {
        log.debug("Event send failed", t)
        sink.close()
        sinks -= sink
        if (sinks.isEmpty) {
          self ! Check(executionId)
        }
      }
      }(scala.concurrent.ExecutionContext.parasitic)
    }
  }

  private def streamCurrentState(sink: SseEventSink): Unit = {
    val (status, jobId, chunk) = taskExecutionLogService.getTaskExecutionEntry(taskId, executionId) match {
      case Some(row) => (if (row.endDate == null) "in_progress" else "closed", row.lastJob, row.lastChunk)
      case None => ("unknown", -1, -1)
    }
    val payload = s"$status, $jobId, $chunk"
    val statusEvent = this.sse.newEvent(EXECUTION_LOG_STATUS_EVENT, new String(payload))
    sendEventToSink(statusEvent)(sink)
  }

  private def streamLastChunk(sink: SseEventSink): Unit = {
    val maybeRow = taskExecutionLogService.getTaskExecutionEntry(taskId, executionId)
    maybeRow.foreach { row =>
      // row -> log entry ref
      val taskIdHash = row.taskIdHash
      val executionId = row.executionId
      val jobId = row.lastJob
      val chunk = row.lastChunk
      val uriPath = s"/jobs/$taskIdHash/$executionId/$jobId/$chunk"
      val uriScheme = URI_SCHEME_LOCAL_STORAGE
      val entryUri = URI.create(s"$uriScheme://$uriPath")
      val payload = Using.resource(storageService.get(JobEntryRef(entryUri))) { content =>
        IOUtils.toByteArray(content)
      }
      val event = this.sse.newEvent(LOG_CHUNK_ENTRY_CREATED_EVENT, new String(payload))
      sendEventToSink(event)(sink)
    }
  }

}

object ExecutionLogWatchActor {

  final val LOG_CHUNK_ENTRY_CREATED_EVENT: String = "log-chunk-entry-created"

  final val PING_SSE_EVENT: String = "ping"

  final val EXECUTION_LOG_STATUS_EVENT: String = "execution-status" // in_progress or closed, last job_id and last_chunk?

  sealed trait WatcherMsg {
    def executionId: String
  }

  case class StartWatch(taskId: String, executionId: String, sink: SseEventSink, sse: Sse) extends WatcherMsg

  case class StopWatching(executionId: String) extends WatcherMsg

  case class Check(executionId: String) extends WatcherMsg

  case class NewEntry(executionId: String, newEntryUri: URI) extends WatcherMsg
}
