package com.xebialabs.platform.sso.oidc.web;

import java.io.IOException;
import java.security.PublicKey;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.jwt.crypto.sign.InvalidSignatureException;
import org.springframework.security.jwt.crypto.sign.SignatureVerifier;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.util.Assert;
import com.fasterxml.jackson.databind.ObjectMapper;

import com.xebialabs.platform.sso.crypto.KeyRetriever;
import com.xebialabs.platform.sso.crypto.RS256SignatureVerifier;
import com.xebialabs.platform.sso.oidc.policy.ClaimsToGrantedAuthoritiesPolicy;
import com.xebialabs.platform.sso.oidc.policy.ClaimsToUserNamePolicy;
import com.xebialabs.platform.sso.oidc.userdetails.OpenIdConnectUserDetails;

import static com.xebialabs.platform.sso.oidc.OpenIdConnectProperty.AUDIENCE;
import static com.xebialabs.platform.sso.oidc.OpenIdConnectProperty.EXPIRATION;
import static com.xebialabs.platform.sso.oidc.OpenIdConnectProperty.ID_TOKEN;
import static com.xebialabs.platform.sso.oidc.OpenIdConnectProperty.ISSUER;
import static org.slf4j.LoggerFactory.getLogger;

/**
 * http://www.baeldung.com/spring-security-openid-connect
 * https://connect2id.com/learn/openid-connect
 * https://connect2id.com/blog/how-to-validate-an-openid-connect-id-token
 * https://developer.okta.com/standards/OIDC/
 * <p>
 * TODO: check nonce should be same as the one used in the request
 * TODO: the state variable should also be the same between request and response, was unable to test that.
 */
public class OpenIdConnectFilter extends AbstractAuthenticationProcessingFilter {
    private static final Logger log = getLogger(OpenIdConnectFilter.class);

    private final ObjectMapper objectMapper = new ObjectMapper();
    private final String audience;
    private final String issuer;
    private final OAuth2RestTemplate restTemplate;
    private final ClaimsToUserNamePolicy claimsToUserNamePolicy;
    private final ClaimsToGrantedAuthoritiesPolicy claimsToGrantedAuthoritiesPolicy;
    private KeyRetriever keyRetriever;

    public OpenIdConnectFilter(
            String defaultFilterProcessesUrl,
            String audience,
            String issuer,
            OAuth2RestTemplate restTemplate,
            KeyRetriever keyRetriever,
            ClaimsToUserNamePolicy claimsToUserNamePolicy,
            ClaimsToGrantedAuthoritiesPolicy claimsToGrantedAuthoritiesPolicy
    ) {
        super(defaultFilterProcessesUrl);
        Assert.hasText(audience, "audience/clientId cannot be empty");
        Assert.hasText(issuer, "issuer cannot be empty");
        this.audience = audience;
        this.issuer = issuer;
        this.restTemplate = restTemplate;
        this.claimsToGrantedAuthoritiesPolicy = claimsToGrantedAuthoritiesPolicy;
        this.claimsToUserNamePolicy = claimsToUserNamePolicy;
        this.keyRetriever = keyRetriever;
    }

    protected OAuth2AccessToken getAccessToken() {
        try {
            return restTemplate.getAccessToken();
        } catch (OAuth2Exception e) {
            throw new BadCredentialsException("Could not obtain access token", e);
        }
    }

    @Override
    public Authentication attemptAuthentication(
            HttpServletRequest request,
            HttpServletResponse response
    ) throws AuthenticationException, IOException, ServletException {
        try {
            log.debug("Starting authentication attempt: getting OIDC access token");
            OAuth2AccessToken accessToken = getAccessToken();

            log.debug("Retrieving {} from access token", ID_TOKEN.getPropertyName());
            String idToken = accessToken.getAdditionalInformation().get(ID_TOKEN.getPropertyName()).toString();
            String kid = JwtHelper.headers(idToken).get("kid");
            PublicKey keyById = keyRetriever.getKeyById(kid);
            if (keyById == null) {
                logger.warn("Got OIDC token signed by unknown key ID {}, refreshing and retrying...");
                keyRetriever.refreshKeys();
                keyById = keyRetriever.getKeyById(kid);
            }
            SignatureVerifier verifier = new RS256SignatureVerifier(keyById);
            Jwt tokenDecoded = JwtHelper.decodeAndVerify(idToken, verifier);
            Map<String, Object> claims = objectMapper.readValue(tokenDecoded.getClaims(), Map.class);
            log.debug("Received claims: [{}]", claims);

            verifyTimestamps(claims);
            verifyAudience(claims);
            verifyIssuer(claims);

            String userName = claimsToUserNamePolicy.claimsToUserName(claims);
            log.debug("User name derived from claims: {}", userName);

            List<GrantedAuthority> authorities = claimsToGrantedAuthoritiesPolicy.claimsToGrantedAuthorities(claims);

            OpenIdConnectUserDetails user = new OpenIdConnectUserDetails(userName, claims, authorities, idToken);

            log.debug("User {} successfully authenticated via OIDC", user);
            UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(userName, null, user.getAuthorities());
            token.setDetails(user);

            return token;
        } catch (InvalidSignatureException e) {
            log.error("Invalid signature on JWT token. Is the property xl.security.auth.providers.oidc.publicKey in the configuration file configured properly?", e);
            throw new BadCredentialsException("Unable to verify the validity of the received JWT token", e);
        } catch (InvalidTokenException e) {
            throw new BadCredentialsException("Could not obtain user details from token", e);
        }
    }

    protected void verifyTimestamps(Map<String, Object> claims) {
        try {
            Number o = (Number) claims.get(EXPIRATION.getPropertyName());  // use number since Jackson may return int/long mix
            if (o == null) {
                throw new BadCredentialsException("Invalid OIDC access token: no '" + EXPIRATION.getPropertyName() + "' claim");
            }
            // Note: exp is in seconds
            long exp = o.longValue();
            long now = System.currentTimeMillis() / 1000L;
            if (now >= exp) {
                log.error("Invalid OIDC access token: it has expired. exp={}, now={}. Check if system time is in sync, or if the expiry setting is too low", exp, now);
                throw new BadCredentialsException("OIDC access token has expired");
            }
        } catch (ClassCastException e) {
            throw new BadCredentialsException("Invalid OIDC access token it has a non numeric '" + EXPIRATION.getPropertyName() + "' claim");
        }
    }

    private void verifyIssuer(Map<String, Object> claims) {
        Object receivedIssuer = claims.get(ISSUER.getPropertyName());
        if (!issuer.equals(receivedIssuer)) {
            log.error("JWT token has wrong issuer. Is xl.security.auth.providers.oidc.issuer in the configuration file configured properly? Expected=[{}] actual=[{}]", issuer, receivedIssuer);
            throw new BadCredentialsException("Received JWT token is from wrong issuer");
        }
    }

    private void verifyAudience(Map<String, Object> claims) {
        Object receivedAudience = claims.get(AUDIENCE.getPropertyName());
        if (!audience.equals(receivedAudience)) {
            log.error("JWT token has wrong audience. Is xl.security.auth.providers.oidc.clientId in the configuration file configured properly? Expected=[{}] actual=[{}]", audience, receivedAudience);
            throw new BadCredentialsException("Received JWT token is for wrong audience");
        }
    }
}
