import yaml
import pytest

import numpy as np
from dateutil import parser
from datetime import datetime
from dataclasses import dataclass
from operator import attrgetter
from si_tests import logger, settings

LOG = logger.logger


@dataclass
class MachineData:
    upgrade_index: int
    update_group: str
    update_group_index: int
    update_group_concurrency: int
    started_at: datetime
    finished_at: datetime
    reconfigure_started_at: datetime
    reconfigure_finished_at: datetime
    machine_name: str

    def __lt__(self, other):
        return self.upgrade_index < other.upgrade_index


def test_compare_indexes_timestamps():
    """
    This test is comparing upgrade order based on timestamps vs UpdateGroup configuration (Index and concurrency)
    and UpgradeIndex values
    - comparing upgrade order based on timestamps vs UpgradeIndex values around every UpdateGroup
    - comparing upgrade order based on UpdateGroup Indexes and timestamps of first and last machines in groups
    - also checks logic based on dedicated controlplane True/False
    """

    timestamps_file = settings.MACHINES_TIMESTAMPS_YAML_FILE
    if not timestamps_file:
        message = "No data was provided for testing. Skipping"
        LOG.warning(message)
        pytest.skip(message)
    with open(timestamps_file, 'r') as f:
        data = yaml.safe_load(f)

    if not data['is_ansible_version_changed']:
        message = "Ansible version on machines is the same as before, nothing to check here. Skipping"
        LOG.warning(message)
        pytest.skip(message)

    compare_list_control = []
    compare_list_worker = []

    for machine, d in data.get('after_test').items():
        m_type = d.get('machine_type')
        upgrade_index = d.get('upgradeIndex')
        update_group = d.get('updateGroup').get('Name')
        update_group_index = d.get('updateGroup').get('Index')
        update_group_concurrency = d.get('updateGroup').get('Concurrency')
        started_at = parser.isoparse(d.get('phases').get('deploy').get('startedAt'))
        finished_at = parser.isoparse(d.get('phases').get('deploy').get('finishedAt'))
        reconfigure_started = parser.isoparse(d.get('phases').get('reconfigure').get('startedAt'))
        reconfigure_finished = parser.isoparse(d.get('phases').get('reconfigure').get('finishedAt'))

        LOG.info(f"Machine {machine} type is {m_type} with upgrade index {upgrade_index} "
                 f"started to upgrade at: {started_at}")
        m_data = MachineData(
            upgrade_index=upgrade_index,
            update_group=update_group,
            update_group_index=update_group_index,
            update_group_concurrency=update_group_concurrency,
            started_at=started_at,
            finished_at=finished_at,
            reconfigure_started_at=reconfigure_started,
            reconfigure_finished_at=reconfigure_finished,
            machine_name=machine)
        if m_type == 'control':
            compare_list_control.append(m_data)
        elif m_type in ['worker', 'storage']:
            compare_list_worker.append(m_data)
        else:
            raise NotImplementedError("Unexpected machine type")

    # Sort machines lists by timestamps and indexes to compare
    sorted_by_time_ctl = sorted(compare_list_control, key=lambda x: x.started_at)
    sorted_by_index_ctl = sorted(compare_list_control, key=attrgetter('update_group_index', 'upgrade_index'))

    expected_upgrade_sequence_ctrl = " --> ".join([i.machine_name for i in sorted_by_index_ctl])
    actual_upgrade_sequence_ctrl = " --> ".join([i.machine_name for i in sorted_by_time_ctl])
    LOG.info(f"Expected sequence for control machines: {expected_upgrade_sequence_ctrl}")
    LOG.info(f"Actual   sequence for control machines: {actual_upgrade_sequence_ctrl}")

    fail_messages = []
    if sorted_by_time_ctl != sorted_by_index_ctl:
        fail_messages.append("Upgrade sequence for controllers is not as expected.")

    ug, ug_members_count = np.unique([g.update_group for g in compare_list_worker], return_counts=True)
    ug_indexes, ugi_members = np.unique([g.update_group_index for g in compare_list_worker], return_counts=True)

    LOG.info(f"All updateGroup for workers: {ug}")
    LOG.info(f"All updateGroup indexes: {ug_indexes}")

    if sum(ug_members_count) != len(compare_list_worker):
        fail_messages.append("Mismatch count of worker machines and all members of updateGroups")

    groups_for_workers = set([g.update_group for g in compare_list_worker])

    for ug in groups_for_workers:
        workers_by_update_group = [w for w in compare_list_worker if w.update_group == ug]
        parallel = workers_by_update_group[0].update_group_concurrency

        sorted_by_time_wkr = sorted(workers_by_update_group, key=lambda x: x.started_at)
        sorted_by_index_wkr = sorted(workers_by_update_group, key=lambda x: x.upgrade_index)

        # Split by groups depending on UpdageGroup concurrentUpdates value
        groups = [
            sorted(sorted_by_index_wkr[i:i + parallel]) for i in range(0, len(sorted_by_index_wkr), parallel)]
        groups_by_time = [
            sorted(sorted_by_time_wkr[i:i + parallel]) for i in range(0, len(sorted_by_time_wkr), parallel)]

        expected_seq = " --> ".join([str([k.machine_name for k in j]) for j in groups])
        expected_seq_time = " --> ".join([str([k.machine_name for k in j]) for j in groups_by_time])
        LOG.info(f"Concurrency is {parallel} for updateGroup {ug}")
        LOG.info(f"Expected sequence for worker machines: {expected_seq}")
        LOG.info(f"Actual   sequence for worker machines: {expected_seq_time}")

        if groups_by_time != groups:
            fail_messages.append(f"Upgrade sequence for workers is not as expected in UpdateGroup {ug}")

    if len(ug) > 1 and len(ug_indexes) > 1:
        # Check and compare timestamps, when we have multiple updateGroups with different indexes.
        # Then the first worker of next group must start to upgrade only after last worker in
        # previous group with lower index is upgraded
        groups_for_workers_by_update_group_index = sorted(set([g.update_group_index for g in compare_list_worker]))
        last_by_time_in_prev_group = {}
        previous_group_index = 0
        for ugi in groups_for_workers_by_update_group_index:
            workers_by_update_group_index = [w for w in compare_list_worker if w.update_group_index == ugi]
            sorted_by_time_wkr_by_ug_index = sorted(workers_by_update_group_index, key=lambda x: x.started_at)
            sorted_by_end_time_wkr_by_ug_index = sorted(workers_by_update_group_index,
                                                        key=lambda x: x.reconfigure_finished_at)
            if previous_group_index:
                if not (sorted_by_time_wkr_by_ug_index[0].started_at >
                        last_by_time_in_prev_group[previous_group_index]):
                    fail_messages.append(f"First worker in group with {ugi} index started to upgrade earlier then "
                                         f"last worker in group with {previous_group_index} index is finished.\n"
                                         f"Last finished: {last_by_time_in_prev_group[previous_group_index]}\n"
                                         f"First started: {sorted_by_time_wkr_by_ug_index[0].started_at}")
                else:
                    LOG.info(f"Last worker in updateGroup with index {previous_group_index} finished upgrade earlier "
                             f"({last_by_time_in_prev_group[previous_group_index]}) than first started upgrade worker "
                             f"({sorted_by_time_wkr_by_ug_index[0].started_at}) in next updateGroup with index {ugi}, "
                             f"as expected")
            last_by_time_in_prev_group[ugi] = sorted_by_end_time_wkr_by_ug_index[-1].reconfigure_finished_at
            previous_group_index = ugi
    else:
        LOG.info("Skip check for compare timestamps between different index of updateGroups. "
                 "Update groups for workers is only one with the same index")

    all_sorted_by_time_wkr = sorted(compare_list_worker, key=lambda x: x.started_at)

    dedicated_ctrplane = data.get('dedicated_control_plane', True)
    if not dedicated_ctrplane:
        # If not dedicated controlplane, then the first worker must start to upgrade only after all of controllers are
        # upgraded. So, we'll check that first worker started_at time greater then last controller finished_at time
        # Finished time is the time, when reconfigure is finished
        first_wkr_by_time = all_sorted_by_time_wkr[0]
        last_ctl_by_time = sorted_by_time_ctl[-1]
        first_worker_started = first_wkr_by_time.started_at
        last_ctrl_finished = last_ctl_by_time.reconfigure_finished_at
        if not first_worker_started > last_ctrl_finished:
            fail_messages.append(f"Dedicated controlplane is set to False. Workers should start to upgrade only "
                                 f"after all of controllers are upgraded. But, first worker started to upgrade at "
                                 f"{first_worker_started} and last controller finished at: {last_ctrl_finished}")
    else:
        # In this case at least one of controllers had to be finished before workers are started
        # Check that first worker started_time greater then first controller finished_time
        first_wkr_by_time = all_sorted_by_time_wkr[0]
        first_ctl_by_time = sorted_by_time_ctl[0]
        first_worker_started = first_wkr_by_time.started_at
        first_ctl_finished = first_ctl_by_time.reconfigure_finished_at
        if not first_worker_started > first_ctl_finished:
            fail_messages.append("First worker started to upgrade earlier then first control is finished")

    if fail_messages:
        for message in fail_messages:
            LOG.error(message)
        raise AssertionError(fail_messages)
