#    Copyright 2024 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations

import pytest
import yaml

from si_tests import logger, settings
from si_tests.deployments.mcc_deploy.restart_daemonset import RestartChecker
from si_tests.utils import update_release_names

LOG = logger.logger


@pytest.fixture(scope='function')
def runtime_restart_checker(target_cluster):
    """Create simple daemonset workload and hold it during cluster upgrade. Check logs after."""
    if not settings.KAAS_RESTART_CHECKER_ENABLED:
        LOG.info('MCC restart checker not enabled. Daemonset will not be deployed.'
                 'To enable - set KAAS_RESTART_CHECKER_ENABLED to True.')
        yield
        return

    cl_version = target_cluster.clusterrelease_version
    target_version = settings.KAAS_CHILD_CLUSTER_UPDATE_RELEASE_NAME

    if not target_version:
        LOG.info("Target version was not passed. Will determine automatically")
        update_releases = list(update_release_names.generate_update_release_names())
        if update_releases:
            target_version = update_releases[-1]
            LOG.info(f"Detected target version: {target_version}")
        else:
            LOG.warning("Target version was not found. Skipping check")
            yield
            return
    is_maintenance_skip = target_cluster.is_skip_maintenance_set(
        cr_before=cl_version,
        target_clusterrelease=target_version)

    if not is_maintenance_skip:
        LOG.banner(f'Skip pod restart check as '
                   f'is_maintenance_skip is {is_maintenance_skip} for pair {cl_version}->{target_version}')
        yield
        return

    reboot_required = target_cluster.update_requires_reboot(
        cr_before=cl_version, target_clusterrelease=target_version)

    if cl_version.startswith('mke-') and reboot_required:
        LOG.info('Cluster release is MKE and update requires reboot. Skipping check.')
        yield
        return

    machines_runtimes = target_cluster.get_runtime_dict()
    is_all_docker = all(runtime == 'docker' for runtime in machines_runtimes.values())

    if is_all_docker:
        LOG.banner('All machines have docker runtime. Skipping runtime restart checks')
        yield
        return

    is_all_containerd = all(runtime == 'containerd' for runtime in machines_runtimes.values())
    is_mixed = not is_all_docker and not is_all_containerd
    containerd_nodes_names = []
    if is_mixed:
        LOG.warning("Cluster contains mixed runtime")
        LOG.info(f'Found next runtimes for machines: \n{yaml.dump(machines_runtimes)}')
        for machine, runtime in machines_runtimes.items():
            if runtime == 'containerd':
                k8s_node_name = target_cluster.get_machine(machine).get_k8s_node().name
                containerd_nodes_names.append(k8s_node_name)

    checker = RestartChecker(target_cluster)
    LOG.banner('Restart DS: Deploying checker daemonset')
    checker.deploy()
    checker.save_timestamps()
    ts_before = checker.timestamps
    LOG.banner('Restart DS: DS ready')

    yield
    LOG.banner('Restart DS: Check phase')
    LOG.info('Check daemonset timestamp logs to detect restarts')
    new_stamps = checker.get_timestamps()

    err_msg = (f"Runtime restart found. DS Timestamps before update and after are different. \n"
               f"Old raw timestamps: {ts_before} \nNew raw timestamps: {new_stamps}")
    restarted_ds = []
    if is_mixed:
        for node, timestamp in ts_before.items():
            if new_stamps[node] != ts_before[node]:
                if node in containerd_nodes_names:
                    LOG.error(f"Found restart on node {node} with containerd runtime")
                    restarted_ds.append({node: timestamp})
                else:
                    # For docker restart is expected
                    LOG.warning(f"Found restart on node {node} with docker runtime.")
        assert not restarted_ds, err_msg
    else:
        # Full containerd cluster. All timestamps should not be changed
        assert ts_before == new_stamps, err_msg
    LOG.banner('Restart DS: No restarts found. Check finished')
    LOG.banner('Restart DS: Cleanup phase')
    checker.cleanup()
