from dataclasses import dataclass

from tabulate import tabulate

from si_tests import logger
from si_tests import settings
from si_tests.clients.k8s.pods import K8sPod
from si_tests.managers.kaas_manager import Manager, Machine
from si_tests.utils.utils import NodeLabel

LOG = logger.logger


@dataclass
class MachineCephRoleData:
    machine: Machine
    k8s_node_name: str
    role: str
    pod: K8sPod
    container_id: str


@dataclass
class MachineLabelData:
    machine: Machine
    node_name: str
    machine_name: str
    all_labels: NodeLabel or None
    all_labels_raw: list


@dataclass
class TargetMachineData:
    labels: list
    machine: MachineLabelData


@dataclass
class OutageMachineData:
    machine: Machine
    ip: str
    k8s_node_name: str
    name: str


def collect_machines_labels():
    """
    Collect all labels from machines and k8s nodes

    Returns: NodeData objects list
    """
    result = []
    kaas_manager = Manager(kubeconfig=settings.KUBECONFIG_PATH)
    managed_ns = kaas_manager.get_namespace(settings.TARGET_NAMESPACE)
    cluster = managed_ns.get_cluster(settings.TARGET_CLUSTER)

    # find ceph_osd machines
    if cluster.workaround.skip_kaascephcluster_usage():
        ceph_cluster_health = cluster.get_miracephhealth()
        ceph_cluster_status = ceph_cluster_health.data.get('status') or {}
        osd_mapping = ceph_cluster_status.get('fullClusterStatus', {}).get('cephDetails', {}).get(
            'deviceMapping', {}).keys()
    else:
        ceph_cluster = cluster.get_kaascephcluster()
        ceph_cluster_status = ceph_cluster.data.get('status') or {}
        osd_mapping = ceph_cluster_status.get('fullClusterInfo', {}).get('cephDetails', {}).get(
            'deviceMapping', {}).keys()
    ceph_osd_machines = [*osd_mapping]
    # find keepalive machine
    keepalive_machine = cluster.get_keepalive_master_machine()

    for machine in cluster.get_machines():
        machine_labels = machine.metadata['labels']
        k8s_node = machine.get_k8s_node().read()
        k8s_node_name = machine.get_k8s_node_name()
        k8s_node_labels = k8s_node.metadata.labels
        all_labels_raw = [f"{k}={v}" for k, v in {**machine_labels, **k8s_node_labels}.items()]
        all_labels = set()
        for item in NodeLabel:
            if item.value in ";".join(all_labels_raw):
                all_labels.add(item)
        if NodeLabel.kaas_master not in all_labels:
            # Workaround for missing labels for node role on worker nodes
            all_labels.add(NodeLabel.kaas_worker)
        if machine.name in ceph_osd_machines:
            # add ceph_osd label from kaascephcluster data
            all_labels.add(NodeLabel.ceph_osd)
        if machine.name == keepalive_machine.name:
            all_labels.add(NodeLabel.kaas_keepalive)
        result.append(MachineLabelData(machine=machine,
                                       machine_name=machine.name,
                                       node_name=k8s_node_name,
                                       all_labels=all_labels or None,
                                       all_labels_raw=all_labels_raw))
    return result


def collect_machine_sequence():
    """
    Generate machine sequences by expected HA_NODE_REBOOT_LABELS

    Returns: TargetMachine objects list

    """
    if not settings.HA_NODE_REBOOT_LABELS:
        LOG.error("No HA node labels provided.")
        return []
    result = []
    actual_labels = collect_machines_labels()
    expected_labels = []
    wrong_labels = []
    for label in settings.HA_NODE_REBOOT_LABELS:
        if NodeLabel.get_by_name(label):
            expected_labels.append(NodeLabel.get_by_name(label))
        else:
            wrong_labels.append(label)
    assert not wrong_labels, f"Can't parse next labels from HA_NODE_REBOOT_LABELS: {wrong_labels}"
    # check if expected labels exist in target machines
    for item in expected_labels:
        if not [x for x in actual_labels if item in x.all_labels]:
            # raise exception if no machine found with expected label
            raise Exception(f"Cannot find any node with label: '{item.name}'")
    for item in actual_labels:
        labels = [i for i in expected_labels if i in item.all_labels]
        if labels:
            result.append(TargetMachineData(labels=labels, machine=item))

    res_table = {item.machine_name: [label.name for label in item.all_labels] for item in actual_labels}
    LOG.info("Expected labels: {}".format(", ".join([label.name for label in expected_labels])))
    LOG.info('Actual node labels:\n' + tabulate(res_table.items(), headers=["Node", "Labels"], tablefmt="presto"))
    return result


