#    Copyright 2025 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.
#
# This file is designed as common module for built-in k8s client that does not depends on parent of k8s cluster
# object and can be used in different places with only one requirement - K8sCluster object.
import re
import base64

from si_tests import logger
from si_tests.utils import waiters, exceptions
from si_tests import settings
from kubernetes.client.rest import ApiException

from typing import TYPE_CHECKING, Optional, List, Tuple

if TYPE_CHECKING:
    from si_tests.clients.k8s import K8sCluster

LOG = logger.logger


def __get_replica_number(client: "K8sCluster", obj, namespace):
    """Return desired number of replicas/scheduled pods according to
       the pod owner specification (e.g. statefulset).
    """
    kind, name = obj.split('/')
    desired_num = 0
    if kind == 'ReplicaSet':
        desired_num = client.replicasets.get(
            name=name, namespace=namespace).desired_replicas
    elif kind == 'StatefulSet':
        desired_num = client.statefulsets.get(
            name=name, namespace=namespace).desired_replicas
    elif kind == 'DaemonSet':
        desired_num = client.daemonsets.get(
            name=name, namespace=namespace).desired_replicas
    else:
        LOG.warning(f"Unknown kind {kind}")
    return desired_num


def compare_objects(client: "K8sCluster",
                    expected_pods_dict,
                    exclude_jobs,
                    exclude_removed_nodes,
                    exclude_incorrect_affinity,
                    check_all_nss):
    failed = {}
    namespaces = expected_pods_dict.keys()
    actual_list = [pod for pod in
                   client.pods.list_raw().to_dict()['items']]

    if exclude_jobs:
        before_filtering = \
            set([x['metadata']['name'] for x in actual_list])
        actual_list = \
            [x for x in actual_list if not (
                    x['metadata']['owner_references'] and
                    x['metadata'][
                        'owner_references'][0]['kind'] == 'Job')]
        after_filtering = \
            set([x['metadata']['name'] for x in actual_list])
        LOG.debug(f"These pods are jobs and will be filtered out: "
                  f"{before_filtering - after_filtering}")

    if exclude_removed_nodes:
        # After cluster scale down, some pods may be still mapped on the
        # removed nodes. Garbage collector may clean these pods in few hours
        # but we can ignore such pods here
        node_names = [node.name for node in client.nodes.list_all()]
        before_filtering = set([x['metadata']['name'] for x in actual_list])

        new_actual_list = []
        filtered_nodes = set()
        for pod in actual_list:
            pod_name = f"{pod['metadata']['namespace']}/{pod['metadata']['name']}"

            node_name = pod['spec']['node_name']
            if node_name is not None and node_name not in node_names:
                LOG.warning(f"Pod {pod_name} node name is bound to non-existing node {node_name}, "
                            f"skip check the pod")
                filtered_nodes.add(node_name)
                continue

            node_selector = pod['spec']['node_selector'] or {}
            node_selector_hostname = node_selector.get('kubernetes.io/hostname')
            if node_selector_hostname is not None and node_selector_hostname not in node_names:
                LOG.warning(f"Pod {pod_name} node selector is bound to non-existing hostname "
                            f"{node_selector_hostname}, skip check the pod")
                filtered_nodes.add(node_selector_hostname)
                continue

            new_actual_list.append(pod)
        actual_list = new_actual_list
        filtered_nodes = list(filtered_nodes)

        after_filtering = set([x['metadata']['name'] for x in actual_list])

        if before_filtering - after_filtering:
            LOG.warning(f"These pods will be ignored because assigned to non-existing nodes "
                        f"{filtered_nodes} : {before_filtering - after_filtering}")

    if exclude_incorrect_affinity:
        # After changing node labels, some pods may no longer match the node affinity
        before_filtering = set([x['metadata']['name'] for x in actual_list])

        actual_list = [x for x in actual_list
                       if not (x['status'].get('phase', '') == 'Failed'
                               and x['status'].get('reason') == 'NodeAffinity')]

        after_filtering = set([x['metadata']['name'] for x in actual_list])

        if before_filtering - after_filtering:
            LOG.warning(f"These pods will be ignored because of wrong NodeAffinity: "
                        f"{before_filtering - after_filtering}")

    # constant
    excluded_pods = [
        'conformance',
    ]
    # special exclusions
    for excl_pod in excluded_pods:
        actual_list = \
            [x for x in actual_list
             if not x['metadata']['name'].startswith(excl_pod)]

    other_ns_pods = [pod for pod in actual_list
                     if pod['metadata']['namespace'] not in namespaces]
    actual_list = [pod for pod in actual_list
                   if pod['metadata']['namespace'] in namespaces]

    not_checked = [pod['metadata']['name'] for pod in actual_list]

    # sort names of expected pods so that
    # long names come first. "/no_owner" tag is excluded
    for ns in namespaces:
        expected_pods_lst = sorted(
            list(expected_pods_dict[ns].keys()),
            key=lambda x: len(x.split("/")[0]), reverse=True
        )
        for pod in expected_pods_lst:
            desired_num = expected_pods_dict[ns][pod]
            prefix = pod.split("/")[0]
            compare_list = \
                [pod for pod in actual_list
                 if pod['metadata']['name'].startswith(prefix) and
                 pod['metadata']['name'] in not_checked and
                 pod['metadata']['namespace'] == ns]

            if not compare_list:
                failed[prefix] = {"actual": 0,
                                  "desired/expected": desired_num}
                continue

            if '/no_owner' in pod:
                # for pods that are marked "no_owner" we do not fetch
                # owner and use # of pods from the file
                LOG.debug(f"Number of pods "
                          f"for {prefix} group will be checked "
                          f"according to expected pods list")
            else:
                # get owner kind and name for the first pod in list
                first_pod = compare_list[0]
                owner_references = \
                    first_pod['metadata']['owner_references']
                kind_name = f"{owner_references[0]['kind']}/" \
                            f"{owner_references[0]['name']}"

                if kind_name.split('/')[0] == "Node":
                    # Starting from 1.17 mirror pods have
                    # Node/node-name in owner_references
                    LOG.warning(f"No replica count info for pod "
                                f"{first_pod['metadata']['name']}. "
                                f"Using expected number ({desired_num}) "
                                f"from the list")
                    # Better to handle this type of pods as '/no_owner'
                    LOG.info("Please add '/no_owner' for this group of"
                             " pods to expected pods list file")
                else:
                    try:
                        replicas_num = __get_replica_number(
                            client,
                            kind_name,
                            first_pod['metadata']['namespace']
                        )
                        LOG.debug(f"First pod is "
                                  f"{first_pod['metadata']['name']}, "
                                  f"owner: {kind_name}, "
                                  f"replica #: {replicas_num}")
                        if int(replicas_num) != int(desired_num):
                            if kind_name.startswith("DaemonSet"):
                                LOG.warning(f"Replicas num ({replicas_num}) from {kind_name} "
                                            f"is not equal to the "
                                            f"number ({desired_num}) in expected "
                                            f"pod list. Pod: {prefix} . "
                                            f"Assuming DaemonSet replicas as expected pods number")
                                desired_num = int(replicas_num)
                            else:
                                LOG.warning(f"Replicas num ({replicas_num}) from {kind_name} "
                                            f"is not equal to the "
                                            f"number ({desired_num}) in expected "
                                            f"pod list. Pod: {prefix}")
                            # PRODX-30560, DaemonSets may have different replicas
                            # on different labs because of node taints or labels.
                            # So do not fail the check until found a better solution.
                            # failed[prefix] = {"object replicas": replicas_num,
                            #                  "desired/expected": desired_num}
                    except (Exception, ApiException) as e:
                        LOG.error(e)
                        LOG.error(f"Cannot process {prefix} "
                                  f"group of pods. Skipping")
                        failed[prefix] = {"actual": 0,
                                          "desired/expected": desired_num}
                        continue
            if int(desired_num) != len(compare_list):
                failed[prefix] = {"actual": len(compare_list),
                                  "desired/expected": desired_num}
            not_checked = [x for x in not_checked
                           if x not in [pod['metadata']['name']
                                        for pod in compare_list]]

            actual_ns = set([pod['metadata']['namespace']
                             for pod in compare_list])
            if set([ns]) != actual_ns:
                failed[prefix] = {"actual namespace": actual_ns,
                                  "desired/expected": ns}

    if other_ns_pods:
        for pod in other_ns_pods:
            LOG.error(f"Extra pod {pod['metadata']['name']} "
                      f"found in {pod['metadata']['namespace']} "
                      f"namespace")
            if check_all_nss:
                not_checked.append(pod['metadata']['name'])

    if failed or not_checked:
        result = {"Pods mismatch": failed,
                  "Not checked pods": not_checked}
        LOG.warning(f"Compare pod check failed: {result}")
        return result


