package com.xebialabs.deployit.core.rest.websockets

import ai.digital.configuration.central.deploy.{ClientProperties, WebSocketsProperties}
import com.xebialabs.deployit.core.rest.converters.CiIdsMessageConverter
import com.xebialabs.deployit.core.rest.json.CiRefsJsonWriter
import jakarta.servlet.ServletContext
import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketServerContainer
import org.springframework.beans.factory.annotation.{Autowired, Qualifier}
import org.springframework.context.annotation.{Bean, Configuration}
import org.springframework.http.server.{ServerHttpRequest, ServerHttpResponse}
import org.springframework.lang.Nullable
import org.springframework.messaging.converter.MessageConverter
import org.springframework.messaging.simp.config.{ChannelRegistration, MessageBrokerRegistry}
import org.springframework.messaging.support.ChannelInterceptor
import org.springframework.scheduling.TaskScheduler
import org.springframework.web.context.ServletContextAware
import org.springframework.web.socket.WebSocketExtension
import org.springframework.web.socket.config.annotation.{EnableWebSocketMessageBroker, StompEndpointRegistry, WebSocketMessageBrokerConfigurer}
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy
import org.springframework.web.socket.server.support.DefaultHandshakeHandler

import java.security.Principal
import java.util.concurrent.TimeUnit
import java.util.{List => JList}

@Configuration
@EnableWebSocketMessageBroker
class WebSocketMessageBrokerConfig(@Autowired ciRefsJsonWriter: CiRefsJsonWriter,
                                   @Autowired @Qualifier("websocketsScheduler") taskScheduler: TaskScheduler,
                                   @Autowired @Qualifier("userInterceptor") userInterceptor: ChannelInterceptor,
                                   @Autowired webSocketsConfiguration: WebSocketsProperties,
                                   @Autowired clientConfiguration: ClientProperties
                                  ) extends WebSocketMessageBrokerConfigurer {

  private val inputBufferSize: Int = webSocketsConfiguration.messageBroker.inputBufferSize

  override def configureMessageBroker(registry: MessageBrokerRegistry): Unit = {
    registry.setApplicationDestinationPrefixes("/app")
    registry.enableSimpleBroker("/topic").setTaskScheduler(taskScheduler)
  }

  @Bean
  def handshakeHandler(): DefaultHandshakeHandler = {
    val idleTimeout = TimeUnit.MINUTES.toMillis(clientConfiguration.getSession.getTimeoutMinute)
    new DefaultHandshakeHandler(new ServletContextAwareJettyRequestUpgradeStrategy(inputBufferSize,idleTimeout))
  }

  override def registerStompEndpoints(registry: StompEndpointRegistry): Unit = {
    registry
      .addEndpoint("/ws")
      .setAllowedOriginPatterns("*")
      .setHandshakeHandler(handshakeHandler())
      .withSockJS()
  }

  override def configureClientInboundChannel(registry: ChannelRegistration): Unit = {
    registry.interceptors(userInterceptor)
  }

  override def configureMessageConverters(messageConverters: JList[MessageConverter]): Boolean = {
    messageConverters.add(new CiIdsMessageConverter(ciRefsJsonWriter))
    true
  }
}

class ServletContextAwareJettyRequestUpgradeStrategy(inputBufferSize: Integer, idleTimeout: Long) extends JettyRequestUpgradeStrategy with ServletContextAware {
  private var servletContext: ServletContext = null

  override def setServletContext(servletContext: ServletContext): Unit = {
    this.servletContext = servletContext
  }

  override def upgrade(request: ServerHttpRequest, response: ServerHttpResponse, @Nullable selectedProtocol: String, selectedExtensions: java.util.List[WebSocketExtension], @Nullable user: Principal, handler: org.springframework.web.socket.WebSocketHandler, attributes: java.util.Map[String, Object]): Unit = {
    if (servletContext != null) {
      val container = JettyWebSocketServerContainer.getContainer(servletContext)
      if (container != null) {
        container.setInputBufferSize(inputBufferSize)
        container.setIdleTimeout(java.time.Duration.ofMillis(idleTimeout))
      }
    }
    super.upgrade(request, response, selectedProtocol, selectedExtensions, user, handler, attributes)
  }

}
