import time

from datetime import datetime, timezone

from si_tests import logger
from si_tests.clients.iam.keycloak_client import KeycloakUserClient

from prometheus_api_client import PrometheusConnect
from requests.auth import AuthBase, HTTPBasicAuth

LOG = logger.logger


class TokenAuth(AuthBase):
    def __init__(self, token, auth_scheme='Bearer'):
        self.token = token
        self.auth_scheme = auth_scheme

    def __call__(self, request):
        request.headers['Authorization'] = f'{self.auth_scheme} {self.token}'
        return request


class PrometheusClient(object):
    def __init__(self, host, port, proto, username="admin", password='topsecret', verify=False):
        self.url = '{0}://{1}:{2}'.format(proto, host, port)
        self.verify = verify
        self.username = username
        self.password = password
        self.login()

    def login(self):
        self.auth = HTTPBasicAuth(self.username, self.password)
        self.client = self.get_prom_client()

    def get_prom_client(self):
        return PrometheusConnect(url=self.url, auth=self.auth, disable_ssl=not self.verify)

    def get_query(self, query: str, timestamp=None):
        """
        Evaluates an instant query at a single point in time

        Args:
            query: [string] Prometheus expression query string.
            timestamp: [unix_timestamp] Evaluation timestamp

        Returns:
            List of dictionnaries (metrics)
        """
        params = {}
        if timestamp is not None:
            params.update({"time": timestamp})
        return self.client.custom_query(query, params=params)

    def get_query_range(self, query: str, start: float, end: float, step: float, timeout=None):
        """
        Evaluates an expression query over a range of time

        Args:
            query: [string] Prometheus expression query string
            start: [unix_timestamp] Start timestamp, inclusive.
            end: [unix_timestamp] End timestamp, inclusive.
            step: [float] Query resolution step width in duration format or float number of seconds.
            timeout: [duration] Evaluation timeout. Optional.

        Returns:
            List of dictionnaries (metrics)
        """
        params = {}
        if timeout:
            params.update({"timeout": timeout})
        return self.client.custom_query_range(
            query,
            start_time=datetime.fromtimestamp(start, tz=timezone.utc),
            end_time=datetime.fromtimestamp(end, tz=timezone.utc),
            step=step,
            params=params
        )

    def get_svc_probe_duration_seconds(self, namespace, service_name, start, end, step="1"):
        query = f"probe_duration_seconds{{blackbox_probe_module='tls', job='mcc-blackbox', pod='', " \
                f"namespace='{namespace}', service_name='{service_name}'}}"
        return self.get_query_range(query=query, start=start, end=end, step=step)

    def get_svc_probe_success(self, namespace, service_name, start, end, step="1"):
        query = f"probe_success{{blackbox_probe_module='tls', job='mcc-blackbox', pod='', " \
                f"namespace='{namespace}', service_name='{service_name}'}}"
        return self.get_query_range(query=query, start=start, end=end, step=step)

    def get_cloudprober_success_rate(
        self, probe, start, end, aggregate_method=None, aggregate_label="dst", interval="30", step="5"
    ):
        labels = "{" + f'probe="{probe}"' + "}"
        query = f"rate(cloudprober_success{labels}[{interval}s])/rate(cloudprober_total{labels}[{interval}s])"
        if aggregate_method:
            query = f"{aggregate_method}({query}) by ({aggregate_label})"
        return self.get_query_range(query=query, start=start, end=end, step=step)


class PrometheusClientOpenid(PrometheusClient):
    def __init__(self, host, port, proto, keycloak_ip,
                 username, password,
                 client_id='sl', verify=False):
        self.keycloak_ip = keycloak_ip
        self.client_id = client_id  # 'sl 'for StackLight services
        super().__init__(host, port, proto, username, password, verify)

    def login(self):
        """Authorize with username/password and get new token"""
        self.keycloak_client = KeycloakUserClient(self.keycloak_ip,
                                                  self.username,
                                                  self.password,
                                                  client_id=self.client_id,
                                                  realm_name="iam")
        result = self.keycloak_client.get_openid_token()
        LOG.debug(f"OpenID data for Prometheus UI:\n{result}")
        self.auth = TokenAuth(result['id_token'])
        self.token_creation_time = int(time.time())
        self.refresh_expire = result['expires_in']
        self.login_expire = result['refresh_expires_in']
        LOG.info("Got new Prometheus UI auth token")
        self.client = self.get_prom_client()

    def refresh_token(func):
        def wrap(self, *args, **kwargs):
            if int(time.time()) > self.token_creation_time + self.login_expire:
                LOG.info("Re-login with OpenID")
                self.login()
            elif int(time.time()) > self.token_creation_time + self.refresh_expire:
                LOG.info("Refresh OpenID token")
                result = self.keycloak_client.refresh_token()
                LOG.info("Prometheus UI auth token has been refreshed")
                self.auth = TokenAuth(result['access_token'])
                self.token_creation_time = int(time.time())
                self.client = self.get_prom_client()

            return func(self, *args, **kwargs)

        return wrap

    @refresh_token
    def get_query(self, query: str, timestamp=None):
        return super().get_query(query, timestamp=timestamp)

    @refresh_token
    def get_query_range(self, query: str, start: float, end: float, step: float, timeout=None):
        return super().get_query_range(query=query, start=start, end=end, step=step, timeout=timeout)
