import time
import os
import json
import tarfile

import exec_helpers

from si_tests.clients.prometheus.prometheus_client import PrometheusClientOpenid
from si_tests.utils import utils
from si_tests import logger
from si_tests import settings


LOG = logger.logger

ENV_NAME = settings.LONGEVITY_ENV_NAME


def write_metrics_to_file(metric, type, step, prometheus_data,
                          target_dir, counter_metric=False, node_dict=None):
    with open(os.path.join(target_dir, f"{metric}.txt"), "w") as f:
        f.write(f"# TYPE {metric} {type}\n")
        for metric_dict in prometheus_data:
            timestamp = 1234567890  # Fri Feb 13 2009 23:31:30 GMT+0000
            metric_data = metric_dict["metric"]
            values_data = metric_dict["values"]

            # Prepare metric to OpenMetrics format: metric{label="value", ...}
            if node_dict:
                label_pairs = []
                node_type = None
                for key, value in metric_data.items():
                    if key == "__name__":
                        continue
                    if key == "node":
                        for node_type, node_names in node_dict.items():
                            if value in node_names:
                                label_pairs.append(f'machine_type="{node_type}"')
                    label_pairs.append(f'{key}="{value}"')
            else:
                label_pairs = [f'{key}="{value}"' for key, value in metric_data.items() if key != '__name__']

            label_pairs.append(f'env="{ENV_NAME}"')
            metric_str = f'{metric}{{{",".join(label_pairs)}}}'

            # Add values and timestamps for each metric
            rm_dublicat_values = -1
            for values in values_data:
                value = values[1]
                timestamp += step

                # Not add same value for counter metrics
                if counter_metric and value == rm_dublicat_values:
                    continue
                rm_dublicat_values = value

                f.write(f"{metric_str} {value} {timestamp}\n")
        f.write("# EOF")


def upload_metrics(host, src_metrics_path):
    ssh_username = settings.LONGEVITY_SSH_LOGIN
    ssh_key = settings.LONGEVITY_SSH_PRIV_KEY_FILE

    keys = utils.load_keyfile(ssh_key)
    pkey = utils.get_rsa_key(keys['private'])

    auth = exec_helpers.SSHAuth(username=ssh_username, key=pkey)
    ssh = exec_helpers.SSHClient(host=host, port=22, auth=auth)
    ssh.logger.addHandler(logger.console)
    dst_metric_path = f"/var/lib/prometheus/metrics-{ENV_NAME}.tar.gz"

    # Upload collected metrics to longevity promstack
    ssh.upload(src_metrics_path, dst_metric_path)

    # Unarchive and backfill prometheus metrics
    ssh.check_call("mkdir /tmp/prometheus; "
                   f"tar zxf {dst_metric_path} -C /tmp/prometheus; "
                   "cd /tmp/prometheus/metrics; "
                   "for metric in $(ls); do "
                   "echo ${metric}\n "
                   "promtool tsdb create-blocks-from openmetrics ${metric} /var/lib/prometheus/data; "
                   "sleep 1; "
                   "done; "
                   "rm -rf /tmp/prometheus",
                   timeout=18000,
                   verbose=True)

    # Restart prometheus to apply the new metrics
    ssh.check_call("kill $(pgrep prometheus); curl --retry 10 --retry-delay 10 --retry-connrefused localhost:9090")


