package com.xebialabs.xlrelease.webhooks.endpoint

import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.xlplatform.cluster.NodeState
import com.xebialabs.xlplatform.webhooks.authentication.RequestAuthenticationMethod
import com.xebialabs.xlplatform.webhooks.domain.{Endpoint, HttpRequestEvent}
import com.xebialabs.xlplatform.webhooks.endpoint.EndpointProvider
import com.xebialabs.xlplatform.webhooks.events.handlers.EventSourceHandler
import com.xebialabs.xlrelease.events.EventBus
import com.xebialabs.xlrelease.webhooks.consumers.logging.{WebHookRequestAcceptedEvent, WebHookRequestIgnoredEvent, WebHookRequestUnauthorizedEvent}
import com.xebialabs.xlrelease.webhooks.endpoint.exceptions._
import grizzled.slf4j.Logging
import org.apache.commons.io.IOUtils

import jakarta.servlet.http.HttpServletRequest
import jakarta.ws.rs.core.Response
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}


abstract class WebhooksEndpointEventSourceHandler(endpointProvider: EndpointProvider,
                                                  eventBus: EventBus)
  extends EventSourceHandler[HttpRequestEvent, Endpoint]
    with Logging {

  def accept(path: String, request: HttpServletRequest): Response = {
    if (NodeState.isActive) {
      processRequest(path, request).recover(handleError).get
    } else {
      Response.status(Response.Status.SERVICE_UNAVAILABLE).entity("Node is not active").build()
    }
  }

  private def processRequest(path: String, request: HttpServletRequest): Try[Response] = {
    def getHeaders: Map[String, String] = request.getHeaderNames.asScala.map(name => name -> request.getHeader(name)).toMap

    def getParams: Map[String, Array[String]] = request.getParameterMap.asScala.toMap

    val result = for {
      endpoint <- getEndpoint(path)
      _ <- checkEnabled(endpoint)
      _ <- checkMethod(request, endpoint)
      payload <- Try(IOUtils.toString(request.getReader))
      reqHeaders = getHeaders
      reqParams = getParams
      _ <- authenticateRequest(endpoint, reqHeaders, reqParams, payload)
      event = HttpRequestEvent(endpoint, reqHeaders.asJava, reqParams.asJava, payload)
      published <- Try(publish(endpoint, event))
      response = if (published) Response.ok() else Response.notAcceptable(List.empty.asJava)
    } yield {
      logger.trace(s"${requestPrefix(request)}: Published event for $endpoint")
      response.build()
    }
    result match {
      case Success(_) =>
        eventBus.publish(WebHookRequestAcceptedEvent(s"Accepted WebHook request for path: webhooks/$path"))
      case Failure(ex) =>
        publishErrorEvents(ex)
    }
    result
  }

  private def handleError: PartialFunction[Throwable, Response] = {
    case e: WebhookEndpointControllerException =>
      logger.warn(e.getMessage, if (e.logStackTrace) e else null)
      Response.status(e.status).entity(e.getMessage).build()

    case e: NotFoundException =>
      logger.warn(e.getMessage)
      Response.status(Response.Status.NOT_FOUND).entity(e.getMessage).build()

    case e: Throwable =>
      logger.warn(e.getMessage, e)
      Response.status(Response.Status.INTERNAL_SERVER_ERROR)
        .entity("Exception happened when trying to process your request, check XLRelease log for details").build()
  }

  private def checkEnabled(endpoint: Endpoint): Try[Unit] = {
    if (NodeState.isActive) {
      if (!endpoint.sourceEnabled) {
        Failure(EndpointDisabled(endpoint))
      } else {
        Success(())
      }
    } else {
      Failure(EndpointDisabled(endpoint))
    }
  }

  private def checkMethod(request: HttpServletRequest, endpoint: Endpoint): Try[Unit] =
    if (endpoint.method.name() != request.getMethod) {
      Failure(exceptions.WrongMethod(request, endpoint))
    }
    else {
      Success(())
    }

  private def getEndpoint(path: String): Try[Endpoint] = {
    endpointProvider.findEndpointByPath(path).recoverWith {
      case _: NotFoundException => Failure(EndpointNotFound(path))
    }
  }

  private def getRequestAuthenticationMethod(endpoint: Endpoint): Try[RequestAuthenticationMethod] = {
    val authenticationMethod = for {
      auth <- Option(endpoint.authentication)
      method <- Option(auth.requestAuthentication)
    } yield method
    authenticationMethod
      .toRight(exceptions.EndpointAuthenticationMethodNotFound(endpoint))
      .toTry
  }

  private def authenticateRequest(endpoint: Endpoint,
                                  headers: Map[String, String],
                                  params: Map[String, Array[String]],
                                  payload: String): Try[Unit] = {
    getRequestAuthenticationMethod(endpoint).flatMap { authenticationMethod =>
      if (authenticationMethod.authenticateScala(endpoint, headers, params, payload)) {
        Success(())
      } else {
        Failure(exceptions.UnauthorizedRequest(endpoint, authenticationMethod))
      }
    }
  }

  private def requestPrefix(request: HttpServletRequest) = s"${request.getRequestURI} from ${request.getRemoteAddr}:${request.getRemotePort}"

  private def publishErrorEvents(ex: Throwable): Unit =
    ex match {
      case ex: EndpointNotFound =>
        eventBus.publish(WebHookRequestIgnoredEvent(ex.getMessage))
      case ex: UnauthorizedRequest =>
        eventBus.publish(WebHookRequestUnauthorizedEvent(ex.getMessage))
      case _ =>
        eventBus.publish(WebHookRequestIgnoredEvent(ex.getMessage))
    }
}

// TODO: we should really not pollute the logs with stack traces when rejecting an incoming payload.
object exceptions {

  abstract class WebhookEndpointControllerException(val message: String,
                                                    val status: Response.Status,
                                                    val logStackTrace: Boolean = false)
    extends Throwable(message)

  case class EndpointNotFound(path: String)
    extends WebhookEndpointControllerException(
      message = s"Endpoint not found for path '$path'",
      status = Response.Status.NOT_FOUND
    )

  case class WrongMethod(req: HttpServletRequest, endpoint: Endpoint)
    extends WebhookEndpointControllerException(
      message = s"Wrong HTTP method for '$endpoint': expected ${endpoint.method}, got ${req.getMethod}.",
      status = Response.Status.BAD_REQUEST
    )

  case class EndpointAuthenticationMethodNotFound(endpoint: Endpoint)
    extends WebhookEndpointControllerException(
      message = s"Authentication method not found for '$endpoint'.",
      status = Response.Status.NOT_FOUND
    )

  case class UnauthorizedRequest(endpoint: Endpoint, authenticationMethod: RequestAuthenticationMethod)
    extends WebhookEndpointControllerException(
      message = s"Unauthorized request for '$endpoint' (authentication method: ${authenticationMethod.getClass.getName})",
      status = Response.Status.UNAUTHORIZED
    )

  case class EndpointDisabled(endpoint: Endpoint)
    extends WebhookEndpointControllerException(
      message = s"Endpoint '$endpoint' is disabled",
      status = Response.Status.NOT_FOUND
    )

}