def wait_expected_pods(client: "K8sCluster",
                       timeout=settings.CHECK_ACTUAL_EXPECTED_PODS_TIMEOUT,
                       interval=10,
                       expected_pods=None,
                       exclude_jobs=True,
                       exclude_removed_nodes=True,
                       exclude_incorrect_affinity=True,
                       check_all_nss=False):
    try:
        waiters.wait(lambda: not compare_objects(
            client,
            expected_pods_dict=expected_pods,
            exclude_jobs=exclude_jobs,
            exclude_removed_nodes=exclude_removed_nodes,
            exclude_incorrect_affinity=exclude_incorrect_affinity,
            check_all_nss=check_all_nss),
                     timeout=timeout, interval=interval)
    except exceptions.TimeoutError:
        result = compare_objects(client,
                                 expected_pods_dict=expected_pods,
                                 exclude_jobs=exclude_jobs,
                                 exclude_removed_nodes=exclude_removed_nodes,
                                 exclude_incorrect_affinity=exclude_incorrect_affinity,
                                 check_all_nss=check_all_nss)
        if result:
            err = f"Timeout waiting for pods. " \
                  f"After {timeout}s there are some fails: " \
                  f"{result}"
            raise TimeoutError(err)
    LOG.info("All pods and their replicas are found")


