import pytest

from si_tests import logger
from si_tests.utils import utils, waiters

LOG = logger.logger


def get_process_pids(ssh, name):
    cmd = f"pgrep -fa {name}"
    exec_result = ssh.check_call(cmd)
    lines = [out.decode().strip() for out in exec_result.stdout]
    LOG.info(f"Output for {cmd}:\n{lines}")
    pids = [line.split(' ')[0] for line in lines if cmd not in line]
    return pids


@pytest.mark.usefixtures('mos_workload_downtime_report')
@pytest.mark.usefixtures('mos_per_node_workload_check_after_test')
def test_restart_containerd_tf_compute_node(tf_manager, os_manager, show_step):
    """Restart containerd with TF vRouter pod and check node isn't rebooted.

    Scenario:
        1. Collect compute nodes with containerd runtime
        2. Get vRouter pod restart and pid process on target compute nodes
        3. Restart containerd on nodes
        4. Wait all pods are ready
        5. Check vRouter pods restart count and pid of vRouter agent on nodes
    """
    expected_runtime = "containerd"
    process_name = "/usr/bin/contrail-vrouter-agent"

    show_step(1)
    compute_nodes = tf_manager.get_vrouter_nodes()
    nodes_with_exp_runtime = [cmp for cmp in compute_nodes if
                              expected_runtime in cmp.data['status']['node_info']['container_runtime_version']]
    if not nodes_with_exp_runtime:
        pytest.skip(f"No {expected_runtime} runtime enabled")

    nodes_data = {}
    LOG.info(f"Compute nodes with {expected_runtime} runtime:")
    for node in nodes_with_exp_runtime:
        node_name = node.data['metadata']['name']
        LOG.info(f"Node: {node_name}")
        nodes_data[node_name] = {'node': node}
        nodes_data[node_name]['resource_version'] = node.data['metadata']['resource_version']
        nodes_data[node_name]['ssh'] = utils.ssh_k8s_node(node_name)

    show_step(2)
    for node_name in nodes_data:
        ssh = nodes_data[node_name]['ssh']
        nodes_data[node_name]['pids'] = get_process_pids(ssh, process_name)
        pods = tf_manager.get_vrouter_pods()
        pod = list(filter(lambda p: p.read().spec.node_name == node_name, pods))[0]
        nodes_data[node_name]['pod'] = pod
        nodes_data[node_name]['pod_restarts'] = pod.get_restarts_number()

    show_step(3)
    for node_name in nodes_data:
        ssh = nodes_data[node_name]['ssh']
        pod = nodes_data[node_name]['pod']
        LOG.info(f"{expected_runtime} service on node {node_name} will be restarted."
                 f"Get restart number for pod {pod.name}")
        ssh.check_call(f"sudo systemctl restart {expected_runtime}")

    show_step(4)
    # After restart of containerd node for a short time become 'NotReady' and
    # resource version is changed
    for node_name in nodes_data:
        node = nodes_data[node_name]['node']
        init_version = nodes_data[node_name]['resource_version']
        waiters.wait(
            lambda: node.data['metadata']['resource_version'] != init_version, timeout=210
        )
        LOG.info(f"Node {node_name} resource version was changed from {init_version} to "
                 f"{node.data['metadata']['resource_version']}")

    for node_name in nodes_data:
        LOG.info(f"Waiting {node_name} returns in 'Ready' state.")
        node = nodes_data[node_name]['node']
        waiters.wait_pass(
            lambda: node.check_status(expected_status='True'), timeout=120
        )

        tf_manager.wait_all_pods_on_node(node_name)
    tf_manager.wait_tf_controllers_healthy()
    os_manager.wait_openstackdeployment_health_status(timeout=300)

    show_step(5)
    errors_msg = ''
    for node_name in nodes_data:
        ssh = nodes_data[node_name]['ssh']
        pod = nodes_data[node_name]['pod']
        before = nodes_data[node_name]['pod_restarts']
        pids_before = nodes_data[node_name]['pids']
        after = pod.get_restarts_number()
        pids_after = get_process_pids(ssh, process_name)

        LOG.info(f"Number of restarts for pod {pod.read().metadata.name} "
                 f"(node: {node_name}) (before/after): {before}/{after}")
        LOG.info(f"Pids of {process_name} process (node {node_name}) "
                 f"(before/after): {pids_before}/{pids_after}")
        if before != after:
            errors_msg += (f"Unexpected number of restarts: expected {before}"
                           f"actual: {after}. Restart wasn't expected.\n")

        if pids_before != pids_after:
            errors_msg += (f"Unexpected number of restarts: expected {before}"
                           f"actual: {after}. Restart wasn't expected.\n")
    assert not errors_msg, (f"vRouters shouldn't restarted:\n"
                            f"{errors_msg}")
