#    Author: Alex Savatieiev (osavatieiev@mirantis.com; a.savex@gmail.com)
#    Copyright 2019-2022 Mirantis, Inc.
from cfg_checker.common import logger_cli
from cfg_checker.modules.network.mapper import SaltNetworkMapper, \
    KubeNetworkMapper
from cfg_checker.modules.network.network_errors import NetworkErrors
from cfg_checker.reports import reporter


class NetworkChecker(object):
    def __init__(self):
        logger_cli.debug("... init error logs folder")
        self.errors = NetworkErrors()

    def _check_duplicate_ips(self):
        # shortcuts
        logger_cli.debug("... checking for duplicate ips")
        _map = self.mapper.map
        for _net in _map.keys():
            _ips = set()
            for _node_name, _interfaces in _map[_net].items():
                for _if in _interfaces:
                    if _if["ip_address"] not in _ips:
                        # there was no such ip yet
                        _ips.add(_if["ip_address"])
                    else:
                        # this ip already used
                        logger_cli.warning(
                            "Warning: Duplicate ip address: "
                            "'{}: {}' at {} -> {}".format(
                                _if["interface"],
                                _if["ip_address"],
                                _net,
                                _node_name
                            )
                        )
                        self.errors.add_error(
                            self.errors.NET_DUPLICATE_IP,
                            network=_net,
                            node_name=_node_name,
                            if_name=_if["interface"],
                            ip_address=_if["ip_address"]
                        )
        return

    def _check_non_uniform_mtu(self):
        # shortcuts
        logger_cli.debug("... checking for duplicate ips")
        _map = self.mapper.map
        for _net in _map.keys():
            _mtus = set()
            for _node_name, _interfaces in _map[_net].items():
                for _if in _interfaces:
                    if len(_mtus) < 1:
                        # this is the 1st iteration
                        _mtus.add(_if["rt_mtu"])
                    elif _if["rt_mtu"] not in _mtus:
                        # this ip already used
                        logger_cli.warning(
                            "Non-uniform MTU value of '{}' in '{}': "
                            "{}:{}:{}".format(
                                _net,
                                _node_name,
                                _if["interface"],
                                _if["ip_address"],
                                _if["rt_mtu"]
                            )
                        )
                        self.errors.add_error(
                            self.errors.NET_DUPLICATE_IP,
                            network=_net,
                            node_name=_node_name,
                            if_name=_if["interface"],
                            ip_address=_if["ip_address"],
                            mtu=_if["rt_mtu"]
                        )
        return

    def check_networks(self, map=True, skip_keywords=None):
        # Load map
        self.mapper.map_networks()
        self.mapper.create_map(skip_keywords=skip_keywords)
        # Check for errors that is not detectable during mapping
        self._check_duplicate_ips()
        self._check_non_uniform_mtu()
        # print map if requested
        if map:
            self.mapper.print_map()

    def print_summary(self):
        logger_cli.info(self.errors.get_summary(print_zeros=False))

    def print_error_details(self):
        # Detailed errors
        logger_cli.info(
            "\n{}\n".format(
                self.errors.get_errors()
            )
        )

    def create_html_report(self, filename):
        """
        Create static html showing network schema-like report

        :return: none
        """
        logger_cli.info("### Generating report to '{}'".format(filename))
        _report = reporter.ReportToFile(
            reporter.HTMLNetworkReport(self.mapper.master),
            filename
        )
        _report(
            {
                "domain": self.mapper.domain,
                "nodes": self.mapper.nodes,
                "map": self.mapper.map,
                "mcp_release": self.mapper.cluster['mcp_release'],
                "openstack_release": self.mapper.cluster['openstack_release']
            }
        )
        logger_cli.info("-> Done")


class SaltNetworkChecker(NetworkChecker):
    def __init__(
        self,
        config,
        skip_list=None,
        skip_list_file=None
    ):
        super(SaltNetworkChecker, self).__init__()
        self.mapper = SaltNetworkMapper(
            config,
            errors_class=self.errors,
            skip_list=skip_list,
            skip_list_file=skip_list_file
        )


class KubeNetworkChecker(NetworkChecker):
    def __init__(
        self,
        config,
        skip_list=None,
        skip_list_file=None
    ):
        super(KubeNetworkChecker, self).__init__()
        self.mapper = KubeNetworkMapper(
            config,
            errors_class=self.errors,
            skip_list=skip_list,
            skip_list_file=skip_list_file
        )
