import functools

import yaml
import json
import time
import pytest
from functools import wraps

from si_tests import settings
from si_tests import logger
from si_tests.utils import utils, templates, waiters
from si_tests.utils.log_step import parse_test_doc
from kubernetes.client.rest import ApiException


LOG = logger.logger


def create(cluster, machine):
    deletion_policy = machine.data['spec']['providerSpec']['value'].get('deletionPolicy')

    if deletion_policy == 'graceful' or deletion_policy is None:
        return GracefulMachineDeletionPolicyCheck(cluster, machine)
    elif deletion_policy == 'unsafe':
        return UnsafeMachineDeletionPolicyCheck(cluster, machine)
    elif deletion_policy == 'forced':
        return ForcedMachineDeletionPolicyCheck(cluster, machine)
    else:
        raise RuntimeError(f"Unknown deletion policy {deletion_policy}")


class BaseMachineDeletionPolicyCheck:

    CLUSTER_WORKLOAD_LOCK_TEMPLATE = """
apiVersion: lcm.mirantis.com/v1alpha1
kind: ClusterWorkloadLock
metadata:
    name: {controllerName}
spec:
    controllerName: {controllerName}
"""

    NODE_WORKLOAD_LOCK_TEMPLATE = """
apiVersion: lcm.mirantis.com/v1alpha1
kind: NodeWorkloadLock
metadata:
  labels:
    controller-name: {controllerName}
  name: {controllerName}-{nodeName}
spec:
  controllerName: {controllerName}
  nodeName: {nodeName}
  nodeDeletionRequestSupported: {nodeDeletionRequestSupported}
"""

    def __init__(self, cluster, machine, base_image_repo=None):
        self.cluster = cluster
        self.base_image_repo = base_image_repo or cluster.determine_mcp_docker_registry()
        self.machine_name = machine.name
        self.check_workload_locks = settings.CHECK_WORKLOAD_LOCKS_ON_MACHINE_DELETE
        try:
            self.node_name = machine.get_k8s_node().read().metadata.labels['kubernetes.io/hostname']
        except Exception:
            self.node_name = None
        self.cwl = None
        self.nwl = None
        self.nwl_nolock = None

    def _init(self):
        if not self.node_name:
            return

        if self.check_workload_locks:
            self.cwl = self.cluster.k8sclient.clusterworkloadlocks.create(
                body=yaml.load(
                    self.CLUSTER_WORKLOAD_LOCK_TEMPLATE.format(controllerName='test'),
                    Loader=yaml.SafeLoader))

            self.cwl.patch(update_status=True, body={
                'status': {
                    'state': 'active'
                }
            })

            self.nwl = self.cluster.k8sclient.nodeworkloadlocks.create(
                body=yaml.load(
                    self.NODE_WORKLOAD_LOCK_TEMPLATE.format(
                        controllerName='test',
                        nodeName=self.node_name,
                        nodeDeletionRequestSupported=True,
                    ),
                    Loader=yaml.SafeLoader))

            self.nwl.patch(update_status=True, body={
                'status': {
                    'state': 'active'
                }
            })

            self.nwl_nolock = self.cluster.k8sclient.nodeworkloadlocks.create(
                body=yaml.load(
                    self.NODE_WORKLOAD_LOCK_TEMPLATE.format(
                        controllerName='test-nolock',
                        nodeName=self.node_name,
                        nodeDeletionRequestSupported=False,
                    ),
                    Loader=yaml.SafeLoader))

            self.nwl_nolock.patch(update_status=True, body={
                'status': {
                    'state': 'active'
                }
            })

    def _cleanup(self):
        if self.nwl:
            self.nwl.delete()
        if self.nwl_nolock:
            self.nwl_nolock.delete()
        if self.cwl:
            self.cwl.delete()

    def __enter__(self):
        try:
            self._init()
        except Exception:
            self._cleanup()
            raise

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._cleanup()

    def check(self):
        pass


