package com.xebialabs.xlrelease.notifications.email

import com.google.common.base.Strings.isNullOrEmpty
import com.nimbusds.oauth2.sdk._
import com.nimbusds.oauth2.sdk.auth.{ClientSecretPost, Secret}
import com.nimbusds.oauth2.sdk.id.ClientID
import com.nimbusds.oauth2.sdk.token.{AccessToken, RefreshToken}
import com.xebialabs.xlrelease.notifications.configuration.{OAuth2SmtpAuthentication, SmtpServer}
import com.xebialabs.xlrelease.script.EncryptionHelper
import com.xebialabs.xlrelease.service.ConfigurationService
import grizzled.slf4j.Logging
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Component

import java.net.URI
import java.time.{Clock, Duration, Instant}
import java.util.Date

//TODO: Remove this and move it to generic class once implemented as a part of complete OAuth2.0 flow in the backend
@Component
class RefreshTokenOAuth2Helper @Autowired()(configurationService: ConfigurationService) extends Logging {

  private val clockSkew: Duration = Duration.ofSeconds(60)
  private val clock: Clock = Clock.systemUTC()

  def refreshTokenIfRequired(smtpServer: SmtpServer): Unit = {
    smtpServer.authentication match {
      case oauth2SmtpAuthentication: OAuth2SmtpAuthentication =>
        EncryptionHelper.decrypt(oauth2SmtpAuthentication)

        if (isNullOrEmpty(oauth2SmtpAuthentication.getRefreshToken)) {
          throw new IllegalArgumentException("The refresh token must not be null or empty")
        }

        if (isNullOrEmpty(oauth2SmtpAuthentication.getAccessToken) || hasTokenExpired(oauth2SmtpAuthentication)) {
          logger.debug("Going to generate a new access token")
          val accessToken = doRefresh(oauth2SmtpAuthentication)

          oauth2SmtpAuthentication.setAccessToken(accessToken.getValue)
          if (accessToken.getLifetime > 0) {
            oauth2SmtpAuthentication.setExpiresAt(Date.from(Instant.now.plusSeconds(accessToken.getLifetime)))
          } else {
            oauth2SmtpAuthentication.setExpiresAt(Date.from(Instant.now.plusSeconds(1)))
          }

          //Update only if smtp configuration exists
          if (configurationService.exists(SmtpServer.SMTP_SERVER_ID)) {
            configurationService.createOrUpdate(smtpServer)
          }
        }
      case _ =>
    }
  }

  private def hasTokenExpired(oauth2SmtpAuthentication: OAuth2SmtpAuthentication): Boolean = {
    if (oauth2SmtpAuthentication.getExpiresAt != null) {
      this.clock.instant().isAfter(oauth2SmtpAuthentication.getExpiresAt.toInstant.minus(this.clockSkew))
    } else {
      true
    }
  }

  private def doRefresh(oauth2SmtpAuthentication: OAuth2SmtpAuthentication): AccessToken = {
    val refreshToken: RefreshToken = new RefreshToken(oauth2SmtpAuthentication.getRefreshToken)
    val refreshTokenGrant: RefreshTokenGrant = new RefreshTokenGrant(refreshToken)

    // The credentials to authenticate the client at the token endpoint
    val clientID = new ClientID(oauth2SmtpAuthentication.getClientId)
    val clientSecret = new Secret(oauth2SmtpAuthentication.getClientSecret)
    val clientAuth = new ClientSecretPost(clientID, clientSecret)

    // The token endpoint
    val tokenEndpoint = new URI(oauth2SmtpAuthentication.getAccessTokenUrl)

    // Make the token request
    val request: TokenRequest = new TokenRequest(tokenEndpoint, clientAuth, refreshTokenGrant)
    val response: TokenResponse = TokenResponse.parse(request.toHTTPRequest.send)

    logger.debug(s"Refresh token response from server: ${response.toHTTPResponse.getContent}")

    if (!response.indicatesSuccess) {
      val errorResponse: TokenErrorResponse = response.toErrorResponse
      throw new RuntimeException(s"An error occurred while attempting to get the new OAuth 2.0 Access Token: [${errorResponse.getErrorObject.getCode}] ${errorResponse.getErrorObject.getDescription}")
    }

    val successResponse: AccessTokenResponse = response.toSuccessResponse
    val accessToken: AccessToken = successResponse.getTokens.getAccessToken
    accessToken
  }

}
