import com.xhaus.jyson.JysonCodec as Json
import urllib
from java.io import IOException
from java.lang import Integer, NumberFormatException
from org.apache.http import Consts
from org.apache.http import HttpHost
from org.apache.http.auth import AuthScope, UsernamePasswordCredentials
from org.apache.http.client.config import RequestConfig
from org.apache.http.client.methods import HttpPost
from org.apache.http.entity import StringEntity
from org.apache.http.impl.client import BasicCredentialsProvider
from org.apache.http.impl.client import HttpClients
from org.apache.http.util import EntityUtils

import time
from xlrelease.HttpResponse import HttpResponse


class OAuthHttpClient:
    def __init__(self, params):
        self._params = params
        self._use_proxy = True if params.get('proxyHost') is not None else False

    def post(self, url, body, content_type='application/x-www-form-urlencoded', accept='application/json'):
        request_entity = StringEntity(body)
        request = HttpPost(url)
        request.setEntity(request_entity)
        request.addHeader('Content-Type', content_type)
        request.addHeader('Accept', accept)
        return self._do_request(request)

    def _do_request(self, request):
        response = None
        http_client = None
        try:
            self._set_proxy(request)
            http_client = self._get_http_client()
            response = http_client.execute(request)
            try:
                entity = response.getEntity()
                result = EntityUtils.toString(entity, Consts.UTF_8.name()) if entity else None
                EntityUtils.consume(entity)
            except IOException:
                result = None
            status = response.getStatusLine().getStatusCode()
            headers = response.getAllHeaders()
            return HttpResponse(status, result, headers)
        finally:
            if response:
                response.close()
            if http_client:
                http_client.close()

    def _get_http_client(self):
        http_client = HttpClients.custom()
        if self._params.get("proxyUsername") is not None and self._params.get("proxyPassword") is not None:
            credentials = UsernamePasswordCredentials(self._params.get('proxyUsername'), self._params.get('proxyPassword'))
            auth_scope = AuthScope(self._params.get('proxyHost'), int(self._params.get('proxyPort')))
            http_client = HttpClients.custom()
            creds_provider = BasicCredentialsProvider()
            creds_provider.setCredentials(auth_scope, credentials)
            http_client.setDefaultCredentialsProvider(creds_provider)
        return http_client.build()

    def _set_proxy(self, request):
        if self._use_proxy:
            proxy = HttpHost(self._params.get('proxyHost'), int(self._params.get('proxyPort')))
            config = RequestConfig.custom().setProxy(proxy).build()
            request.setConfig(config)


# https://tools.ietf.org/html/rfc6749#section-4.3
class OAuthPasswordGrantTypeIssuer:
    def __init__(self, params):
        self._oauth_client = OAuthHttpClient(params)
        self._oauth_access_token_url = params.get('accessTokenUrl')
        self._client_id = params.get('clientId')
        self._client_secret = params.get('clientSecret')
        self._scope = params.get('scope')
        self._username = params.get('username')
        self._password = params.get('password')
        self._is_token_issued = False
        self._token_data = {}
        self._expiration_time = None

    def get_access_token(self):
        status_code = None
        if not self._is_token_issued:
            self._token_data, status_code = self._request_token()
        elif self._is_token_expired() and self._token_data.get('refresh_token') is not None and self._token_data.get('refresh_token'):
            self._token_data, status_code = self._renew_token()

        if status_code is not None and status_code != 200:
            raise Exception("HTTP response code {}".format(status_code))
        return self._token_data['access_token']

    def _request_token(self):
        body = dict(
            grant_type="password",
            client_id=self._client_id,
            client_secret=self._client_secret,
            username=self._username,
            password=self._password
        )
        if self._scope is not None:
            body['scope'] = self._scope

        return self._get_token(body)

    def _renew_token(self):
        body = dict(
            grant_type='refresh_token',
            client_id=self._client_id,
            client_secret=self._client_secret,
            refresh_token=self._token_data['refresh_token']
        )

        return self._get_token(body)

    def _get_token(self, body):
        response = self._oauth_client.post(self._oauth_access_token_url, body=urllib.urlencode(body))
        response_body = Json.loads(response.getResponse())
        if response.getStatus() == 200:
            self._is_token_issued = True
            self._set_expiration_time(response_body)
        else:
            self._is_token_issued = False
            self._expiration_time = None
        return response_body, response.getStatus()

    def _is_token_expired(self):
        if self._expiration_time is not None and time.time() > self._expiration_time:
            return True
        return False

    def _set_expiration_time(self, body):
        try:
            self._expiration_time = time.time() + Integer.valueOf(body.get('expires_in'))
        except NumberFormatException:
            self._expiration_time = None


class OAuthProvider:
    def __init__(self, params):
        self._token_issuer = None
        if str(params.get('oauth2GrantType')) == "Password":
            self._token_issuer = OAuthPasswordGrantTypeIssuer(params)
        else:
            raise Exception("OAuth grant_type='{}' is not supported".format(params.get('oauth2GrantType')))

    def get_access_token(self):
        return self._token_issuer.get_access_token()


class OAuthSupport(object):
    def __init__(self, params):
        if str(params.get('authenticationMethod')) == "OAuth2":
            self.use_oauth = True
            self._oauth_provider = OAuthProvider(params)
        else:
            self.use_oauth = False
            self._oauth_provider = None

    # doRequest interceptor
    def __getattribute__(self, name):
        if name == 'doRequest':
            def intercept(callback, name):
                def func(*args, **kwargs):
                    # set oauth token
                    if self.use_oauth:
                        token = self._oauth_provider.get_access_token()
                        headers = kwargs.get('headers', None)
                        if headers is None:
                            kwargs['headers'] = {
                                "Authorization": "Bearer {}".format(token)
                            }
                        else:
                            headers['Authorization'] = "Bearer {}".format(token)
                        return callback(*args, **kwargs)
                    else:
                        return callback(*args, **kwargs)

                return func

            callback = None
            do_request_attr = type(self).__dict__.get(name)
            if do_request_attr is not None:
                callback = do_request_attr.__get__(self, type(self))
            if callback is not None and callable(callback):
                return intercept(callback, name)

        return super(OAuthSupport, self).__getattribute__(name)
