from si_tests import settings
from si_tests import logger
from si_tests.utils import waiters
from yaml import safe_dump
from os.path import join
from deepdiff import DeepDiff
from kubernetes.client.exceptions import ApiException as clientApiException

LOG = logger.logger


def get_app_sources(app):
    sources = list()
    in_app_sources = app["status"]["sync"]["comparedTo"].get("sources", list())
    in_app_revisions = app["status"]["sync"].get("revisions", list())
    if len(in_app_sources) > 0:
        counter = 0
        for source in in_app_sources:
            try:
                sources.append("|".join([source[key]
                                         for key in ("repoURL", "path", "chart", "targetRevision")
                                         if key in source] + [in_app_revisions[counter]]))
                counter += 1
            except IndexError:
                # in_app_revisions can be empty (status.sync.revisions can be missing due to issues with sync)
                pass
    else:
        try:
            source = app["status"]["sync"]["comparedTo"]["source"]
            revision = app["status"]["sync"]["revision"]
            sources.append("|".join([source[key]
                                     for key in ("repoURL", "path", "chart", "targetRevision")
                                     if key in source] + [revision]))
        except KeyError:
            # status.sync.revision can be missing due to issues with sync
            pass
    return sorted(sources)


def get_app_workload_namespaces(app):
    namespaces = set()
    for resource in app["status"].get("resources", dict()):
        resource_class = "/".join(resource[key] for key in ("group", "version", "kind") if key in resource)
        if resource_class in (
                "v1/Pod",
                "apps/v1/ReplicaSet",
                "apps/v1/Deployment",
                "apps/v1/StatefulSet",
                "apps/v1/DaemonSet",
                "batch/v1/Job",
                "batch/v1/CronJob",
                "acid.zalan.do/v1/postgresql"):
            try:
                namespaces.add(resource["namespace"])
            except KeyError:
                pass
    return sorted(namespaces)


def get_app_spec_images(app: dict):
    """
    Gets list of images from spec Argo CD apps.
    Does the same as get_box_images.py from box-cluster repo.
    Implemented primarily for box-cluster apps and unable to extract images from other apps.
    :param app: a dictionary representing an Argo CD app
    :return: sorted list of images
    """
    app_spec = app.get("spec", dict())
    sources: list = app_spec.get("sources", list())
    single_source: dict = app_spec.get("source", dict())
    if single_source:
        sources.append(single_source)
    images = set()
    source: dict
    for source in sources:
        repository = None
        tag = None
        for param in source.get("helm", dict()).get("parameters", list()):
            if 'image.repository' in param.get('name', ''):
                repository = param.get('value', '').strip('"')
            elif 'image.tag' in param.get('name', ''):
                tag = param.get('value', '').strip('"')
        if repository and tag:
            images.add(f"{repository}:{tag}")
        values_obj_image = source
        for key in ("helm", "valuesObject", "helm-common-chart", "image"):
            values_obj_image = values_obj_image.get(key, dict())
        repository, tag = (values_obj_image.get(key) for key in ("repository", "tag"))
        if repository and tag:
            images.add(f"{repository}:{tag}")
    return sorted(images)


def get_argo_apps_state(client, label_selector):
    apps = client.argocd_applications.list_raw(label_selector=label_selector).to_dict()['items']
    state = dict(
        (
            f"{app['metadata']['namespace']}/{app['metadata']['name']}",
            dict(
                name=app["metadata"]["name"],
                namespace=app["metadata"]["namespace"],
                sync_status=app["status"]["sync"]["status"],
                health_status=app["status"]["health"]["status"],
                spec_images=get_app_spec_images(app),
                summary_images=app["status"].get("summary", dict()).get("images", list()),
                sources=get_app_sources(app),
                workload_namespaces=get_app_workload_namespaces(app)
            )
        )
        for app in apps)
    return state


def get_argo_apps_images(client, apps_namespaces):
    images = set()
    for pod in client.pods.list_raw().to_dict()["items"]:
        try:
            if pod["metadata"]["namespace"] in apps_namespaces:
                for container in (pod["status"].get("container_statuses", [])
                                  + pod["status"].get("init_container_statuses ", [])):
                    images.add(container["image"])
        except clientApiException as e:
            if e.status == 404:
                pass
            else:
                raise
    return sorted(images)


