package com.xebialabs.xlrelease.scm.connector

import com.xebialabs.deployit.util.PasswordEncrypter
import com.xebialabs.xlrelease.config.XlrConfig
import com.xebialabs.xlrelease.domain.scm.connector._
import com.xebialabs.xlrelease.scm.connector.HttpClientRequest.{MULTIPART_BODY_FORM, RequestMediaType}
import org.apache.http.impl.client.{CloseableHttpClient, HttpClients}
import org.springframework.core.io.ByteArrayResource
import org.springframework.http._
import org.springframework.http.client.{ClientHttpResponse, HttpComponentsClientHttpRequestFactory}
import org.springframework.util.LinkedMultiValueMap
import org.springframework.web.client.{ResponseErrorHandler, RestTemplate}
import org.springframework.web.util.UriComponentsBuilder

import java.net.URI
import java.util.Base64
import java.util.concurrent.TimeUnit
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

object HttpClientRequest {

  sealed trait RequestMediaType

  case class MULTIPART_BODY_FORM() extends RequestMediaType

  private val restTemplate: RestTemplate = {
    val scmConnectorRequestTimeout = XlrConfig.getInstance.timeouts.scmConnectorRequestTimeout
    val httpClient: CloseableHttpClient = HttpClients.custom.setConnectionTimeToLive(scmConnectorRequestTimeout.toSeconds, TimeUnit.SECONDS).build()
    val httpComponentsClientHttpRequestFactory = new HttpComponentsClientHttpRequestFactory(httpClient)
    val restTemplate = new RestTemplate(httpComponentsClientHttpRequestFactory)
    restTemplate.setErrorHandler(ScmResponseErrorHandler)
    restTemplate
  }

  def newRequest(restApiUrl: String): ScmHttpClientWrapper = {
    new ScmHttpClientWrapper(restApiUrl, restTemplate)
  }

  implicit class ResponseEntityExtension[T](response: ResponseEntity[T]) {
    def isSuccessful: Boolean = response.getStatusCode.is2xxSuccessful()

    def tried: Try[T] = {
      if (response.isSuccessful) {
        Success(response.getBody)
      } else {
        Failure {
          response.getStatusCode match {
            case code@HttpStatus.UNAUTHORIZED => ScmException("Invalid credentials", statusCode = code.value())
            case code@HttpStatus.NOT_FOUND => ScmException("Repository not found", statusCode = code.value())
            case _ => ScmException("Unable to connect to the repository.", statusCode = response.getStatusCodeValue)
          }
        }
      }
    }

    def checkResponse(errMsg: String): Try[ResponseEntity[T]] = {
      if (response.isSuccessful) {
        Success(response)
      } else {
        Failure(ScmException(s"$errMsg ${response.toString}"))
      }
    }
  }

}

class HttpClientRequest(restTemplate: RestTemplate, url: URI, httpMethod: HttpMethod, entity: HttpEntity[_]) {
  def doRequest[T: ClassTag](): Try[ResponseEntity[T]] = Try {
    // use ParameterizedTypeReference if you T is a generic java collection
    // see: https://stackoverflow.com/questions/19463372/classcastexception-resttemplate-returning-listlinkedhashmap-instead-of-listm
    restTemplate.exchange(url, httpMethod, entity, implicitly[ClassTag[T]].runtimeClass).asInstanceOf[ResponseEntity[T]]
  }
}


class ScmHttpClientWrapper(restApiUrl: String, restTemplate: RestTemplate) {
  private val formData = new LinkedMultiValueMap[String, AnyRef]
  private val uriComponentsBuilder = UriComponentsBuilder.fromHttpUrl(restApiUrl)
  private val headers = new HttpHeaders()
  private var mediaType: RequestMediaType = _
  private var uriEncoded: Boolean = false

  def setType(mediaType: RequestMediaType): ScmHttpClientWrapper = {
    this.mediaType = mediaType
    this
  }

  def setUriEncoded(uriEncoded: Boolean): ScmHttpClientWrapper = {
    this.uriEncoded = uriEncoded
    this
  }

  def addFormData(name: String, value: String): ScmHttpClientWrapper = {
    formData.add(name, value)
    this
  }

  def addFile(path: String, fileName: String, content: () => Array[Byte]): ScmHttpClientWrapper = {
    val resource = new ByteArrayResource(content()) {
      override def getFilename: String = fileName
    }
    formData.add(path, resource)
    this
  }

  def addQueryParam(name: String, value: String): ScmHttpClientWrapper = {
    uriComponentsBuilder.queryParam(name, value)
    this
  }

  def withAuth(authType: ScmCredential): ScmHttpClientWrapper = {
    authType match {
      case c: BitBucketCloudUsernamePasswordCredential => headers.set(HttpHeaders.AUTHORIZATION, getBasicAuthorization(c.username, c.password))
      case c: BitBucketCloudAppPassword => headers.set(HttpHeaders.AUTHORIZATION, getBasicAuthorization(c.username, c.appPassword))
      case c: BitBucketCloudAPIKey => headers.set(HttpHeaders.AUTHORIZATION, getBasicAuthorization(c.teamName, c.apiKey))
      case c: GitLabAPIKey => headers.set("Private-Token", decryptPassword(c.token))
      case _ => throw ScmException(s"Unsupported authentication method ${authType.getClass.getSimpleName}")
    }
    this
  }

  def post[T](path: String, body: T = null): HttpClientRequest = {
    uriComponentsBuilder.path(path)
    new HttpClientRequest(restTemplate, uriComponentsBuilder.build(uriEncoded).toUri, HttpMethod.POST, createHttpEntity(body))
  }

  def get(path: String): HttpClientRequest = {
    uriComponentsBuilder.path(path)
    new HttpClientRequest(restTemplate, uriComponentsBuilder.build(uriEncoded).toUri, HttpMethod.GET, createHttpEntity())
  }

  def head(path: String): HttpClientRequest = {
    uriComponentsBuilder.path(path)
    new HttpClientRequest(restTemplate, uriComponentsBuilder.build(uriEncoded).toUri, HttpMethod.HEAD, createHttpEntity())
  }

  private def createHttpEntity[B](body: B = null): HttpEntity[_ >: LinkedMultiValueMap[String, AnyRef] with B] = {
    val entity = mediaType match {
      case MULTIPART_BODY_FORM() => headers.setContentType(MediaType.MULTIPART_FORM_DATA)
        new HttpEntity(formData, headers)
      case _ => headers.setContentType(MediaType.APPLICATION_JSON_UTF8)
        new HttpEntity(body, headers)
    }
    entity
  }

  private def getBasicAuthorization(username: String, password: String): String = {
    s"Basic ${Base64.getEncoder.encodeToString(s"$username:${decryptPassword(password)}".getBytes)}"
  }

  private def decryptPassword(password: String): String = PasswordEncrypter.getInstance().ensureDecrypted(password)
}

object ScmResponseErrorHandler extends ResponseErrorHandler {
  override def hasError(response: ClientHttpResponse): Boolean = false

  override def handleError(response: ClientHttpResponse): Unit = ???
}
