import time
from hashlib import sha256

import pytest

from si_tests import settings
from si_tests.deployments.utils import commons
from si_tests.managers.clustercheck_mos_manager import ClusterCheckMosManager
from si_tests.utils import waiters


def host_hash(host):
    host = host.split(".")[0]
    return sha256(host.encode()).hexdigest()[:8]


def check_node_controller(openstack_client_manager, os_manager, hosts=None):
    hostname_attr = "OS-EXT-SRV-ATTR:hypervisor_hostname"

    hypervisor_list = openstack_client_manager.hypervisor.list([])

    assert len(hypervisor_list) > 1, (
        "Environment has 1 or less " "compute nodes %s",
        hypervisor_list,
    )

    for hypervisor in hypervisor_list:
        hostname = hypervisor["Hypervisor Hostname"]
        if hosts and hostname not in hosts:
            commons.LOG.info(f"Skip maintenance mode on {hostname}")
            continue
        commons.LOG.info("Start maintenance mode on %s", hostname)
        nmr_name = "nmr-si-test-" + host_hash(hostname)
        nwl_name = "openstack-" + hostname.split(".")[0]
        nmr_api = os_manager.api.nodemaintenancerequests
        nwl_api = os_manager.api.nodeworkloadlocks
        nmr_api.create_node_maintenance_request(hostname, nmr_name)

        # Check evacuation finished successfully
        commons.LOG.info("Check evacuation finished successfully")

        nwl_api.wait_nwl_state(nwl_name, "inactive", timeout=2760, interval=30)
        commons.LOG.info("Check node controller evacuation " "finished successfully.")

        def _check_instances_migrated(hostname):
            instances = openstack_client_manager.server.list(["--all"])
            instances = [x["ID"] for x in instances]
            not_migrated = []
            for instance_id in instances:
                instance = openstack_client_manager.server.show([instance_id])
                if instance[hostname_attr] == hostname:
                    not_migrated.append(instance["id"])
            assert not not_migrated, ("Some instances are not migrated %s, but workloadlock is inactive",
                                      not_migrated)

        _check_instances_migrated(hostname)

        commons.LOG.info("Return node from maintenance mode to status mode.")
        nmr_api.delete_node_maintenance_request(nmr_name)
        nwl_api.wait_nwl_state(nwl_name, "active", timeout=1800, interval=30)


def check_nodes_parallel(os_manager, nodes, concurrency=1, tf_computes=False):
    commons.LOG.info(f"Checking parallel nwl with concurrency {concurrency} for nodes {nodes}")
    handled_nodes = []
    nmr_prefix = "nmr-si-test"
    nmr_api = os_manager.api.nodemaintenancerequests
    nwl_api = os_manager.api.nodeworkloadlocks
    for hostname in nodes:
        nmr_name = f"{nmr_prefix}-{hostname}"
        nmr_api.create_node_maintenance_request(hostname, nmr_name, scope="drain")

    while set(nodes) - set(handled_nodes):
        to_handle = set(nodes) - set(handled_nodes)
        commons.LOG.info(f"Checking locks for nodes {to_handle}")
        wait_len = min(concurrency, len(to_handle))
        # If it's compute node and TF enabled we should track NWLs for TF
        if tf_computes:
            wait_len = wait_len*2
        waiters.wait(
            nwl_api.check_nwl_by_state_number,
            predicate_args=(
                "inactive",
                wait_len,
            ),
            timeout=600,
            interval=10,
        )
        # Make sure number of inactive locks is not changed
        for i in range(1, 3):
            commons.LOG.info(f"Checking number of inactive locks {wait_len}, attempt {i}")
            nwl_inactive = nwl_api.get_nwl_by_state("inactive")
            assert (
                len(nwl_inactive) == wait_len
            ), "Expected number of nwl in inactive is {wait_len}, but got {len(nwl_inactive)}"
            time.sleep(60)
        for nwl in nwl_inactive:
            hostname = nwl.nodename
            nmr_name = f"{nmr_prefix}-{hostname}"
            if hostname not in handled_nodes:
                nmr_api.delete_node_maintenance_request(nmr_name)
                handled_nodes.append(hostname)
            nwl_api.wait_nwl_state(nwl.name, "active")


@pytest.mark.usefixtures('mos_workload_downtime_report')
@pytest.mark.usefixtures('mos_loadtest_os_refapp')  # Should be used if ALLOW_WORKLOAD == True
@pytest.mark.usefixtures('mos_per_node_workload_check_after_test')
def test_node_controller_parallel(os_manager):
    """Test node controller maintenance parallel
    1) Create nmr for all controller nodes simulteniously
    2) Check 1st controller nwl is inactive
    3) Remove 1st controller nmr
    4) Wait 1st controller nwl is active again
    5) Repeat 2-4 for rest controllers
    6) Create nmr for all compute nodes
    7) Check 2 of them moved to inactive state
    8) Remove nmr for 2 nodes
    9) Check nwl is active again
    10) Repeat 7-9 unless all compute nodes are checked.
    """

    tf_enabled = os_manager.is_tf_enabled
    ctl_concurrency = 1
    compute_concurrency = 30
    control_nodes = os_manager.api.nodes.list(label_selector="openstack-control-plane=enabled")
    compute_nodes = os_manager.api.nodes.list(label_selector="openstack-compute-node=enabled")
    control_node_names = [n.name for n in control_nodes]
    if tf_enabled:
        # Exclude tf-vrouter gateway nodes in case of TF:
        compute_node_names = [n.name for n in compute_nodes if
                              n.data['metadata']['labels'].get('openstack-gateway') != "enabled"]
    else:
        compute_node_names = [n.name for n in compute_nodes]

    # NOTE(vsaienko): we need +1 node up to concurrency to check that nwl still
    # active while we have nmr for it.
    assert (
        len(control_node_names) >= ctl_concurrency + 1
    ), f"Environment has controller nodes less node than required minimum {ctl_concurrency + 1}"
    commons.LOG.info("Check nmr for controllers")
    check_nodes_parallel(os_manager, control_node_names, concurrency=ctl_concurrency)
    commons.LOG.info("Check nmr for computes")
    check_nodes_parallel(os_manager, compute_node_names, concurrency=compute_concurrency, tf_computes=tf_enabled)


