package com.xebialabs.xlrelease.repository

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.datatype.joda.JodaModule
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import grizzled.slf4j.Logging
import org.jboss.resteasy.plugins.providers.sse.OutboundSseEventImpl
import org.springframework.stereotype.Repository

import java.util.concurrent.{CompletableFuture, CompletionStage, ConcurrentHashMap}
import jakarta.ws.rs.core.MediaType
import jakarta.ws.rs.sse.{OutboundSseEvent, SseEventSink}
import scala.collection.{concurrent, mutable}
import scala.jdk.CollectionConverters.{ConcurrentMapHasAsScala, SetHasAsScala}
import scala.jdk.FutureConverters.CompletionStageOps


/**
 * This class will be used as a fake sink when adding topics to a user that is in another node
 */
class FakeSseEventSink extends SseEventSink {
  override def isClosed: Boolean = false

  override def send(event: OutboundSseEvent): CompletionStage[_] = {
    val completableFuture = new CompletableFuture[String]()
    completableFuture.complete("")
    completableFuture
  }

  override def close(): Unit = {
    // nothing to close
  }
}
trait SSERepository {

  def getSinks(username: String): Set[SseEventSink]

  def getUsers(topic: String): Set[String]

  def addUserToSink(username: String, sink: SseEventSink): Unit

  def addTopicToUser(topic: String, username: String): Unit

  def removeTopicToUser(topic: String, username: String): Unit

  def removeAllUsersFromTopic(topic: String): Unit

  def sendEventToSink(topic: String, eventName: String, payload: String): Unit

  def newEventBuilder() = new OutboundSseEventImpl.BuilderImpl()

}

private case class CustomSseEvent(eventName: String, origin: String, payload: String)

@Repository
class DefaultSSERepository extends SSERepository with Logging {

  private val mapper = new ObjectMapper()
  mapper.registerModule(DefaultScalaModule)
  mapper.registerModule(new JodaModule)

  private val userSinks: concurrent.Map[String, mutable.Set[SseEventSink]] = new ConcurrentHashMap[String, mutable.Set[SseEventSink]].asScala
  private val topicUsers: concurrent.Map[String, mutable.Set[String]] = new ConcurrentHashMap[String, mutable.Set[String]].asScala

  private def createConcurrentSet[T](): mutable.Set[T] = {
    java.util.Collections.newSetFromMap(new ConcurrentHashMap[T, java.lang.Boolean]).asScala
  }

  override def getSinks(username: String): Set[SseEventSink] = {
    removeClosedSinks()
    userSinks.getOrElse(username, createConcurrentSet()).toSet
  }


  override def getUsers(topic: String): Set[String] = {
    topicUsers.getOrElse(topic, createConcurrentSet()).toSet
  }


  override def sendEventToSink(topic: String, eventName: String, payload: String): Unit = {
    val xlrSseEvent = CustomSseEvent(eventName = eventName, origin = topic, payload = payload)
    val event = this.newEventBuilder().data(mapper.writeValueAsString(xlrSseEvent)).mediaType(MediaType.APPLICATION_JSON_TYPE).build()
    val users = getUsers(topic)
    users.foreach(user => {
      getSinks(user).foreach(sink => {
        if (!sink.isClosed) {
          sink.send(event).asScala.recover {
            case e: Throwable =>
              logger.debug("can't send SSE event", e)
              sink.close()
          }(scala.concurrent.ExecutionContext.parasitic)
        }
      })
    })
  }

  override def addUserToSink(username: String, sink: SseEventSink): Unit = {
    removeClosedSinks()
    val userSink = userSinks.getOrElse(username, createConcurrentSet[SseEventSink]())
    userSinks.put(username, userSink += sink)
  }


  override def addTopicToUser(topic: String, username: String): Unit = {
    if (!userSinks.contains(username)) {
      logger.debug(s"Trying to subscribe user [$username] for topic [$topic] but user has no SSE sink")
      addUserToSink(username, new FakeSseEventSink())
    }
    val topicsByUser = topicUsers.getOrElse(topic, createConcurrentSet[String]())
    topicUsers.put(topic, topicsByUser += username)
  }

  override def removeTopicToUser(topic: String, username: String): Unit = {
    val topicsByUser = topicUsers.getOrElse(topic, createConcurrentSet[String]())
    topicUsers.put(topic, topicsByUser -= username)
  }

  override def removeAllUsersFromTopic(topic: String): Unit = {
    if (topicUsers.contains(topic)) {
      topicUsers.remove(topic)
    }
  }

  private def removeClosedSinks(): Unit = {
    userSinks.values.foreach { set => set.filterInPlace(!_.isClosed) }
    userSinks.filterInPlace { (_, set) => set.nonEmpty }
  }
}