def get_corefile(client: "K8sCluster", name='coredns', namespace='kube-system'):
    cm = client.configmaps.get(name, namespace)
    return cm.read().data["Corefile"]


def override_coredns_configmap(client: "K8sCluster", corefile, name='coredns', namespace='kube-system'):
    """Update the coredns configmap with new Corefile

    :param namespace: cm namespace
    :param name: cm name
    :param client: k8s client
    :param corefile: corefile
    :return:
    """
    assert corefile, "Corefile can not be empty"
    LOG.info(f"Overriding Corefile:\n {corefile}")
    cm = client.configmaps.get(name, namespace)
    cm_data = cm.data
    cm_data['data']["Corefile"] = corefile
    cm.patch(body=cm_data)


def get_kubeconfig_from_secret(client: "K8sCluster", secret=None, cluster_name='', namespace=settings.KCM_NAMESPACE):
    """Get kubeconfig for the cluster from secret

    :param secret: secret object
    :param namespace: Cluster namespace
    :param cluster_name: Cluster deployment name
    :param client: k8s Cluster client
        Returns tuple with kubeconfig name and content in yaml format
    """
    secret_name = secret.name or "{}-kubeconfig".format(cluster_name)
    LOG.debug("Fetching kubeconfig of {} cluster deployment".format(cluster_name))
    assert client.secrets.present(secret_name, namespace=namespace), (
        f"Secret {namespace}/{secret_name} not found. May be not populated yet.")
    kubeconfig_secret = client.secrets.get(
        secret_name, namespace=namespace)
    kubeconfig = kubeconfig_secret.read()
    assert kubeconfig.data, f"Secret {namespace}/{secret_name} is empty"
    kubeconfig = kubeconfig.data.get('value')
    assert kubeconfig, f"Secret {namespace}/{secret_name} doesn't contain kubeconfig"
    kubeconfig = base64.b64decode(kubeconfig).decode("utf-8")
    name = secret_name
    content = kubeconfig
    return name, content


