#    Copyright 2025 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import cachetools.func as cachetools_func
import os
import yaml

from si_tests import logger
from si_tests import settings
from si_tests.utils import waiters

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from si_tests.managers.kaas_manager import Cluster

LOG = logger.logger


class NetcheckerManager(object):
    """NetChecker manager"""

    def __init__(self, cluster: "Cluster"):
        self._cluster: "Cluster" = cluster
        self.inventory_config_name = settings.NETCHECKER_INVENTORY_CONFIG_NAME
        self.targets_config_name = settings.NETCHECKER_TARGETS_CONFIG_NAME
        self.netchecker_namespace = settings.NETCHECKER_NAMESPACE
        self.netchecker_config_path = settings.NETCHECKER_FILE_PATH
        self.netchecker_obj_name = settings.NETCHECKER_OBJECT_NAME
        self.cnnc_agent_configmap_name = 'cnnc-agent'

    @property
    def cluster(self) -> "Cluster":
        return self._cluster

    @property
    @cachetools_func.ttl_cache(ttl=300)
    def k8sclient(self):
        return self._cluster.k8sclient

    @property
    def infraconnectivitymonitors(self):
        return self._cluster._manager.api.kaas_infraconnectivitymonitors

    @property
    def netcheckertargetsconfigs(self):
        return self._cluster.k8sclient.cnnc_netcheckertargetsconfigs

    @property
    def checkerinventoryconfigs(self):
        return self._cluster.k8sclient.cnnc_checkerinventoryconfigs

    def create_infraconnectivitymonitor(self, data=None):
        return self.infraconnectivitymonitors.create(namespace=self.cluster.namespace, body=data)

    def get_infraconnectivitymonitor(self, name=None):
        return self.infraconnectivitymonitors.get(namespace=self.cluster.namespace, name=name)

    def get_targets_config(self, name=None):
        return self.netcheckertargetsconfigs.get(namespace=self.netchecker_namespace, name=name)

    def get_targets_for_source(self, source_node=None, targets_config_name=None):
        if not source_node:
            LOG.warning("Source node was not given. Can't get targets")
            return []
        all_targets = self.get_targets_config(name=targets_config_name).data['spec'].get('nodesConfig', [])
        target_for_source = [source for source in all_targets if source['nodeName'] == source_node]
        if not target_for_source:
            LOG.error(f"Targets not found for source node {source_node}")
            return []
        return target_for_source[0].get('targets', [])

    def get_targets_count_for_source(self, source_node=None, targets_config_name=None):
        if not source_node:
            LOG.warning("Source node was not given. Can't get targets count")
            return 0
        targets_config_nodes_status = self.get_targets_config(
            name=targets_config_name).data.get('status', {}).get('nodes', [])
        targets_count = [node['targetsCount'] for node in targets_config_nodes_status if node['name'] == source_node]
        if not targets_count:
            LOG.error(f"Can't get targets count for node {source_node}")
            return 0
        return int(targets_count[0])

    def get_inventory_config(self, name=None):
        return self.checkerinventoryconfigs.get(namespace=self.netchecker_namespace, name=name)

    def get_inventory_for_node(self, node_name=None, inventory_config_name=None):
        inventory = self.get_inventory_config(name=inventory_config_name)
        inventory_data = inventory.data
        inventory_name = inventory.name
        all_nodes_config = inventory_data.get('spec', {}).get('nodesConfig', [])
        expected_subnets_node_config = [
            conf.get('expectedSubnetTags') for conf in all_nodes_config if conf['nodeName'] == node_name]
        err_msg = ''
        if not expected_subnets_node_config:
            err_msg += (f"Subnets config not found for node {node_name} in netchecker inventory "
                        f"config {inventory_name}.\n")
        nodes_status = inventory_data.get('status', {}).get('nodes', [])
        ips_for_node = [data['ipAddresses'] for data in nodes_status if data['name'] == node_name]
        if not ips_for_node:
            err_msg += (f"Ips config not found for node {node_name} in netchecker inventory "
                        f"config {inventory_name}.\n")
        if err_msg:
            raise AssertionError(f"\n{err_msg}")
        return {'expectedSubnetTags': expected_subnets_node_config[0], 'ipAddresses': ips_for_node[0]}

    def get_values_from_cnnc_configmap_raw_data(self, param_list):
        configmap_ns = self.netchecker_namespace
        configmap_name = self.cnnc_agent_configmap_name
        configmap_options = {}
        configmap = self.cluster.k8sclient.configmaps.get(namespace=configmap_ns, name=configmap_name)
        conf_data = configmap.data.get('data', {}).get('cloudprober.cfg', '')
        for line in conf_data.strip().splitlines():
            line = line.strip()
            for param in param_list:
                if line.startswith(param):
                    configmap_options[param] = int(line.split(':', 1)[1].strip())
        return configmap_options

    def get_netchecker_cloudprober_overrides(self):
        charts_overrides = self.cluster.data['spec']['providerSpec']['value'].get('helmReleases', [])
        cnnc_overrides = [chart for chart in charts_overrides if chart.get('name') == 'cnnc-agent']
        return cnnc_overrides[0] if cnnc_overrides else {}

    def apply_cnnc_agent_chart_overrides(self, options):
        """
        Setup helm overrides for cloudprober
        :param options: map of key:value options for configmap
        :return: patched cluster object
        """

        cnnc_chart_overrides = {
            'name': 'cnnc-agent',
            'values': {
                'probe_options': options
            }
        }

        existed_overrides = self.get_netchecker_cloudprober_overrides()
        helm_charts = self.cluster.data['spec']['providerSpec']['value'].get('helmReleases', [])

        if not existed_overrides:
            helm_charts.append(cnnc_chart_overrides)
        else:
            for chart in helm_charts:
                if chart.get('name', '') == 'cnnc-agent':
                    chart["values"]["probe_options"] = options
                    break
        patch_data = {
            'spec': {
                'providerSpec': {
                    'value': {
                        'helmReleases': helm_charts
                    }
                }
            }
        }
        return self.cluster.patch(patch_data)

    def check_cnnc_agent_chart_values_applied(self, params_mapping, expected_values):
        """
        Check if values are applied in configmap
        :param params_mapping: As we have differnet names in overrides and actual configmap, we need to have actual
                               parameters mapping, e.g. {'ping_interval_msec': 'interval_msec',
                               'ping_timeout_msec': 'timeout_msec', 'any_override_param': 'param_from_configmap'}
        :param expected_values: dict with key:value of expected parameters:values
        :return: bool
        """
        not_applied_params = {}
        configmap_values = self.get_values_from_cnnc_configmap_raw_data(params_mapping.values())

        for param_name, param_value in expected_values.items():
            configmap_param_name = params_mapping[param_name]
            configmap_param_value = configmap_values[configmap_param_name]

            if param_value != configmap_param_value:
                not_applied_params[param_name] = {
                    'actual_value': configmap_param_value,
                    'expected_value': param_value
                }

        if not_applied_params:
            LOG.warning(f"Next parameters are not applied yet:\n{yaml.dump(not_applied_params)}")
            return False
        return True

    def wait_cnnc_agent_chart_values_applied(self, params_mapping, expected_values, timeout=600, interval=30):
        waiters.wait(lambda: self.check_cnnc_agent_chart_values_applied(params_mapping=params_mapping,
                                                                        expected_values=expected_values),
                     timeout=timeout, interval=interval)

    def compare_targets(self, request_object_name, machines=None):
        """
        This check compares targets from NetCheckerTargetsConfig, CheckerInventoryConfig and actual cluster machines
        """
        if not machines:
            machines = self.cluster.get_machines()

        netchecker_obj = self.get_infraconnectivitymonitor(name=request_object_name)
        LOG.banner("Compare targets from inventory, targets and actual machines")
        LOG.info(f"infraconnectivitymonitor:\n{yaml.dump(netchecker_obj.data)}")

        icm_status = netchecker_obj.data.get('status', {})
        inv_status = icm_status.get('inventoryConfigStatus', {}).get('nodes', [])
        targets_status = icm_status.get('targetsConfigStatus', {}).get('nodes', [])
        inv_status_err_msg = (f"Field 'status.inventoryConfigStatus.nodes is empty in "
                              f"InfraConnectivityMonitor object named {request_object_name}")
        assert inv_status, inv_status_err_msg
        targets_status_err_msg = (f"Field 'status.targetsConfigStatus.nodes is empty in "
                                  f"InfraConnectivityMonitor object named {request_object_name}")
        assert targets_status, targets_status_err_msg

        machines_from_inv_status = [s['machineName'] for s in inv_status if s['status'] != 'Maintenance']
        machines_from_targets_status = [s['machineName'] for s in targets_status]
        machines_names_to_check = [m.name for m in machines if not m.is_in_maintenance]

        machine_selector = netchecker_obj.data['spec'].get('machineSelector', {}).get('matchLabels', {})
        if machine_selector:
            LOG.info(
                f"Machine selector is used in infraconnectivitymonitor. "
                f"Only machines with given label {machine_selector} will be used for checking")
            l_name, = machine_selector
            machines_names_to_check = [m.name for m in self.cluster.get_machines_by_label(l_name)
                                       if not m.is_in_maintenance]

        machines_message = (f"Machines from netchecker inventory:\n{yaml.dump(machines_from_inv_status)}\n"
                            f"Machines from targets status:\n{yaml.dump(machines_from_targets_status)}\n"
                            f"Machines to check from cluster:\n{yaml.dump(machines_names_to_check)}")

        err_msg = (f"Machines from infraconnectivitymonitor are not equal to actual cluster machines.\n"
                   f"{machines_message}")
        assert set(machines_from_inv_status) == set(machines_from_targets_status) == set(
            machines_names_to_check), err_msg
        LOG.info(f"All targets as expected:\n{machines_message}")

    def wait_compare_targets(self, netchecker_obj_name, machines=None, timeout=300, interval=20):
        """Wait for equality of machine names listed in InfraConnectivityMonitor object status field
        (under both inventoryConfigStatus and targetsConfigStatus subfields) and machine names of
        Machine objects of the current cluster
        Args:
            request_object_name: InfraConnectivityMonitor object name
            machines: list of Machine objects to compare with. All cluster machines are used if None.
            If InfraConnectivityMonitor has 'matchSelector' field, and this parameter is None, only Machine objects
            with corresponding labels will be selected to compare with
            timeout: timeout to wait
            interval: time between checks
        Returns: None
        """
        timeout_msg = (f"Machines from InfraConnectivityMonitor are not equal to actual cluster machines after "
                       f"{timeout} seconds")
        waiters.wait_pass(lambda: self.compare_targets(request_object_name=netchecker_obj_name, machines=machines),
                          timeout=timeout,
                          interval=interval,
                          timeout_msg=timeout_msg)

    def compare_subnets(self, inventory_config_name, machines=None):
        """
        This check compares subnets from CheckerInventoryConfig and actual used subnets
        """
        if not machines:
            machines = self.cluster.get_machines()

        LOG.banner("Compare actual used subnets with subnets from netchecker inventory")
        cluster = self.cluster
        all_subnets = cluster._manager.get_ipam_subnets(namespace=cluster.namespace)
        failed_subnets = {}
        inventory_nodes_config = self.get_inventory_config(
            name=inventory_config_name).data.get('spec').get('nodesConfig')

        for machine in machines:
            k8s_node_name = machine.get_k8s_node_name()
            l2t = cluster._manager.get_l2template(machine.l2template_name, namespace=cluster.namespace)
            l2t_layout = l2t.data['spec'].get('l3Layout', [])

            labels_to_check = []
            subnets_from_inventory_conf = []
            subnets_to_compare = []

            for lay in l2t_layout:
                if not lay.get('labelSelector', []):
                    sub_name = lay.get('subnetName')
                    subnets_to_compare.append(sub_name)
                for item in lay.get('labelSelector', []):
                    if item.startswith('ipam/'):
                        labels_to_check.append(item)

            for node in inventory_nodes_config:
                if node.get('nodeName') == k8s_node_name:
                    subnets_from_inventory_conf.extend([i.split('/')[-1] for i in node.get('expectedSubnetTags')])

            for label in labels_to_check:
                labled_subnets = [s.name for s in all_subnets if label in s.data['metadata']['labels']]
                subnets_to_compare.extend(labled_subnets)

            if set(subnets_to_compare) != set(subnets_from_inventory_conf):
                failed_subnets[machine.name] = {
                    'subnets_from_inventory': subnets_from_inventory_conf,
                    'subnets_from_l2t': subnets_to_compare
                }

            LOG.info(f"\nMachine {machine.name} subnets:\n{yaml.dump(subnets_to_compare)}\n"
                     f"Subnets from inventory:\n{yaml.dump(subnets_from_inventory_conf)}")

        assert not failed_subnets, (f"Subnets from netchecker inventory and subnets from l2 template are different:\n"
                                    f"{yaml.dump(failed_subnets)}")

    def compare_ips(self, inventory_config_name, machines=None):
        """
        This check compares ipaddresses from CheckerInventoryConfig and actual ips from machine netplan
        """
        if not machines:
            machines = self.cluster.get_machines()

        k8s_nodes_machines_map = {m.get_k8s_node_name(): m for m in machines}
        not_monitored = {}

        LOG.banner("Compare IPaddresses from netplan and netchecker inventory")

        for node_name, machine in k8s_nodes_machines_map.items():
            machine_name = machine.name
            if machine.is_in_maintenance:
                LOG.info(f"Machine {machine_name} is in maintenance mode. Skipping")
                continue

            ips_from_inventory = self.get_inventory_for_node(
                node_name=node_name, inventory_config_name=inventory_config_name).get('ipAddresses', [])
            ips_from_inventory = [data.get('ipAddress', '') for data in ips_from_inventory]
            ips_from_machine_netplan = machine.get_machine_ipaddresses_from_netplan()

            LOG.info(f"\nIPs for machine {machine_name} from netchecker inventory:\n{yaml.dump(ips_from_inventory)}\n"
                     f"IPs for machine {machine_name} from netplan:\n"
                     f"{yaml.dump(list(ips_from_machine_netplan.values()))}")

            for iface, ip in ips_from_machine_netplan.items():
                if ip not in ips_from_inventory:
                    not_monitored.setdefault(machine_name, {}).update({iface: ip})

        if not_monitored:
            raise AssertionError(f"Some ips are not monitored by netchecker.\n{yaml.dump(not_monitored)}")

    def save_configs(self):
        cl_name = self.cluster.name
        netchecker_obj_data = self.get_infraconnectivitymonitor(name=settings.NETCHECKER_OBJECT_NAME).data
        inventory_config = self.get_inventory_config(name=settings.NETCHECKER_INVENTORY_CONFIG_NAME).data
        targets_config = self.get_targets_config(name=settings.NETCHECKER_TARGETS_CONFIG_NAME).data
        infraconnectivitymonitor_info_path = os.path.join(
            settings.ARTIFACTS_DIR, self.netchecker_obj_name + '-' + cl_name + '.yaml')
        inventory_nodes_config_path = os.path.join(
            settings.ARTIFACTS_DIR, self.inventory_config_name + '-' + cl_name + '.yaml')
        targets_config_path = os.path.join(
            settings.ARTIFACTS_DIR,  self.targets_config_name + '-' + cl_name + '.yaml')
        netchecker_objects_configs_specs = {'infraconnectivitymonitor_specs': netchecker_obj_data['spec'],
                                            'inventory_nodes_config_specs': inventory_config['spec'],
                                            'targets_config_specs': targets_config['spec']}
        netchecker_objects_configs_specs_file_path = os.path.join(
            settings.ARTIFACTS_DIR, 'netchecker-objects-specs-aggregated.yaml')

        LOG.info(f"Saving inventory_nodes_config to {inventory_nodes_config_path}")
        with open(inventory_nodes_config_path, 'w') as f:
            f.write(yaml.dump(inventory_config))

        LOG.info(f"Saving targets_config to {targets_config_path}")
        with open(targets_config_path, 'w') as f:
            f.write(yaml.dump(targets_config))

        LOG.info(f"Saving request object to {infraconnectivitymonitor_info_path}")
        with open(infraconnectivitymonitor_info_path, 'w') as f:
            f.write(yaml.dump(netchecker_obj_data))

        LOG.info(f"Saving objects specs to file {netchecker_objects_configs_specs_file_path}")
        with open(netchecker_objects_configs_specs_file_path, 'w') as f:
            f.write(yaml.dump(netchecker_objects_configs_specs))

        if self.get_netchecker_cloudprober_overrides():
            overrides_config_path = os.path.join(
                settings.ARTIFACTS_DIR, 'netchecker-cloudprober-overrides-data.yaml')
            LOG.info(f"Saving cnnc-agent overrides data to {overrides_config_path}")
            with open(overrides_config_path, 'w') as f:
                f.write(yaml.dump(self.get_netchecker_cloudprober_overrides()))
        return netchecker_objects_configs_specs

    @staticmethod
    def check_infraconnectivitymonitor_status(icm,
                                              expected_inventoryconfig_status='ok',
                                              expected_targetsconfig_status='ok',
                                              overall_status_only=False):
        status = icm.data.get('status', {}) or {}
        if not status:
            LOG.error("Empty 'status' section")
            return False
        inventory_config_status = status.get('inventoryConfigStatus', {})
        if not inventory_config_status:
            LOG.error("Empty 'status.inventoryConfigStatus' section")
            return False
        targets_config_status = status.get('targetsConfigStatus', {})
        if not targets_config_status:
            LOG.error("Empty 'status.targetsConfigStatus' section")
            return False
        inventory_config_overall_status = inventory_config_status.get('statusMsg', '')
        targets_config_overall_status = targets_config_status.get('statusMsg', '')
        overall_status_pass = (inventory_config_overall_status == expected_inventoryconfig_status and
                               targets_config_overall_status == expected_targetsconfig_status)
        if overall_status_only:
            if not overall_status_pass:
                LOG.warning(f"Overall status is not as expected.\n"
                            f"Inventory config status: \'{inventory_config_overall_status}\' "
                            f"but expected \'{expected_inventoryconfig_status}\'.\n"
                            f"Targets config status: \'{targets_config_overall_status}\' "
                            f"but expected \'{expected_targetsconfig_status}\'")
                return False
            return True
        ic_nodes_wrong_status = {}
        tc_nodes_wrong_status = {}
        ic_nodes_status_data = status.get('inventoryConfigStatus', {}).get('nodes', [])
        tc_nodes_status_data = status.get('targetsConfigStatus', {}).get('nodes', [])
        if not ic_nodes_status_data:
            LOG.error("Empty 'status.inventoryConfig.nodes' section")
            return False
        if not tc_nodes_status_data:
            LOG.error("Empty 'status.targetsConfigStatus.nodes' section")
            return False
        for node in ic_nodes_status_data:
            node_status = node.get('status', '')
            if node_status != 'ok':
                ic_nodes_wrong_status[node.get('machineName', '')] = node_status
        for node in tc_nodes_status_data:
            node_status = node.get('status', '')
            if node_status != 'ok':
                tc_nodes_wrong_status[node.get('machineName', '')] = node_status
        if ic_nodes_wrong_status or tc_nodes_wrong_status:
            LOG.error(f"Unexpected status for some nodes.\n"
                      f"Inventory config nodes status data:\n{yaml.dump(ic_nodes_status_data)}\n"
                      f"Targets config nodes status data:\n{yaml.dump(tc_nodes_status_data)}")
            if not overall_status_pass:
                LOG.error(f"Overall status is not as expected too.\nInvenotory config overall status is: "
                          f"{inventory_config_overall_status} but expected \'{expected_inventoryconfig_status}\'\n"
                          f"Targets config overall status is: {targets_config_overall_status} but expected "
                          f"\'{expected_targetsconfig_status}\'")
            return False
        LOG.info("All statuses as expected")
        return True

    def wait_infraconnectivitymonitor_status(self, icm, timeout=300, interval=20,
                                             expected_inventoryconfig_status='ok',
                                             expected_targetsconfig_status='ok',
                                             overall_status_only=False):
        """Wait for good status of InfraConnectivityMonitor object
        Args:
            icm: InfraConnectivityMonitor object
            timeout: timeout to wait
            interval: time between checks
            expected_inventoryconfig_status: status of inventory config status block
            expected_targetsconfig_status: status of targets config status block
            overall_status_only: If True, then only overall block status will be checked
        Returns: None
        """
        timeout_msg = (f"InfraConnectivityMonitor object status fields have not been set "
                       f"to expected value:\nExpected targets_config_status '{expected_targetsconfig_status}'\n"
                       f"Expected inventory_config_status: '{expected_inventoryconfig_status}'")
        waiters.wait(lambda: self.check_infraconnectivitymonitor_status(
            icm=icm,
            expected_inventoryconfig_status=expected_inventoryconfig_status,
            expected_targetsconfig_status=expected_targetsconfig_status,
            overall_status_only=overall_status_only),
                     timeout=timeout,
                     interval=interval,
                     timeout_msg=timeout_msg)