def get_app_sources_static(app_sources):
    """
    This function processes apps that were processed by get_argo_apps_state
    :param app_sources: an app from list returned by get_argo_apps_state
    :return: sorted list of sources without revision
    """
    rv = list()
    for source in app_sources:
        rv.append("|".join(source.split("|")[0:3]))
    return sorted(rv)


def wait_for_app_sync_status(client, name, namespace, expected_sync_status, expected_sources):
    app = client.argocd_applications.get(name, namespace)
    app_sync_status = app.data["status"]["sync"]["status"]
    app_sources = get_app_sources_static(get_app_sources(app.data))
    rv = False
    if app_sync_status == expected_sync_status and set(expected_sources) <= set(app_sources):
        rv = True
        msg = "\n".join((
            f"The app {namespace}/{name} got expected sync status.",
            f"Status: {app_sync_status}",
            "Sources:",
            safe_dump(app_sources)
        ))
    else:
        msg = "\n".join((
            f"Wait for app {namespace}/{name} to get expected sync status",
            f"Current status: {app_sync_status}",
            "Current sources:",
            safe_dump(app_sources),
            f"Expected status: {expected_sync_status}",
            "Expected sources:",
            safe_dump(expected_sources)
        ))
    LOG.info(msg)
    return rv


class AppUpgradeChecker:
    PERFECT = "perfect"
    AS_BEFORE = "as_before"

    def __init__(self, client, app_state_base, images_manifest_upgrade, artifacts_dir,
                 app_label_selector, app_expected_status):
        self.client = client
        self.app_state_base = app_state_base
        self.app_state_upgrade = {"images_manifest": images_manifest_upgrade}
        self.artifacts_dir = artifacts_dir
        self.app_label_selector = app_label_selector
        if app_expected_status not in (self.__class__.PERFECT, self.__class__.AS_BEFORE):
            raise RuntimeError("app_expected_status is not in (\"perfect\", \"as_before\")")
        else:
            self.app_expected_status = app_expected_status
        self.waiters = {self._check_new_images: False, self._check_statuses: False}

    @staticmethod
    def _get_statuses(app_state):
        return {k: {kk: vv for kk, vv in v.items()
                    if kk in ("health_status", "sync_status")} for k, v in app_state["apps"].items()}

    def _refresh_state(self):
        self.app_state_upgrade["apps"] = get_argo_apps_state(self.client, self.app_label_selector)
        self.app_state_upgrade["images_live_cluster"] = get_argo_apps_images(
            self.client,
            set(namespace for app in self.app_state_upgrade["apps"].values()
                for namespace in app["workload_namespaces"]))
        with open(join(self.artifacts_dir, "argo_apps_state_upgrade.yaml"), "w") as argo_apps_state_dump:
            argo_apps_state_dump.write(safe_dump(self.app_state_upgrade))

    def _check_new_images(self):
        expected_images = set(self.app_state_upgrade["images_manifest"])
        current_images = set(self.app_state_upgrade["images_live_cluster"])
        in_cluster_images = expected_images & current_images
        if in_cluster_images == expected_images:
            LOG.info("All expected images are in cluster")
            return True
        else:
            LOG.info("\n".join((
                "Wait for images to appear in cluster",
                safe_dump(sorted(expected_images - in_cluster_images))
            )))
            return False

    def _check_statuses(self):
        expected_statuses = self._get_statuses(self.app_state_base)
        if self.app_expected_status == self.__class__.PERFECT:
            for k in expected_statuses:
                expected_statuses[k]["health_status"] = "Healthy"
                expected_statuses[k]["sync_status"] = "Synced"
        current_statuses = self._get_statuses(self.app_state_upgrade)
        for k in tuple(current_statuses):
            if k not in expected_statuses:
                if self.app_expected_status == self.__class__.PERFECT:
                    expected_statuses[k] = {
                        "health_status": "Healthy",
                        "sync_status": "Synced"
                    }
                elif self.app_expected_status == self.__class__.AS_BEFORE:
                    # I can't expect a status of an app that I haven't seen previously
                    del current_statuses[k]
            elif current_statuses[k]["health_status"] == "Healthy" and current_statuses[k]["sync_status"] == "Synced":
                # Nothing to expect if it is already perfect
                del current_statuses[k]
                del expected_statuses[k]
        diff = DeepDiff(expected_statuses, current_statuses)
        if diff:
            LOG.info("\n".join((
                "Wait for apps to get expected statuses",
                "Current discrepancies:",
                diff.pretty()
            )))
            return False
        else:
            LOG.info("All apps are in expected statuses")
            return True

    def wait(self):
        self._refresh_state()
        for waiter in self.waiters:
            if self.waiters[waiter]:
                continue
            self.waiters[waiter] = waiter()
            if not self.waiters[waiter]:
                return False
        return True