class GracefulMachineDeletionPolicyCheck(BaseMachineDeletionPolicyCheck):
    def __init__(self, cluster, machine, base_image_repo=None):
        super().__init__(cluster, machine, base_image_repo)
        self.pod = None

    def _init(self):
        super()._init()

        if not self.node_name:
            return

        pod_name = f"dummy-pod-{utils.gen_random_string(6)}"
        template = templates.render_template(settings.DUMMY_TEST_POD_YAML,
                                             {'POD_NAME': pod_name, 'NODE_NAME': self.node_name,
                                              'IMAGE_BASE_REPO': self.base_image_repo})
        json_body = json.dumps(yaml.load(template, Loader=yaml.SafeLoader))

        LOG.info(f"Create pod on {self.node_name}")
        self.pod = self.cluster.k8sclient.pods.create(name='test-pod', body=json.loads(json_body))
        self.pod.wait_phase('Succeeded')

    def _cleanup(self):
        super()._cleanup()
        if self.pod and self.pod.exists():
            try:
                self.pod.delete()
            except ApiException as e:
                if e.reason != "Not Found":
                    raise e

    def check(self):
        self.check_delete_locked()
        self.release_lock()
        self.wait_machine_prepared_to_delete()

    def check_delete_locked(self):
        if not self.node_name:
            return

        if not self.nwl:
            return

        LOG.info("Wait until NodeDeletionRequest created")
        timeout = 900
        if self.cluster.provider == utils.Provider.equinixmetalv2:
            timeout = 3600
        self.cluster.check.wait_nodedeletionrequest_created(self.node_name, timeout=timeout)

        LOG.info("Check prepare deletion not started for 1m")
        for _ in range(4):
            LOG.info(f"Check pod exists {self.pod.name}")
            assert self.pod.exists(), \
                "Pod have to exists. Node prepare deletion locked"
            self.cluster.check.wait_machine_prepare_deletion_phase(self.machine_name, expected_phase="started")
            time.sleep(15)

    def release_lock(self):
        if not self.node_name:
            return

        if not self.nwl:
            return

        self.nwl.patch(update_status=True, body={
            'status': {
                'state': 'inactive'
            }
        })

    def wait_machine_prepared_to_delete(self):
        if not self.node_name:
            return

        pod_name = self.pod.name

        def check_pod_exists():
            exists = self.pod.exists()
            LOG.info(f"Pod {pod_name} exists={exists}")
            return exists

        LOG.info(f"Waiting pod {pod_name} to be drained")
        waiters.wait(lambda: not check_pod_exists(),
                     timeout=1800,
                     interval=30,
                     timeout_msg=f"Timeout for waiting pod {pod_name} to be drained")

        self.cluster.check.wait_machine_prepare_deletion_phase(self.machine_name, "completed", interval=30)

    def wait_machine_delete_aborted(self):
        self.cluster.check.wait_machine_prepare_deletion_phase('aborting')
        self.cluster.check.wait_machine_prepare_deletion_phase('')


class UnsafeMachineDeletionPolicyCheck(BaseMachineDeletionPolicyCheck):
    def __init__(self, cluster, machine, base_image_repo=None):
        super().__init__(cluster, machine, base_image_repo)

    def check(self):
        if not self.node_name:
            return

        assert len(self.cluster.get_nodedeletionrequests()) == 0, \
            "Unexpected NodeDeletionRequest for unsafe machine deletion policy"


class ForcedMachineDeletionPolicyCheck(BaseMachineDeletionPolicyCheck):
    def __init__(self, cluster, machine, base_image_repo=None):
        super().__init__(cluster, machine, base_image_repo)

    def check(self):
        if not self.node_name:
            return

        assert len(self.cluster.get_nodedeletionrequests()) == 0, \
            "Unexpected NodeDeletionRequest for forced machine deletion policy"


