import time
import yaml

from si_tests import logger
from si_tests.utils import waiters
from si_tests.utils import exceptions as timeout_exceptions
from si_tests.utils import packaging_version as version

LOG = logger.logger


def _produce_io_errors(machine, verbose=False):
    machine_name = machine.name
    k8s_node_name = machine.get_k8s_node_name()
    ip = machine.public_ip if machine.public_ip else machine.internal_ip
    LOG.info(f"Using machine {machine_name} (k8s_node_name: {k8s_node_name}, ip: {ip}) for generation I/O errors")

    # Calculate I/O errors before producing new errors
    LOG.info("Collect total i/o errors before generation new")
    dmesg_io_before = int(machine.run_cmd('sudo journalctl -k | grep "I/O error" | wc -l',
                                          verbose=verbose).stdout_str)
    LOG.info(f"Total I/O errors in dmesg on machine {machine.name} is: {dmesg_io_before}")
    # Creating fake block device
    LOG.info("Creating fake block device")
    machine.run_cmd('sudo dd if=/dev/zero of=/tmp/fake_device.img bs=1M count=50', verbose=verbose)
    losetup = machine.run_cmd('sudo losetup -a | cut -d ":" -f1', verbose=verbose).stdout_str
    losetup = losetup.splitlines()
    indexes = []
    for line in losetup:
        indexes.append(int(line.split('loop')[1]))
    if not indexes:
        new_index = 0
    else:
        new_index = max(indexes) + 1
    # Setup device
    LOG.info(f"Setup block device as /dev/loop{new_index} using /tmp/fake_device.img")
    machine.run_cmd(f'sudo losetup /dev/loop{new_index} /tmp/fake_device.img', verbose=verbose)

    # Create error mapping table
    LOG.info("Create and setup error_device with error mapping table ")
    machine.run_cmd(f'echo "0 2048 linear /dev/loop{new_index} 0\n2048 2048 error" > /tmp/error_table.txt',
                    verbose=verbose)
    machine.run_cmd('sudo dmsetup create error_device --table "$(cat /tmp/error_table.txt)"', verbose=verbose)

    # Write to device to get I/O errors in dmesg
    LOG.info("Writing data to error_device to get I/O errors")
    res = machine.run_cmd('sudo dd if=/dev/zero of=/dev/mapper/error_device bs=512 count=8192',
                          check_exit_code=False, verbose=verbose).stderr_str

    # Check that cmd failed exactly due to i/o error
    if 'Input/output error' not in res:
        raise Exception(res)

    # Cleanup device
    LOG.info(f"Cleanup machine {machine.name} from error device")
    machine.run_cmd(f'sudo losetup -d /dev/loop{new_index}', verbose=verbose)
    machine.run_cmd('sudo dmsetup remove error_device', verbose=verbose)

    # Calculate I/O errors after producing new
    LOG.info("Collect total i/o errors after generation")
    dmesg_io_after = int(machine.run_cmd('sudo journalctl -k | grep "I/O error" | wc -l',
                                         verbose=verbose).stdout_str)
    LOG.info(f"Total I/O errors in dmesg after generation on machine {machine.name} is: {dmesg_io_after}")

    # Check that there are more errors than before we generate
    assert dmesg_io_after > dmesg_io_before, "No new I/O errors in kernel. No reason to continue test"

    return dmesg_io_after


def _get_prometheus_response_for_node(cluster, node_name, query='kernel_io_errors_total'):
    prometheus_response = cluster.prometheusclient.get_query(query=query)
    resp_for_node = [i for i in prometheus_response if i.get('metric').get('node') == node_name]
    assert resp_for_node, (f"No response from prometheus for node {node_name} in cluster {cluster.name}. Looks like"
                           f" cluster version {cluster.clusterrelease_version} is not supporting {query} metric")
    error_cnt = int(resp_for_node[0]['value'][1])
    timestamp = time.localtime(float(resp_for_node[0]['value'][0]))
    LOG.info(f"Value for {query} from prometheus is: {error_cnt}. "
             f"Collected at: {time.strftime('%Y-%m-%d %H:%M:%S', timestamp)}")
    return error_cnt