@pytest.mark.usefixtures('mos_workload_downtime_report')
@pytest.mark.usefixtures('mos_loadtest_os_refapp')  # Should be used if ALLOW_WORKLOAD == True
@pytest.mark.usefixtures('mos_per_node_workload_check_after_test')
@pytest.mark.usefixtures('skip_by_network_backend')
@pytest.mark.network_backend('tf')
def test_tf_node_controller_parallel(os_manager):
    """Test tf node controller maintenance parallel
    1) Create nmr for all tf controller nodes simulteniously
    2) Check 1st controller nwl is inactive
    3) Remove 1st controller nmr
    4) Wait 1st controller nwl is active again
    5) Repeat 2-4 for rest controllers
    6) Create nmr for all gateway (vrouter-vgw) compute nodes
    7) Check 1st gateway nwl is inactive
    8) Remove 1st gateway nmr
    9) Wait 1st gateway nwl is active again
    10) Repeat 7-9 unless all gateway compute nodes are checked.
    """

    concurrency = 1
    tf_control_nodes = os_manager.api.nodes.list(label_selector="tfcontrol=enabled")
    tf_control_node_names = [n.name for n in tf_control_nodes]
    compute_gw_nodes = os_manager.api.nodes.list(
        label_selector="openstack-compute-node=enabled,openstack-gateway=enabled"
    )
    compute_gw_node_names = [n.name for n in compute_gw_nodes]

    assert (
        len(tf_control_nodes) >= concurrency + 1
    ), f"Environment has tf controller nodes less node than required minimum {concurrency + 1}"
    assert (
            len(compute_gw_nodes) >= concurrency + 1
    ), f"Environment has gw compute nodes less node than required minimum {concurrency + 1}"

    commons.LOG.info("Check nmr for tf controllers")
    check_nodes_parallel(os_manager, tf_control_node_names, concurrency=concurrency)
    commons.LOG.info("Check nmr for gateway compute nodes")
    check_nodes_parallel(os_manager, compute_gw_node_names, concurrency=concurrency, tf_computes=True)


@pytest.mark.usefixtures('mos_workload_downtime_report')
@pytest.mark.usefixtures('mos_loadtest_os_refapp')  # Should be used if ALLOW_WORKLOAD == True
@pytest.mark.usefixtures('mos_per_node_workload_check_after_test')
def test_node_controller_evacuate(openstack_client_manager, os_manager, request, func_name):
    """Test node controller evacuation. Test can be run only on vms.
    """
    commons.LOG.info("Checking node controller")
    check_node_controller(openstack_client_manager, os_manager)
    commons.LOG.info("Node controller evacuation finished successfully.")


@pytest.mark.usefixtures('mos_workload_downtime_report')
@pytest.mark.usefixtures('mos_loadtest_os_refapp')  # Should be used if ALLOW_WORKLOAD == True
@pytest.mark.usefixtures('mos_per_node_workload_check_after_test')
@pytest.mark.usefixtures('skip_by_network_backend')
@pytest.mark.network_backend('tf')
def test_tf_amphora_evacuate(openstack_client_manager, os_manager, request,
                             show_step):
    """Check migration of amphora LB.

        Scenario:
            1. Deploy LB with amphora
            2. Find host with amphora instance
            3. Create node maintenance request and wait instances are
            evacuated
            4. Check LB is working
    """

    show_step(1)
    az = settings.OPENSTACK_AVAILABILITY_ZONE_NAME
    hosts = openstack_client_manager.host.list([])
    hosts_az = [host for host in hosts if host.get('Zone') == az]
    assert len(hosts_az) >= 2, (f"Test requires not less than 2 compute "
                                f"nodes in az {az}")

    stack = ClusterCheckMosManager.created_stack_tf_lb(request, openstack_client_manager,
                                                       custom_params={"lb_provider": "amphorav2"})

    lb_url = None
    net_name = None
    for output in stack['outputs']:
        if output['output_key'] == 'lb_url':
            lb_url = output['output_value']
        if output['output_key'] == 'network_name':
            net_name = output['output_value']

    show_step(2)
    # Get amphora instance by network affiliation
    amphora_vm = None
    servers = openstack_client_manager.server.list(['--all'])
    for server in servers:
        if net_name in server['Networks'].keys() and server['Name'].startswith("amphora"):
            amphora_vm = openstack_client_manager.server.show([server['ID']])
            break
    host = amphora_vm['OS-EXT-SRV-ATTR:hypervisor_hostname']
    commons.LOG.info(f"Amphora instance landing on {host} host")
    assert ClusterCheckMosManager.is_lb_functional(openstack_client_manager, 2, lb_url)

    show_step(3)
    check_node_controller(openstack_client_manager, os_manager, hosts=[host])

    show_step(4)
    assert ClusterCheckMosManager.is_lb_functional(openstack_client_manager, 2, lb_url)
