# Copyright 2023 Mirantis Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import abc
import json
import os

import docker
import jinja2
from oslo_log import log
from tempest import config
import tenacity


CONF = config.CONF
LOG = log.getLogger(__name__)
BGPD_LOG_PATH = "/var/log/bgpd.log"


class GenericConnector:
    @abc.abstractmethod
    def exec(self, cmd):
        pass


class ContainerConnector(GenericConnector):
    def __init__(self, container):
        self.ctn = container

    def exec(self, cmd):
        res_cmd = f"vtysh -d bgpd  -c '{cmd}'"
        return self.ctn.exec_run(res_cmd)


class BGPClient:
    def __init__(self, connector):
        self.connector = connector

    def exec(self, cmd):
        res = self.connector.exec(cmd)
        if res.exit_code == 0:
            return res.output

        LOG.debug("Container %s logs:\n%s\nbgpd:\n%s",
                  self.connector.ctn.name,
                  self.connector.ctn.logs(),
                  self.connector.ctn.exec_run(f"cat {BGPD_LOG_PATH}"))
        raise Exception(f"Failed to run command {cmd}")

    def show_bgp_neighbors(self, *args):
        cmd = ["show", "bgp", "neighbors", "json"]
        cmd.extend(args)
        return json.loads(self.exec(" ".join(cmd)))

    def show_bgp_ipv4(self):
        cmd = ["show", "bgp", "ipv4", "json"]
        return json.loads(self.exec(" ".join(cmd)))

    def show_bgp_ipv6(self):
        cmd = ["show", "bgp", "ipv6", "json"]
        return json.loads(self.exec(" ".join(cmd)))

    def show_bgp(self):
        cmd = ["show", "bgp", "json"]
        return json.loads(self.exec(" ".join(cmd)))


class FrrBGPContainer:
    conf_dir = "/etc/frr"

    def __init__(self, name, image, bgpd, daemons=None):
        self.image = image
        self.name = name
        self.config_dir = os.path.join(CONF.state_path, f"ctn_base/{name}/")
        self.volumes = [f"{self.config_dir}:{self.conf_dir}"]
        self.daemons = daemons or {
            "bgpd": {"enabled": "yes"},
            "vtysh": {"enabled": "yes"},
        }
        if self.daemons.get("bgpd"):
            self.daemons["bgpd"]["log_path"] = BGPD_LOG_PATH
        self.bgpd = bgpd
        self.docker_client = docker.from_env()
        self._create_config_debian()
        self.ctn = None
        self._bgp_client = None

    @property
    def bgp_client(self):
        if self._bgp_client is None:
            self._bgp_client = BGPClient(ContainerConnector(self.ctn))
        return self._bgp_client

    def _create_config_debian(self):
        environment = jinja2.Environment(
            loader=jinja2.FileSystemLoader(
                os.path.join(os.path.dirname(__file__), "templates")
            )
        )
        template = environment.get_template("bgpd.conf")
        if not os.path.exists(self.config_dir):
            os.makedirs(self.config_dir)
        for cfg_file in ["daemons", "vtysh.conf", "bgpd.conf"]:
            with open(
                os.path.join(self.config_dir, cfg_file),
                mode="w",
                encoding="utf-8",
            ) as conf:
                template = environment.get_template(cfg_file)
                data = template.render(bgpd=self.bgpd, daemons=self.daemons)
                conf.write(data)

    def run(self, wait=True):
        self.docker_client.images.pull(self.image)
        self.ctn = self.docker_client.containers.create(
            image=self.image,
            name=self.name,
            volumes=self.volumes,
            privileged=True,
            network_mode="host",
        )
        self.ctn.start()
        if wait:
            self._wait_running()

    @tenacity.retry(
        retry=tenacity.retry_if_result(lambda val: val is not True),
        wait=tenacity.wait_random(min=5, max=15),
        stop=tenacity.stop_after_delay(60),
    )
    def _wait_running(self):
        self.ctn.reload()
        if self.ctn.status == "running":
            return True

    def exec_on_ctn(self, cmd, capture=True, detach=False):
        self.ctn.exec_run(cmd, detach=detach)

    @tenacity.retry(
        retry=tenacity.retry_if_result(lambda val: val is not True),
        wait=tenacity.wait_random(min=5, max=15),
        stop=tenacity.stop_after_delay(CONF.dynamic_routing.frr_bgp_timeout),
    )
    def bgp_check_neighbor_state(self, nei_ident, expected_state):
        res = self.bgp_client.show_bgp_neighbors()
        neighbor_states = [
            nei.get("bgpState") == expected_state
            for nei in res.values()
            if nei.get("peerGroup") == nei_ident
        ]
        return len(neighbor_states) > 0 and all(neighbor_states)

    @tenacity.retry(
        retry=tenacity.retry_if_result(lambda val: val is not True),
        wait=tenacity.wait_random(min=5, max=15),
        stop=tenacity.stop_after_delay(CONF.dynamic_routing.frr_bgp_timeout),
    )
    def bgp_check_neighbor_absent(self, nei_ident):
        res = self.bgp_client.show_bgp_neighbors()
        neighbors = [
            nei for nei in res.values() if nei.get("peerGroup") == nei_ident
        ]
        return len(neighbors) == 0

    @tenacity.retry(
        retry=tenacity.retry_if_result(lambda val: val is not True),
        wait=tenacity.wait_random(min=5, max=15),
        stop=tenacity.stop_after_delay(CONF.dynamic_routing.frr_bgp_timeout),
    )
    def bgp_check_rib(self, ip_version, cidr, nexthop=None):
        res = getattr(self.bgp_client, f"show_bgp_{ip_version}")()
        should = {"cidr": False}
        if nexthop:
            should["nexthop"] = False

        for _cidr, routes in res.get("routes", {}).items():
            if cidr == _cidr:
                should["cidr"] = True
            for route in routes:
                if nexthop:
                    for hop_data in route.get("nexthops", []):
                        if nexthop == hop_data.get("ip"):
                            should["nexthop"] = True
        return all(should.values())