def test_collect_prometheus_metrics(kaas_manager):
    """Collects Prometheus metrics and uploads them to the specified host.

    Scenario:
        1. Gets target cluster nodes by type
        2. Gets the Prometheus client
        3. Collects specified Prometheus metrics data from the target cluster.
        4. Organizes and prepares the collected metrics for storage.
        5. Creates an archive of the metrics data.
        6. Uploads the metrics archive to the specified host.
        7. Unarchives and backfills the metrics data on the remote host.
        8. Restarts Prometheus to apply the new metrics.
    """
    tg_name = settings.TARGET_CLUSTER
    tg_ns = kaas_manager.get_namespace(settings.TARGET_NAMESPACE)
    tg_cluster = tg_ns.get_cluster(tg_name)
    artifacts_dir = settings.ARTIFACTS_DIR
    host = settings.LONGEVITY_HOST
    metric_step = settings.LONGEVITY_METRICS_STEP
    metrics = json.loads(settings.LONGEVITY_METRICS)
    start = settings.LONGEVITY_START_TIMESTAMP if settings.LONGEVITY_START_TIMESTAMP else (time.time() - 3600)
    end = settings.LONGEVITY_END_TIMESTAMP if settings.LONGEVITY_END_TIMESTAMP else time.time()

    # Directory for storing metrics
    target_dir = os.path.join(artifacts_dir, "metrics")
    if not os.path.isdir(target_dir):
        os.makedirs(target_dir)

    # Get nodes by type
    sl_node_public_ip = None
    control_node_names = []
    sl_node_names = []
    worker_node_names = []
    child_cluster_machines = tg_cluster.get_machines()

    for machine in child_cluster_machines:
        # Get master nodes name
        if machine.is_machine_type("control"):
            control_node_names.append(machine.get_k8s_node_name())
        # Get stacklight nodes name
        elif machine.has_k8s_labels({"stacklight": "enabled"}):
            if sl_node_public_ip is None:
                sl_node_public_ip = machine.public_ip
            sl_node_names.append(machine.get_k8s_node_name())
        # Get worker nodes name
        else:
            worker_node_names.append(machine.get_k8s_node_name())

    LOG.info(f"Managment nodes: {control_node_names}")
    LOG.info(f"Stacklight nodes: {sl_node_names}")
    LOG.info(f"Worker nodes: {worker_node_names}")

    # Get prometheus client
    if settings.RUN_ON_REMOTE:
        proto = "https"
        keycloak_user_name = "operator"

        keycloak_ip = kaas_manager.get_keycloak_ip()
        password = kaas_manager.si_config.get_keycloak_user_password(keycloak_user_name)
        prometheus_svc = tg_cluster.k8sclient.services.get(name="iam-proxy-prometheus", namespace="stacklight")
        prometheus_svc_node_port = next((s.node_port for s in prometheus_svc.get_ports() if s.name == proto), None)

        prometheus_client = PrometheusClientOpenid(host=sl_node_public_ip,
                                                   port=prometheus_svc_node_port,
                                                   proto=proto,
                                                   keycloak_ip=keycloak_ip,
                                                   username=keycloak_user_name,
                                                   password=password)
    else:
        prometheus_client = tg_cluster.prometheusclient

    # node type: [node names]
    node_dict = {
        "control": control_node_names,
        "stacklight": sl_node_names,
        "worker": worker_node_names
    }

    for metric_name, metric_type in metrics.items():
        counter_metric = metric_type == "counter"
        LOG.info(f'Get "{metric_name}" metric data from prometheus')
        for i in range(1, 4):
            try:
                prometheus_response = prometheus_client.get_query_range(
                    query=metric_name,
                    start=start,
                    end=end,
                    step=metric_step)
                break
            except Exception as e:
                LOG.error(f"An unexpected error occurred. Exception: '{e}', retry {i}")
                time.sleep(60)
        else:
            raise Exception(f"Failed to get {metric_name} metric after multiple retries. Aborting.")

        # Prepare and write metric to file
        if metric_name == "kube_node_info":
            write_metrics_to_file(
                metric=metric_name,
                type=metric_type,
                step=metric_step,
                prometheus_data=prometheus_response,
                target_dir=target_dir,
                counter_metric=counter_metric,
                node_dict=node_dict)
        else:
            write_metrics_to_file(
                metric=metric_name,
                type=metric_type,
                step=metric_step,
                prometheus_data=prometheus_response,
                target_dir=target_dir,
                counter_metric=counter_metric,
                node_dict=None)

    LOG.info("Create archive with metrics")
    metrics_path = os.path.join(artifacts_dir, "metrics.tar.gz")
    with tarfile.open(metrics_path, "w:gz") as tar:
        tar.add(target_dir, arcname=os.path.basename(target_dir))

    LOG.info(f"Upload collected metrics to {host}")
    upload_metrics(host, metrics_path)

    LOG.info("Done")