def collect_ceph_role_machines():
    """
    Collect Ceph role for machine

    Returns: list of MachineCephRoleData

    """
    kaas_manager = Manager(kubeconfig=settings.KUBECONFIG_PATH)
    managed_ns = kaas_manager.get_namespace(settings.TARGET_NAMESPACE)
    child_cluster = managed_ns.get_cluster(settings.TARGET_CLUSTER)

    ceph_pods = child_cluster.k8sclient.pods.list(namespace=settings.ROOK_CEPH_NS)
    machine_mapping = {m.data.get('status', {}).get('instanceName'): m for m in child_cluster.get_machines()}
    roles_mapping = dict()
    mgr_pods = list(filter(lambda x: 'mgr' in x.name and x.data['status']['phase'] == 'Running', ceph_pods))
    mon_pods = list(filter(lambda x: 'mon' in x.name and x.data['status']['phase'] == 'Running', ceph_pods))
    osd_pods = list(filter(lambda x: 'osd' in x.name and x.data['status']['phase'] == 'Running', ceph_pods))
    csi_cephfsp_pods = list(filter(
        lambda x: ('csi-cephfsplugin' in x.name and 'provisioner' not in x.name) and x.data['status'][
            'phase'] == 'Running', ceph_pods))
    csi_rbd_pods = list(filter(
        lambda x: ('csi-rbdplugin' in x.name and 'provisioner' not in x.name) and x.data['status'][
            'phase'] == 'Running', ceph_pods))
    roles_mapping.update(
        {"mgr": mgr_pods, "mon": mon_pods, "osd": osd_pods, "csi-cephfsp": csi_cephfsp_pods, "csi-rbd": csi_rbd_pods})

    result = list()
    for role, pods in roles_mapping.items():
        for pod in pods:
            machine = machine_mapping[pod.node_name]
            container = next(
                filter(lambda x: role in x.get('name'), pod.data.get('status', {}).get('container_statuses')),
                None)
            container_id = container['container_id'].split("//")[1]
            result.append(MachineCephRoleData(
                machine=machine,
                k8s_node_name=pod.node_name,
                role=role,
                pod=pod,
                container_id=container_id
            ))
    return result


def id_label_node(param):
    if hasattr(param, 'labels'):
        return 'LABEL={};NODE={}'.format(";".join(map(lambda x: x.name, param.labels)), param.machine.machine_name)
    else:
        return ""


def collect_pods_restart_count(pods):
    # Collect restart number for containers
    restarts = {}
    for pod in pods:
        restarts[pod.data['metadata']['name']] = pod.get_restarts_number()
    return restarts


def compare_restarts_number(pods, restarts):
    errors = []
    # current_restarts = collect_pods_restart_count(pods)
    for pod in pods:
        name = pod.data['metadata']['name']
        # Check pod data is present in dict
        initial_restart_number = restarts.get(name)
        if initial_restart_number is None:
            errors.append(f"Restart number for pod {name} wasn't found.")
            continue
        # Compare restart number
        restart_number = pod.get_restarts_number()
        if initial_restart_number != restart_number:
            LOG.warning(f"Compare restarts for pod {name}: {initial_restart_number} "
                        f"-> {restart_number}")
            errors.append(f"Restarts number for pod {name} was increased: from {initial_restart_number} to "
                          f"{restart_number}")
        else:
            LOG.info(f"Compare restarts for pod {name}: {initial_restart_number} "
                     f"-> {restart_number}")
    assert not errors, "\n".join(errors)