def subtest(func):
    def docstring_printer(func):
        @wraps(func)
        def _docstring_printer(*args, **kwargs):
            head = "#" * 30 + "[ SUBTEST: {} ]" + "#" * 30
            head = head.format(func.__name__)
            docstring = func.__doc__
            info = "\n{head}".format(head=head)
            if docstring and type(docstring) is str:
                docstring = '\n'.join([f"### {line}"
                                       for line in docstring.splitlines()])
                info = f"{info}\n{docstring}"
            LOG.info(info)
            func(*args, **kwargs)

        return _docstring_printer

    def show_substep_printer(func):
        def show_substep(step_num):
            test_case_steps = parse_test_doc(func.__doc__)['steps']
            try:
                LOG.info("\n\n*** [SUBSTEP: {0}] {1} ***".format(
                    step_num,
                    test_case_steps[step_num - 1]))
            except IndexError:
                LOG.error("Can't show step #{0}: docstring for method {1} does't "
                          "contain it!".format(step_num, func.__name__))

        @wraps(func)
        def _show_substep_printer(*args, **kwargs):
            func(*args, show_substep, **kwargs)

        return _show_substep_printer

    return docstring_printer(show_substep_printer(func))


@subtest
def check_machine_delete_without_policy(cluster, machine, show_step, check_deletion_policy=False,
                                        wait_deletion_timeout=1200):
    """Delete machine.

    Scenario:
        1. Delete machine
        2. Wait until machine deleted
    """
    machine_name = machine.name
    node_name = machine.get_k8s_node_name()

    show_step(1)
    machine.delete(check_deletion_policy=check_deletion_policy, new_delete_api=False)

    show_step(2)
    cluster.wait_machine_deletion(machine_name, interval=30, retries=wait_deletion_timeout/30)
    cluster.check.check_deleted_node(node_name)


@subtest
def check_machine_deletion_policy(cluster, machine, show_step, deletion_policy,
                                  wait_deletion_timeout=1200, check_deleted_node=True):
    """Delete machine, check deletion policy.

    Scenario:
        1. Run delete
        2. Check deletion policy
        3. Wait until machine deleted
    """

    if deletion_policy == "forced":
        LOG.info("Check forced policy forbidden to set initially")
        with pytest.raises(ApiException) as exc_info:
            machine.set_deletion_policy('forced')
        assert exc_info.value.status == 400
        assert exc_info.value.reason == 'Bad Request'
    else:
        machine.set_deletion_policy(deletion_policy)

    with create(cluster, machine) as policy_check:
        machine_name = machine.name
        node_name = machine.get_k8s_node_name()

        show_step(1)
        machine.delete(check_deletion_policy=False, new_delete_api=True)
        if deletion_policy == "forced":
            machine.set_deletion_policy(deletion_policy)

        show_step(2)
        policy_check.check()

        show_step(3)
        cluster.wait_machine_deletion(machine_name, interval=30, retries=wait_deletion_timeout/30)
        if check_deleted_node:
            cluster.check.check_deleted_node(node_name)


