from si_tests import logger
from si_tests.clients.http_client import HttpClient, HttpClientOpenId

LOG = logger.logger


class GrafanaClient(object):

    def __init__(self, host, port, proto,
                 username=None, password=None,
                 keycloak_ip=None,
                 client_id='sl', verify=False):
        self.url = '{0}://{1}:{2}'.format(proto, host, port)
        if keycloak_ip:
            self.client = HttpClientOpenId(
                self.url, keycloak_ip, username, password, client_id, verify
            )
        else:
            self.client = HttpClient(
                self.url, username, password, verify
            )

    def find_datasource(self, datasource_name):
        """
        :param datasource_name: string name of data source
        :return: dictionary with datasource data
        """
        get_datasource_path = f"/api/datasources/name/{datasource_name}"
        r = self.client.get(get_datasource_path)
        return r.json()

    def range_query_datasource(self, datasource_name, query, start, end, interval=15, max_data_points=1000):
        """
        Evaluates an expression query over a range of time

        Args:
            query: [string] expression query string
            start: [unix_timestamp] Start timestamp in seconds, floating number.
            end: [unix_timestamp] End timestamp in seconds, floating number.
            interval: [int] Query resolution step width in seconds.
            max_data_points: [int] Number of datapoints to get

        Returns:
            Dictionary with results
        """
        dts = self.find_datasource(datasource_name)
        url = "/api/ds/query"
        # Grafana api takes unix time in milliseconds as string
        _from = str(int(start) * 1000)
        _to = str(int(end) * 1000)
        body = {
            "queries": [
                {
                    "datasource": {
                        "uid": dts["uid"]
                    },
                    "expr": query,
                    "range": True,
                    "instant": False,
                    "interval": f"{interval}",
                    "maxDataPoints": max_data_points,
                }
            ],
            "from": _from,
            "to": _to
        }
        LOG.debug(f"Quering datasource {body}")
        res = self.client.post(url, json=body)
        LOG.debug(f"Got datasource response {res.content}")
        return res.json()

    def get_timeseries_data(self, dts_data, ref_id="A"):
        """
        Takes results for datasource query and returns list of metrics time series.
        Args:
            dts_data: dictionary with results
            ref_id: optional ref_id of the query
        Returns:
            data: Prometheus metric timeseries data in format:
            [{
                "metric" : {
                    "job" : "<job_name>",
                    "instance" : "<instance_name>",
                    "<label>" : "<value>",
                },
                "values" : [
                    [ 1435781430, "1" ],
                    [ 1435781445, "1" ],
                    [ 1435781460, "0.5" ]
                ]
            }]
        """
        def _get_labels(fields):
            for field in fields:
                if field["name"] == "Value":
                    return field.get("labels", {})

        def _get_values(data):
            res = []
            vals = data.get("values", [])
            if vals:
                for ts, val in zip(vals[0], vals[1]):
                    # grafana returns ts in milliseconds, converting to seconds
                    ts_s = ts / 1000
                    res.append([ts_s, val])
            return res

        frames = dts_data.get("results", {}).get(ref_id, {}).get("frames", [])
        res = []
        for frame in frames:
            labels = _get_labels(frame["schema"]["fields"])
            values = _get_values(frame["data"])
            res.append({"metric": labels, "values": values})
        return res

    def get_os_instance_probe_success(self, probe, start, end, dts_name="prometheus"):
        """
        Get Openstack instance availability probes success rate. Uses the same query
        as in Grafana Openstack Instance Availability dashboard.

        Args:
            probe:           [string] name of the Cloudprober probe
            start:           [unix_timestamp] Start timestamp in seconds, floating number.
            end:             [unix_timestamp] End timestamp in seconds, floating number.
            dts_name:        [string] name of datasource
        Returns:
            Prometheus timeseries data (List of dictionaries)
        """
        labels = "{" + f'probe="{probe}"' + "}"
        rate = f"rate(cloudprober_success{labels}[$__rate_interval])/rate(cloudprober_total{labels}[$__rate_interval])"
        query = f"clamp_max(max without (instance, openstack_hypervisor_hostname) ({rate}), 1) >= 0"
        dts_data = self.range_query_datasource(dts_name, query, start, end)
        return self.get_timeseries_data(dts_data)

    def get_os_portprobe_success(self, probe_type, start, end, dts_name="prometheus"):
        """
        Get Openstack port probes success rate. Uses the same query
        as in Grafana Openstack Instance Availability dashboard.

        Args:
            probe_type:           [string] type of the portprober probe
            start:           [unix_timestamp] Start timestamp in seconds, floating number.
            end:             [unix_timestamp] End timestamp in seconds, floating number.
            dts_name:        [string] name of datasource
        Returns:
            Prometheus timeseries data (List of dictionaries)
        """
        metric_filter = 'device_owner=~"compute:.+"'
        rate = (f"rate(portprober_{probe_type}_target_success{{{metric_filter}}}[$__rate_interval])/"
                f"rate(portprober_{probe_type}_target_total{{{metric_filter}}}[$__rate_interval])")
        query = f"clamp_max(max without (instance, probber_host) ({rate}), 1) >= 0"
        dts_data = self.range_query_datasource(dts_name, query, start, end)
        return self.get_timeseries_data(dts_data)
