#
# Copyright (c) 2018. All rights reserved.
#
# This software and all trademarks, trade names, and logos included herein are the property of XebiaLabs, Inc. and its affiliates, subsidiaries, and licensors.
#

from kubernetes import client
from kubernetes.client.api_client import ApiClient
from xld.kubernetes.pod.pod_helper import PodHelper
from xld.kubernetes.persistent_volume_claim.helper import PVCHelper
import time


class StatefulSetHelper(object):
    def __init__(self):
        self.__pod_helper = PodHelper()

    def read_statefulset(self, deployed_statefulset):
        statefulset = client.V1beta2StatefulSet()
        statefulset.metadata = client.V1ObjectMeta(name=self.get_statefulset_name(deployed_statefulset))

        template = self.__pod_helper.read_pod(deployed_pod=deployed_statefulset, pod=client.V1PodTemplateSpec())
        template['metadata']['name'] = statefulset.metadata.name
        spec = client.V1beta2StatefulSetSpec(
            template=template,
            service_name=deployed_statefulset.serviceName,
            replicas=deployed_statefulset.replicasCount,
            pod_management_policy=deployed_statefulset.podManagementPolicy,
            revision_history_limit=deployed_statefulset.revisionHistoryLimit,
            volume_claim_templates=self.__read_volume_claim_templates(deployed_statefulset),
            selector=self.__read_selector(deployed_statefulset),
        )
        if deployed_statefulset.strategyType:
            spec.update_strategy = self.__read_statefulset_strategy(deployed_statefulset)

        spec = ApiClient().sanitize_for_serialization(spec)
        statefulset = ApiClient().sanitize_for_serialization(statefulset)
        statefulset['spec'] = spec
        return statefulset

    @staticmethod
    def update_modifytime_label(statefulset):
        statefulset['spec']['template']['metadata']['labels']['modifytime'] = str(int(time.time()))

    @staticmethod
    def get_statefulset_name(deployed):
        return deployed.statefulSetName if deployed.statefulSetName else deployed.name

    @staticmethod
    def enrich_app_selectors(deployed_statefulset):
        if not deployed_statefulset.labels or 'app' not in deployed_statefulset.labels:
            deployed_statefulset.labels=dict(deployed_statefulset.labels, app = StatefulSetHelper.get_statefulset_name(deployed_statefulset))
        if not deployed_statefulset.matchLabels or 'app' not in deployed_statefulset.matchLabels:
            deployed_statefulset.matchLabels=dict(deployed_statefulset.matchLabels, app = StatefulSetHelper.get_statefulset_name(deployed_statefulset))

    @staticmethod
    def __read_match_expression(deployed_match_expression):
        match_expression = client.V1LabelSelectorRequirement(key=deployed_match_expression.key,
                                                             operator=deployed_match_expression.operator)
        if deployed_match_expression.matchValues:
            match_expression.values = deployed_match_expression.matchValues
        return match_expression

    @staticmethod
    def __read_selector(deployed_statefulset):
        selector = client.V1LabelSelector()
        if deployed_statefulset.matchExpressions:
            selector.match_expressions = []
            for deployed_match_expression in deployed_statefulset.matchExpressions:
                selector.match_expressions.append(StatefulSetHelper.__read_match_expression(deployed_match_expression))
        if deployed_statefulset.matchLabels:
            selector.match_labels = deployed_statefulset.matchLabels

        return selector

    @staticmethod
    def __read_statefulset_strategy(deployed_statefulset):
        strategy = client.V1beta2StatefulSetUpdateStrategy()
        if "RollingUpdate" == deployed_statefulset.strategyType:
            strategy.rolling_update = client.V1beta2RollingUpdateStatefulSetStrategy()
            strategy.rolling_update.partition = deployed_statefulset.partition
        strategy.type = deployed_statefulset.strategyType
        return strategy

    @staticmethod
    def __read_volume_claim_templates(deployed_statefulset):
        volume_claim_templates = []
        for volume_claim_template in deployed_statefulset.volumeClaimTemplates:
            pvc_helper = PVCHelper(volume_claim_template)
            volume_claim_templates.append(pvc_helper.read_pvc())

        return volume_claim_templates

    @staticmethod
    def validate_statefulset(deployed_statefulset):
        if deployed_statefulset.minimumPodCount > deployed_statefulset.replicasCount:
            raise RuntimeError("Minimum pod count should not be more than replicas count.")

        if not deployed_statefulset.matchExpressions and not deployed_statefulset.matchLabels:
            raise RuntimeError("Match Expressions or Match Labels should not be empty.")

        if not deployed_statefulset.labels:
            raise RuntimeError("Labels should not be empty.")

        if not deployed_statefulset.volumeClaimTemplates:
            raise RuntimeError("Volume Claim Templates should not be empty.")

    @staticmethod
    def verify_stateful_set_ready_on_create(response_statefulset, statefulset_name, minimum_pod_count):
        if response_statefulset.status.ready_replicas < minimum_pod_count:
            print("Waiting for StatefulSet's pod {0} to be in running state".format(statefulset_name))
            return True
        return False

