#
# Copyright (c) 2018. All rights reserved.
#
# This software and all trademarks, trade names, and logos included herein are the property of XebiaLabs, Inc. and its affiliates, subsidiaries, and licensors.
#

import os
import base64
import tempfile

# to use secure hash and message digest algorithms
import hashlib
import botocore.awsrequest as awsreq
import com.xebialabs.deployit.plugin.kubernetes.BotoLoader as BotoLoader
import xld.kubernetes.eks as eks

from java.nio.file import Files, Paths, StandardCopyOption
from botocore.session import Session as BotocoreSession
from botocore.signers import RequestSigner
from botocore.model import ServiceModel
from boto3.session import Session

# to access the kubernetes client configuration
from kubernetes.client.configuration import Configuration

# Constants
default_encoding = 'utf-8'
action_name = 'GetCallerIdentity'
cluster_id_header = "x-k8s-aws-id"
token_prefix = 'k8s-aws-v1.'
aws_protocol = 'https'


class EKSHelper(object):
    def __init__(self, access_key, access_secret):
        # Init boto session
        self.core_session = BotocoreSession()
        self.core_session.lazy_register_component('data_loader', lambda: eks.create_loader())
        session_keywords = {'botocore_session': self.core_session, 'aws_access_key_id': access_key, 'aws_secret_access_key': access_secret}
        self.session = Session(**session_keywords)
        EKSHelper.set_ca_bundle_path()

        # Load STS service model from JSON file
        data_loader = self.core_session.get_component('data_loader')
        json_model = data_loader.load_service_model('sts', 'service-2')
        self.sts_service_model = ServiceModel(json_model, service_name='sts')

    def get_eks_client(self, resource_name='eks'):
        return self.session.client(resource_name, verify=False)

    def get_k8s_token(self, cluster_name, use_global, region_name="us-east-1"):
        # Prepare AWS request parameters
        request_dict = {
            'url_path': '/',
            'query_string': '',
            'headers': {cluster_id_header: cluster_name},
            'body': {
                'Action': action_name,
                'Version': self.sts_service_model.api_version
            },
            'method': 'GET'
        }
        if use_global:
            print("Using global STS endpoint")
            url_endpoint = self.sts_service_model.metadata['globalEndpoint']
            region_name = "us-east-1"
        else:
            print("Using regional STS endpoint")
            url_endpoint = "{0}.{1}.amazonaws.com".format(self.sts_service_model.metadata['endpointPrefix'], region_name)
        awsreq.prepare_request_dict(
            request_dict,
            endpoint_url="{0}://{1}".format(aws_protocol, url_endpoint)
        )

        # Pre-sign AWS request
        request_signer = RequestSigner(
            service_name=self.sts_service_model.service_name,
            region_name=region_name,
            signing_name=self.sts_service_model.signing_name,
            signature_version=self.sts_service_model.signature_version,
            credentials=self.session.get_credentials(),
            event_emitter=self.session.events
        )
        signed_url = request_signer.generate_presigned_url(request_dict, action_name, expires_in=60)
        return token_prefix + base64.urlsafe_b64encode(signed_url.encode(default_encoding)).decode(default_encoding).rstrip('=')

    @staticmethod
    def extract_file_from_jar(config_file):
        file_url = BotoLoader.getResourceBySelfClassLoader(config_file)

        # instead of creating multiple tmp files with same content,
        # generates the specific file name with prefix eks_cacert followed by some unique code generated using md5 algo on kubernetes client host and maps to the tmp directory
        if file_url:
            tmp_file_name = 'eks_cacert-{0}'.format(hashlib.md5(Configuration().host).hexdigest())
            tmp_file_path = os.path.join(tempfile.gettempdir(), tmp_file_name)
            with open(file_url.openStream()) as file:
                with open(tmp_file_path, "w") as tmp_file:
                    for line in file:
                        tmp_file.write(line)
            return tmp_file_path
        else:
            return None

    @staticmethod
    def set_ca_bundle_path():
        ca_bundle_path = EKSHelper.extract_file_from_jar("botocore/vendored/requests/cacert.pem")
        os.environ['REQUESTS_CA_BUNDLE'] = ca_bundle_path

    @staticmethod
    def is_starts_with_name(property_value):
        return property_value.lower().startswith('name:') if property_value else False

    @staticmethod
    def get_property_name(property_name):
        return property_name[5:]

    @staticmethod
    def is_success(response):
        return 299 >= response['ResponseMetadata']['HTTPStatusCode'] >= 200

    @staticmethod
    def remove_none_keys(dict):
        return {k: v for k, v in dict.iteritems() if v is not None}

    @staticmethod
    def remove_empty_and_none_values(dict):
        return {k: v for k, v in dict.iteritems() if (bool(v) if isinstance(v, (list, set)) else v is not None)}

    @staticmethod
    def get_current_retry_count(context, counter_name_suffix):
        counter_name_suffix = "current_retry_{0}".format(counter_name_suffix)
        current_retry_count = context.getAttribute(counter_name_suffix)
        current_retry_count = 1 if not current_retry_count else current_retry_count
        return current_retry_count

    @staticmethod
    def increment_retry_counter(context, counter_name_suffix):
        current_retry_count = EKSHelper.get_current_retry_count(context, counter_name_suffix)
        current_retry_count = current_retry_count + 1
        EKSHelper.set_current_retry_count(context, counter_name_suffix, current_retry_count)

    @staticmethod
    def set_current_retry_count(context, counter_name_suffix, current_retry_count):
        counter_name_suffix = "current_retry_{0}".format(counter_name_suffix)
        context.setAttribute(counter_name_suffix, current_retry_count)

    def retry_or_fail(self, context, subject, max_retry_count, fail_message, wait_message):
        retry_count = self.get_current_retry_count(context, "{0}_stopped".format(subject))
        if retry_count > max_retry_count:
            raise RuntimeError("Reached maximum limit of {0} retries. {1}"
                               .format(max_retry_count, fail_message))
        else:
            print "{0} Done with retry {1}".format(wait_message, retry_count)
            self.increment_retry_counter(context, "{0}_stopped".format(subject))
            return "RETRY"