class CorefileEditor:
    def __init__(self, corefile):
        self.original_corefile = corefile
        self.corefile = corefile

    def find_zone_block(self, zone=".:53") -> Optional[re.Match]:
        pattern = rf"({re.escape(zone)}\s*\{{)(.*?)(\n\}})"
        return re.search(pattern, self.corefile, flags=re.DOTALL)

    def has_plugin(self, plugin_name, zone=".:53") -> bool:
        match = self.find_zone_block(zone)
        if not match:
            return False
        _, body, _ = match.groups()
        return plugin_name in body

    def insert_plugin(self, plugin_block, before_plugin="forward", zone=".:53") -> bool:
        match = self.find_zone_block(zone)
        if not match:
            raise ValueError(f"Zone block '{zone}' not found in Corefile")

        head, body, tail = match.groups()

        if plugin_block.strip() in body:
            return False  # Already present

        # Try to insert before specific plugin (like 'forward')
        if before_plugin in body:
            new_body = re.sub(rf"(.*{before_plugin}.*?)", plugin_block + r"\n\1", body, flags=re.DOTALL)
        else:
            new_body = body + "\n" + plugin_block

        new_zone_block = f"{head}{new_body}{tail}"
        self.corefile = self.corefile.replace(match.group(0), new_zone_block)

    def insert_hosts_entry(self, ip, domain, before_plugin="forward", zone=".:53") -> bool:
        return self.insert_hosts_entries([(ip, domain)], before_plugin, zone)

    def insert_hosts_entries(self, entries: List[Tuple[str, str]], before_plugin="forward", zone=".:53"):
        host_lines = [f"        {ip} {domain}" for ip, domain in entries]
        plugin_block = """
    hosts {
%s
        fallthrough
    }""" % ("\n".join(host_lines))
        return self.insert_plugin(plugin_block, before_plugin, zone)

    def remove_hosts_entries(self, entries: List[Tuple[str, str]], zone=".:53") -> bool:
        match = self.find_zone_block(zone)
        if not match:
            LOG.info(f"Zone block '{zone}' not found in Corefile")
            return

        head, body, tail = match.groups()

        def remove_lines(block) -> str:
            for ip, domain in entries:
                line_pattern = re.escape(f"{ip} {domain}")
                block = re.sub(rf"^\s*{line_pattern}\s*$", "", block, flags=re.MULTILINE)
            return block

        # Remove entries only from the "hosts { ... }" plugin
        new_body = re.sub(r'(hosts\s*\{.*?\n\})', lambda m: remove_lines(m.group(0)), body, flags=re.DOTALL)

        if new_body == body:
            return  # No change

        new_zone_block = f"{head}{new_body}{tail}"
        self.corefile = self.corefile.replace(match.group(0), new_zone_block)

    def remove_hosts_plugin(self, zone=".:53"):
        match = self.find_zone_block(zone)
        if not match:
            LOG.info(f"Zone block '{zone}' not found in Corefile")
            return

        head, body, tail = match.groups()
        new_body = re.sub(r'\s*hosts\s*\{.*?\n\}', '', body, flags=re.DOTALL)

        if new_body == body:
            LOG.info('No changes to Corefile detected')
            return  # No change

        new_zone_block = f"{head}{new_body}{tail}"
        self.corefile = self.corefile.replace(match.group(0), new_zone_block)

    def remove_plugin(self, plugin_name, zone=".:53"):
        match = self.find_zone_block(zone)
        if not match:
            LOG.info(f"Zone block '{zone}' not found in Corefile")
            return

        head, body, tail = match.groups()
        pattern = rf"\s*{plugin_name}\s*\{{.*?\n\}}"
        new_body = re.sub(pattern, "", body, flags=re.DOTALL)

        if new_body == body:
            LOG.info('No changes to Corefile detected')
            return False  # No change

        new_zone_block = f"{head}{new_body}{tail}"
        self.corefile = self.corefile.replace(match.group(0), new_zone_block)

    def get(self) -> str:
        return self.corefile

    def reset(self):
        self.corefile = self.original_corefile
