#
import base64
import re
import sys
import jarray
from sets import Set

from javax.net.ssl import HostnameVerifier, X509TrustManager, TrustManager, TrustManagerFactory, SSLContext
from java.net import URL, ConnectException
from java.security import KeyStore
from java import util
from java.io import DataOutputStream, FileInputStream

from wlp.modules.utility import Logger


class AllHostsVerifier(HostnameVerifier):
    def verify(self, urlHostname, session):
        return True


class TrustAllX509TrustManager(X509TrustManager):
    def checkClientTrusted(self, chain, auth):
        pass

    def checkServerTrusted(self, chain, auth):
        pass

    def getAcceptedIssuers(self):
        return None


TRUST_ALL_HOSTNAME_VERIFIER = AllHostsVerifier()
TRUST_ALL_X509_MANAGER = TrustAllX509TrustManager()


class ConnectionFactory(object):
    def __init__(self, container, use_ssl):
        if not container.hostname or not container.httpsPort:
            Logger.log_and_raise_error("This operation requires communication with the Liberty REST JMX connector. " \
                                       "Connection data for the connector is missing, please specify it on the wlp.Server instance.")
        self.container = container
        self.use_ssl = use_ssl

    def get_connection(self, url):
        connection = URL("https" if self.use_ssl else "http", self.container.hostname, self.container.httpsPort, url).openConnection()
        connection.setUseCaches(False)
        connection.setConnectTimeout(self.container.connectTimeout)
        connection.setReadTimeout(self.container.readTimeout)
        connection.setAllowUserInteraction(False)
        connection.setInstanceFollowRedirects(True)
        if self.use_ssl:
            ssl_context = SSLContext.getInstance(self.container.sslProtocol)
            if self.container.trustAllHostnames:
                connection.setHostnameVerifier(TRUST_ALL_HOSTNAME_VERIFIER)
            if self.container.trustAllCertificates:
                ssl_context.init(None, jarray.array([TRUST_ALL_X509_MANAGER], TrustManager), None)
            elif self.container.trustStorePath:
                # TODO: benchmark if this could be slow and cache accordingly
                key_store = KeyStore.getInstance(KeyStore.getDefaultType())
                key_store.load(FileInputStream(self.container.trustStorePath), self.container.trustStorePassword)
                trust_manager_factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
                trust_manager_factory.init(key_store)
                ssl_context.init(None, trust_manager_factory.getTrustManagers(), None)
            else:
                ssl_context.init(None, None, None)
            connection.setSSLSocketFactory(ssl_context.getSocketFactory())
        return connection


class JMXSSLConnector(object):
    def __init__(self, container, use_ssl=True, default_headers=None):
        self.connection_factory = ConnectionFactory(container, use_ssl)
        self.default_headers = default_headers if default_headers else {'Content-Type': 'application/json'}
        if 'Authorization' not in self.default_headers:
            self.default_headers['Authorization'] = "Basic %s" % base64.encodestring('%s:%s' % (container.username, container.password)).replace('\n', '')

    def send(self, method, url, body="{}", headers=None, throw_on_failure=True):
        if not headers:
            headers = {}
        connection = self.connection_factory.get_connection(url)
        try:
            connection.setRequestMethod(method)
            self._set_headers(connection, dict(self.default_headers.items() + headers.items()))
            if self._is_post(method):
                self._set_request_body(connection, body)

            response = self._get_response(connection)
            result = Result(connection.getResponseCode(), connection.getResponseMessage(), response)
            if throw_on_failure:
                result.throw_on_failure()
            return result
        finally:
            if connection:
                connection.disconnect()

    # private

    def _set_headers(self, connection, request_headers):
        for key in request_headers.keys():
            connection.setRequestProperty(key, request_headers[key])

    def _is_post(self, method):
        return method == 'POST'

    def _set_request_body(self, connection, body):
        connection.setDoOutput(True)
        wr = DataOutputStream(connection.getOutputStream())
        wr.writeBytes(body)
        wr.flush()
        wr.close()

    def _get_response(self, connection):
        try:
            if 200 <= connection.getResponseCode() < 400:
                response = self._get_response_from_stream(connection.getInputStream())
            else:
                response = self._get_response_from_stream(connection.getErrorStream())
            return response
        except ConnectException:
            Logger.log_and_raise_error("Unable to communicate with server, please ensure that server instance is running.")

    def _get_response_from_stream(self, stream):
        response = ''
        if stream:
            scanner = util.Scanner(stream, 'UTF-8').useDelimiter("\\A")
            response = ''
            if scanner.hasNext():
                response = scanner.next()
        return response


class Result(object):
    def __init__(self, status, reason, text):
        self.status = status
        self.reason = reason
        self.text = text

    def is_success(self):
        return 200 <= self.status < 400

    def get(self, key):
        m = re.compile(r"\"%s\":\"(.*?)\"" % key).search(self.text)
        if m and m.groups():
            return m.groups()[0]

    def throw_on_failure(self):
        if not self.is_success():
            self.dump()
            Logger.log_and_raise_error("Request failed with response code: %s" % self.status)

    def dump(self):
        print "Error executing request, please see response for more details.\nStatus: %s\nReason: %s\nText: %s" % (self.status, self.reason, self.text)


class MethodCall(object):
    def __init__(self):
        self.signature = util.ArrayList()
        self.params = util.ArrayList()

    def with_signature(self, *types):
        self.signature.addAll(*types)
        return self

    def with_array_param(self, values, collection_type, item_type=None):
        return self.with_param(values, dict(className=collection_type, items=[item_type] * len(values)))

    def with_param(self, value, type):
        self.params.add(self._to_java(dict(value=value, type=type)))
        return self

    def build(self):
        if len(self.signature) != len(self.params):
            msg = "ERROR: Number of types in method signature (%s) does not match the number of parameters (%s)." % (len(self.signature), len(self.params))
            print >> sys.stderr, msg
            raise Exception(msg)
        from org.codehaus.jackson.map import ObjectMapper

        return ObjectMapper().writeValueAsString(self._to_java(self.__dict__))

    def _to_java(self, proxy):
        if isinstance(proxy, list):
            return util.ArrayList(proxy)
        elif isinstance(proxy, Set):
            return util.HashSet(proxy)
        elif isinstance(proxy, dict):
            map = util.HashMap()
            for key in proxy.keys():
                map.put(key, self._to_java(proxy.get(key)))
            return map
        else:
            return proxy
