import exec_helpers
import yaml
import time

from si_tests import logger
from si_tests import settings
from si_tests.utils import utils, waiters
from si_tests.fixtures.kubectl import os_manager # noqa
from si_tests.fixtures.openstack import openstack_client_manager # noqa
from si_tests.fixtures.mos import mos_per_node_workload # noqa

LOG = logger.logger


def ssh_node(hostname):
    with open(settings.NODES_INFO) as f:
        dictionary = yaml.safe_load(f)
    private_key = dictionary.get(hostname, {}).get('ssh')['private_key']
    assert private_key, "Private key is not defined"
    pkey = utils.get_rsa_key(private_key)
    node_ip = dictionary.get(hostname, {}).get('ip')['address']
    username = dictionary.get(hostname, {}).get('ssh')['username']
    auth = exec_helpers.SSHAuth(username=username, password='', key=pkey)
    ssh = exec_helpers.SSHClient(host=node_ip, port=22, auth=auth)
    ssh.logger.addHandler(logger.console)
    ssh.sudo_mode = True
    return ssh


def test_node_reboot_compute(mos_per_node_workload, os_manager): # noqa
    """Test a compute node reboot with VMs on it
    1) Create VM on the compute
    2) Reboot node gracefully
    3) Wait compute restarted
    4) Ensure VM is started after reboot
    5) Check that VM was gracefully rebooted
    """

    stack_name = [x for x in mos_per_node_workload.stacks_data.keys()][0]
    servers_to_verify = mos_per_node_workload.get_stack_servers_to_verify(stack_name)
    servers_reboot_before = mos_per_node_workload.get_servers_reboot_count(stack_name)
    # Reboot
    compute_hostname = stack_name.replace('pernode-test', '')
    LOG.info("Rebooting host %s", compute_hostname)
    ssh = ssh_node(compute_hostname)
    ssh.check_call("systemctl reboot")
    time.sleep(60)
    for server_id, server in servers_to_verify.items():
        fip = server["fip"]
        waiters.wait_tcp(fip, 22, timeout=900, timeout_msg=f"Waiting {fip} port 22 timed out")
    servers_reboot_after = mos_per_node_workload.get_servers_reboot_count(stack_name)
    servers_reboot_expected = {}
    for server, count in servers_reboot_before.items():
        servers_reboot_expected[server] = count + 1
    assert servers_reboot_after == servers_reboot_expected, "Reboot count of VMs does not match expected value."
    os_manager.wait_openstackdeployment_health_status(timeout=900)