def test_io_errors_monitoring(kaas_manager, show_step):
    """Test for monitoring i/o errors in stacklight

    Scenario:
            1. Check mgmt cluster
            2. Get mgmt machine
            3. Setup fake device
            4. Setup errors table
            5. Write to fake device to get errors
            6. Compare errors number from prometheus with number from kernel
            7. Check child clusters
            8. Get one control and one worker from every existing child
            9. Repeat steps 2-5 for every taken machine
    """

    mgmt_cluster = kaas_manager.get_mgmt_cluster()
    mgmt_cluster_name = mgmt_cluster.name

    # Get 1 mgmt machine for test
    mgmt_machine = mgmt_cluster.get_machines()[0]
    mgmt_node_name = mgmt_machine.get_k8s_node_name()

    err_data = {}
    show_step(1)
    mgmt_node_io_err = _produce_io_errors(mgmt_machine)
    try:
        # Prometheus requests data from fluentd every 30 sec. So will wait 1 min to be sure
        LOG.info("Waiting for actual errors number is equal errors number from prometheus")
        waiters.wait(lambda: _get_prometheus_response_for_node(mgmt_cluster, mgmt_node_name) == mgmt_node_io_err,
                     timeout=60, interval=10)
    except timeout_exceptions.TimeoutError:
        error_cnt_from_prometheus = _get_prometheus_response_for_node(mgmt_cluster, mgmt_node_name)
        msg = (f"Errors count from prometheus is not equal actual errors count from kernel. "
               f"Actual: {mgmt_node_io_err}, Prometheus: {error_cnt_from_prometheus}")
        LOG.error(msg)
        err_data.setdefault(mgmt_cluster_name, {})[mgmt_node_name] = msg
    LOG.info(f"========== Finished for cluster {mgmt_cluster_name} ==========\n")

    # Get all child clusters
    mke = 'mke-16-2-0-3-7-8'
    mos = 'mosk-17-2-0-24-2'
    child_clusters = kaas_manager.get_child_clusters()
    if not child_clusters:
        LOG.info("No child clusters existed.")
    for cl in child_clusters:
        cl_version = cl.clusterrelease_version
        if cl_version.startswith('mke') and version.parse(cl_version) >= version.parse(mke) or\
                cl_version.startswith('mos') and version.parse(cl_version) >= version.parse(mos):
            show_step(7)
            cl_name = cl.name
            # Get 1 worker and 1 control machines
            worker = cl.get_machines(machine_type='worker')[0]
            control = cl.get_machines(machine_type='control')[0]

            for child_machine in worker, control:
                child_k8s_node_name = child_machine.get_k8s_node_name()
                child_machine_name = child_machine.name
                io_err = _produce_io_errors(child_machine)
                try:
                    # Prometheus requests data from fluentd every 30 sec. So will wait 1 min to be sure
                    LOG.info("Waiting for actual errors number is equal errors number from prometheus")
                    waiters.wait(lambda: _get_prometheus_response_for_node(cl, child_k8s_node_name) == io_err,
                                 timeout=60, interval=10)
                except timeout_exceptions.TimeoutError:
                    error_cnt_from_prometheus = _get_prometheus_response_for_node(cl, child_k8s_node_name)
                    msg = (f"Errors count from prometheus is not equal actual errors count from kernel. "
                           f"Actual: {io_err}, Prometheus: {error_cnt_from_prometheus}")
                    LOG.error(msg)
                    err_data.setdefault(cl_name, {})[child_k8s_node_name] = msg

                LOG.info(f"========== Finished for machine {child_machine_name} ==========\n")
            LOG.info(f"========== Finished for cluster {cl_name} ==========\n")
        if err_data:
            raise AssertionError(f"\n{yaml.dump(err_data)}")