@subtest
def check_machine_graceful_delete_with_abort(cluster, machine, show_step, wait_deletion_timeout=1200):
    """Delete machine, check deletion policy.

    Scenario:
        1. Run graceful delete
        2. Wait until delete locked by NodeWorkloadLock
        3. Abort graceful delete
        4. Wait until all machines and cluster ready
        5. Run graceful delete
        6. Release NodeWorkloadLock
        7. Wait until machine deleted
    """

    with GracefulMachineDeletionPolicyCheck(cluster, machine) as policy_check:
        machine_name = machine.name
        node_name = machine.get_k8s_node_name()

        show_step(1)
        machine.set_deletion_policy('graceful')
        machine.delete(check_deletion_policy=False, new_delete_api=True)

        show_step(2)
        policy_check.check_delete_locked()

        show_step(3)
        machine.abort_graceful_deletion()
        policy_check.wait_machine_delete_aborted()

        show_step(4)
        # Waiting for machines are Ready
        cluster.check.check_machines_status(timeout=1800)
        cluster.check.check_cluster_nodes()
        cluster.check.check_k8s_nodes()

        # Check/wait for correct docker service replicas in cluster
        ucp_worker_agent_name = cluster.check.get_ucp_worker_agent_name()
        cluster.check.check_actual_expected_docker_services(
            changed_after_upd={'ucp-worker-agent-x': ucp_worker_agent_name})
        cluster.check.check_k8s_pods()
        cluster.check.check_actual_expected_pods(timeout=3200)
        cluster.check.check_cluster_readiness()
        cluster.check.check_diagnostic_cluster_status()

        show_step(5)
        machine.delete(check_deletion_policy=False, new_delete_api=True)
        policy_check.check_delete_locked()

        show_step(6)
        policy_check.release_lock()

        show_step(7)
        policy_check.wait_machine_prepared_to_delete()
        cluster.wait_machine_deletion(machine_name, interval=30, retries=wait_deletion_timeout/30)
        cluster.check.check_deleted_node(node_name)


@subtest
def check_machine_graceful_delete_stuck_complete_with_unsafe(cluster, machine, show_step):
    """Delete machine, check deletion policy.

    Scenario:
        1. Run graceful delete
        2. Wait until delete locked by NodeWorkloadLock
        3. Run unsafe delete
        4. Wait until machine deleted
    """

    with GracefulMachineDeletionPolicyCheck(cluster, machine) as policy_check:
        machine_name = machine.name
        node_name = machine.get_k8s_node_name()

        show_step(1)
        machine.set_deletion_policy('graceful')
        machine.delete(check_deletion_policy=False, new_delete_api=True)

        show_step(2)
        policy_check.check_delete_locked()

        show_step(3)
        machine.set_deletion_policy('unsafe')

        show_step(4)
        cluster.wait_machine_deletion(machine_name, retries=60)
        cluster.check.check_deleted_node(node_name)


@subtest
def check_machine_graceful_delete_stuck_complete_with_forced(cluster, machine, show_step):
    """Delete machine, check deletion policy.

    Scenario:
        1. Check forced policy forbidden to set initially
        2. Run graceful delete
        3. Wait until delete locked by NodeWorkloadLock
        4. Run forced delete
        5. Wait until machine deleted
    """

    show_step(1)
    with pytest.raises(ApiException) as exc_info:
        machine.set_deletion_policy('forced')
    assert exc_info.value.status == 400
    assert exc_info.value.reason == 'Bad Request'

    with GracefulMachineDeletionPolicyCheck(cluster, machine) as policy_check:
        machine_name = machine.name
        node_name = machine.get_k8s_node_name()

        show_step(2)
        machine.set_deletion_policy('graceful')
        machine.delete(check_deletion_policy=False, new_delete_api=True)

        show_step(3)
        policy_check.check_delete_locked()

        show_step(4)
        machine.set_deletion_policy('forced')

        show_step(5)
        cluster.wait_machine_deletion(machine_name, retries=60)
        cluster.check.check_deleted_node(node_name)


check_machine_graceful_delete = functools.partial(check_machine_deletion_policy, deletion_policy='graceful')
if settings.CHECK_WORKLOAD_LOCKS_ON_MACHINE_DELETE:
    check_machine_graceful_delete = check_machine_graceful_delete_with_abort

check_machine_unsafe_delete = functools.partial(check_machine_deletion_policy, deletion_policy='unsafe')
if settings.CHECK_WORKLOAD_LOCKS_ON_MACHINE_DELETE:
    check_machine_unsafe_delete = check_machine_graceful_delete_stuck_complete_with_unsafe

check_machine_forced_delete = functools.partial(check_machine_deletion_policy, deletion_policy='forced')
if settings.CHECK_WORKLOAD_LOCKS_ON_MACHINE_DELETE:
    check_machine_forced_delete = check_machine_graceful_delete_stuck_complete_with_forced