def test_argo_apps_upgrade(kcm_manager, show_step):
    """Switch the root app to the upgrade ref and check if the cluster upgrade succeed

    Scenario:
        1. Check pods before upgrade
        2. Get app data from live cluster before upgrade
        3. Get app data from manifests at base ref
        4. Check base ref
        5. Check upgrade ref
        6. Patch the root app
        7. Wait for cluster to get into expected state
        8. Check pods after upgrade
    """

    ns = kcm_manager.get_namespace(settings.TARGET_NAMESPACE)
    cld = ns.get_cluster_deployment(settings.TARGET_CLD)
    client = cld.k8sclient

    # 1. Check pods before upgrade
    show_step(1)
    cld.check.check_k8s_pods()

    # 2. Get app data from live cluster before upgrade
    show_step(2)
    app_state_base_no_selector = {"apps": get_argo_apps_state(client, "")}
    app_state_base = {"apps": get_argo_apps_state(client, settings.KSI_INFERENCE_APP_LABEL_SELECTOR)}
    app_state_base["images_live_cluster"] = get_argo_apps_images(
        client,
        set(namespace for app in app_state_base["apps"].values() for namespace in app["workload_namespaces"]))

    # 3. Get app data from manifests at base ref
    show_step(3)
    app_state_base["images_manifest"] = settings.KSI_INFERENCE_BOX_CLUSTER_BASE_IMAGES.split("|")
    with open(join(settings.ARTIFACTS_DIR, "argo_apps_state_base.yaml"), "w") as argo_apps_state_dump:
        argo_apps_state_dump.write(safe_dump(app_state_base))

    # 4. Check base ref
    show_step(4)
    root_app = app_state_base_no_selector["apps"][
        f"{settings.KSI_INFERENCE_ROOT_APP_NAMESPACE}/{settings.KSI_INFERENCE_ROOT_APP_NAME}"]
    base_source = get_app_sources_static(root_app["sources"])[0]
    expected_base_source = "|".join((
        settings.KSI_INFERENCE_ROOT_APP_REPO_URL,
        settings.KSI_INFERENCE_ROOT_APP_PATH,
        settings.KSI_INFERENCE_ROOT_APP_BASE_TARGET_REVISION
    ))
    assert base_source == expected_base_source, f"Base source {base_source} is not {expected_base_source}"

    # 5. Check upgrade ref
    show_step(5)
    upgrade_source = "|".join((
        settings.KSI_INFERENCE_ROOT_APP_REPO_URL,
        settings.KSI_INFERENCE_ROOT_APP_PATH,
        settings.KSI_INFERENCE_ROOT_APP_UPGRADE_TARGET_REVISION
    ))
    assert base_source != upgrade_source, f"Base and upgrade sources are equal (f{upgrade_source})"

    # 6. Patch the root app
    show_step(6)
    client.argocd_applications.update(root_app["name"], root_app["namespace"],
                                      {
                                          "spec": {
                                              "source": {
                                                  "targetRevision":
                                                      settings.KSI_INFERENCE_ROOT_APP_UPGRADE_TARGET_REVISION
                                              }
                                          }
                                      })
    waiters.wait(wait_for_app_sync_status,
                 predicate_args=(client,
                                 settings.KSI_INFERENCE_ROOT_APP_NAME,
                                 settings.KSI_INFERENCE_ROOT_APP_NAMESPACE,
                                 "Synced",
                                 (upgrade_source,)),
                 interval=15,
                 timeout=120)

    # 7. Wait for cluster to get into expected state
    show_step(7)
    checker = AppUpgradeChecker(client,
                                app_state_base,
                                settings.KSI_INFERENCE_BOX_CLUSTER_UPGRADE_IMAGES.split("|"),
                                settings.ARTIFACTS_DIR,
                                settings.KSI_INFERENCE_APP_LABEL_SELECTOR,
                                settings.KSI_INFERENCE_APP_EXPECTED_STATUS)
    waiters.wait(checker.wait, interval=60, timeout=3600)

    # 8. Check pods after upgrade
    show_step(8)
    cld.check.check_k8s_pods()
