import functools
import io
import os
import paramiko
import pytest
import time
import uuid

from exec_helpers import SSHClient, SSHAuth
from si_tests.utils import waiters, utils
from si_tests import logger
from si_tests.deployments.utils import (
    commons,
    file_utils,
)

LOG = logger.logger


class TestTFLLGR(object):

    @pytest.fixture()
    def restore_cr(self, tf_manager):
        commons.LOG.info("Backup tfoperator CR specs")
        tfoperator_cr = tf_manager.tfoperator(detect_api=True)
        spec = tfoperator_cr.data['spec']
        enabled = tf_manager.is_gracefulRestart_enabled()
        if tf_manager.apiv2:
            grRes = spec.get('features', {}).get('control', {}).get('gracefulRestart', {})
            grRes['enabled'] = enabled
            spec = {
                'features': {
                    'control': {
                        'gracefulRestart': grRes
                    }
                }
            }
        else:
            grRes = spec.get('settings', {}).get('gracefulRestart', {})
            grRes['enabled'] = enabled
            spec = {
                'settings': {
                    'gracefulRestart': grRes
                }
            }
        yield
        commons.LOG.info("Restore tfoperator CR specs")
        tfoperator_cr.patch({"spec": spec})

    @pytest.mark.usefixtures("restore_cr")
    def test_tf_llgr(self, openstack_client_manager, tf_manager, request,
                     show_step):
        """Test TungstenFabric BGP LLGR.

        Scenario:
            1. Enable LLGR
            2. Run VMs
            3. Emulate lost connection between vrouter and control
            4. Check network connectivity between VMs
        """
        show_step(1)

        grRes = {
            "gracefulRestart": {
                "enabled": True,
                "bgpHelperEnabled": True,
                "xmppHelperEnabled": True,
                "restartTime": 600,
                "llgrRestartTime": 600,
                "endOfRibTimeout": 600,
            },
        }

        if tf_manager.apiv2:
            settings = {
                "features": {
                    "control": grRes
                }
            }
        else:
            settings = {
                "settings": grRes
            }
        tfoperator_cr = tf_manager.tfoperator()
        tfoperator_cr.patch({"spec": settings})

        gr_job = tf_manager.wait_gr_gob(timeout=300)
        gr_job.wait_succeded(timeout=240)

        show_step(2)
        az = 'nova'
        hosts = openstack_client_manager.host.list([])
        hosts_az = [host for host in hosts if host.get('Zone') == az]
        assert len(hosts_az) >= 2, (f"Test requires not less than 2 compute "
                                    f"nodes in az {az}")

        test_host = hosts_az[0]['Host Name']
        stack_name = f"si-test-servers-{uuid.uuid4()}"
        template_path = file_utils.join(os.path.dirname(os.path.abspath(__file__)),
                                        "templates/servers.yaml")
        ssh_keys = utils.generate_keys()
        stack_params = {
            "host_name_1": test_host,
            "host_name_2": hosts_az[1]['Host Name'],
            "ssh_public_key": f"ssh-rsa {ssh_keys['public']}",
            "ssh_private_key": ssh_keys["private"]
            }
        openstack_client_manager.create_stack(request, stack_name, template_path,
                                              stack_params=stack_params)

        stack = openstack_client_manager.stack.show([stack_name])
        pkey = None
        vm1_fip = None
        vm2_ip = None
        for output in stack['outputs']:
            if output['output_key'] == 'private_key':
                pkey = paramiko.RSAKey.from_private_key(io.StringIO(output['output_value']))
            if output['output_key'] == 'server_1_public_ip':
                vm1_fip = output['output_value']
            if output['output_key'] == 'server_2_ip':
                vm2_ip = output['output_value']

        vm1_ssh = SSHClient(
            host=vm1_fip,
            port=22,
            auth=SSHAuth(username='ubuntu', key=pkey)
        )
        ping_success = waiters.icmp_ping(host=vm2_ip, ssh_client=vm1_ssh)
        assert ping_success, f"VM 2 {vm2_ip} is unreachable"

        show_step(3)

        # Find node related with OS host
        node_name = None
        pods = openstack_client_manager.tf_manager.get_vrouter_pods()
        for pod in pods:
            pod_spec = pod.read().spec
            if pod_spec.node_name == test_host:
                node_name = pod.read().spec.node_name
                break
        assert node_name is not None, f"Compute node {test_host} wasn't found"

        # Control nodes ip can be obtained from env variables
        cfg_map = openstack_client_manager.tf_manager.get_tf_services_cfgmap()
        control_nodes = cfg_map.data['data']['CONTROL_NODES'].split(',')

        commons.LOG.info(f"Emulate all Control nodes outage situation by "
                         f"using iptables command on compute node {node_name}")
        cmds, cmds_cleanup = [], []
        for ctl in control_nodes:
            rule = f"iptables -A OUTPUT -d {ctl} -p tcp " \
                   f"--dport 5269 -j DROP"
            cmds.append(rule)
            cmds_cleanup.append(rule.replace("-A OUTPUT", "-D OUTPUT"))
        cmd = ' && '.join(cmds)
        cmd_leanup = ' && '.join(cmds_cleanup)
        output = openstack_client_manager.tf_manager.exec_pod_cmd(
            node_name, cmd, verbose=False)['logs']
        commons.LOG.info(output)
        request.addfinalizer(functools.partial(
            openstack_client_manager.tf_manager.exec_pod_cmd, node_name,
            cmd_leanup)
        )

        show_step(4)
        # The default value for the hold time suggested in the BGP specification
        # (RFC 4271) is 90 seconds. Contrail has the same hold time for BGP peers,
        # so let's check connectivity since 120 seconds.
        timeout = 120
        commons.LOG.info(f"Check network connectivity after {timeout} seconds")
        time.sleep(timeout)
        ping_success = waiters.icmp_ping(host=vm2_ip, ssh_client=vm1_ssh)
        assert ping_success, f"VM 2 {vm2_ip} is unreachable from VM 1"
