#    Copyright 2013 - 2016 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from __future__ import absolute_import

import signal
import socket
import time

import exec_helpers

from si_tests import logger

from si_tests.utils import exceptions

LOG = logger.logger


def icmp_ping(host, timeout=1, ssh_client=None, bastion_kwargs=None, verbose=True):
    """Run ICMP ping

    returns True if host is pingable
    False - otherwise.
    """
    cmd = f"ping -c 1 -W '{timeout}' '{host}'"
    if ssh_client:
        if bastion_kwargs:
            bastion_kwargs['command'] = cmd
            result = ssh_client.execute_through_host(**bastion_kwargs)
        else:
            result = ssh_client.execute(cmd)
    else:
        result = exec_helpers.Subprocess().execute(cmd, verbose=verbose)
    return result.exit_code == 0


def tcp_ping_(host, port, timeout=None):
    s = socket.socket()
    if timeout:
        s.settimeout(timeout)
    s.connect((str(host), int(port)))
    s.close()


def tcp_ping(host, port, timeout=None):
    """Run TCP ping

    returns True if TCP connection to specified host and port
    can be established
    False - otherwise.
    """
    try:
        tcp_ping_(host, port, timeout)
    except socket.error:
        return False
    return True


class RunLimit(object):
    def __init__(self, timeout=60, timeout_msg='Timeout',
                 status_msg_function=None):
        self.seconds = int(timeout)
        self.error_message = timeout_msg
        self.status_msg_function = status_msg_function
        logger.debug("RunLimit.__init__(timeout={0}, timeout_msg='{1}'"
                     .format(timeout, timeout_msg))

    def handle_timeout(self, signum, frame):
        logger.debug("RunLimit.handle_timeout reached!")
        err_msg = self.error_message.format(spent=self.seconds)
        if self.status_msg_function is not None:
            err_msg += str(self.status_msg_function())

        raise exceptions.TimeoutError(err_msg)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)
        logger.debug("RunLimit.__enter__(seconds={0}".format(self.seconds))

    def __exit__(self, exc_type, value, traceback):
        time_remained = signal.alarm(0)
        logger.debug("RunLimit.__exit__ , remained '{0}' sec"
                     .format(time_remained))


def _check_wait_args(predicate,
                     predicate_args,
                     predicate_kwargs,
                     interval,
                     timeout):

    if not callable(predicate):
        raise TypeError("Not callable raising_predicate has been posted: '{0}'"
                        .format(predicate))
    if not isinstance(predicate_args, (list, tuple)):
        raise TypeError("Incorrect predicate_args type for '{0}', should be "
                        "list or tuple, got '{1}'"
                        .format(predicate, type(predicate_args)))
    if not isinstance(predicate_kwargs, dict):
        raise TypeError("Incorrect predicate_kwargs type, should be dict, "
                        "got {}".format(type(predicate_kwargs)))
    if interval <= 0:
        raise ValueError("For '{0}(*{1}, **{2})', waiting interval '{3}'sec is"
                         " wrong".format(predicate,
                                         predicate_args,
                                         predicate_kwargs,
                                         interval))
    if timeout <= 0:
        raise ValueError("For '{0}(*{1}, **{2})', timeout '{3}'sec is "
                         "wrong".format(predicate,
                                        predicate_args,
                                        predicate_kwargs,
                                        timeout))


def wait(predicate, interval=5, timeout=60,
         timeout_msg="Waiting timed out",
         predicate_args=None,
         predicate_kwargs=None,
         status_msg_function=None):
    """Wait until predicate will become True.

    Options:

    :param predicate: - function that return True or False
    :param interval: - seconds between checks
    :param timeout:  - raise TimeoutError if predicate won't become True after
                      this amount of seconds
    :param timeout_msg: - text of the TimeoutError
    :param predicate_args: - positional arguments for given predicate wrapped
                            in list or tuple
    :param predicate_kwargs: - dict with named arguments for the predicate
    :param status_msg_function: - custom func to process error message

    """
    predicate_args = predicate_args or []
    predicate_kwargs = predicate_kwargs or {}
    _check_wait_args(predicate, predicate_args, predicate_kwargs,
                     interval, timeout)
    msg = (
        "{msg}\nWaited for pass {cmd}: {spent} seconds. "
        "".format(
            msg=timeout_msg,
            cmd=repr(predicate),
            spent="{spent:0.3f}"
        ))

    start_time = time.time()
    with RunLimit(timeout, msg, status_msg_function):
        while True:
            result = predicate(*predicate_args, **predicate_kwargs)
            if result:
                logger.debug("wait() completed with result='{0}'"
                             .format(result))
                return result

            if start_time + timeout < time.time():
                err_msg = msg.format(spent=time.time() - start_time)
                if status_msg_function:
                    err_msg += str(status_msg_function())
                LOG.error(err_msg)
                raise exceptions.TimeoutError(err_msg)

            time.sleep(interval)


def wait_pass(raising_predicate, expected=Exception,
              interval=5, timeout=60, timeout_msg="Waiting timed out",
              predicate_args=None, predicate_kwargs=None,
              status_msg_function=None):
    """Wait unless predicate not failing with expected exception,
       return predicate result when no exception occur.

    Options:

    :param interval: - seconds between checks.
    :param timeout:  - raise TimeoutError if predicate still throwing expected
                       exception after this amount of seconds.
    :param timeout_msg: - text of the TimeoutError
    :param predicate_args: - positional arguments for given predicate wrapped
                            in list or tuple
    :param predicate_kwargs: - dict with named arguments for the predicate
    :param expected: Exception that can be ignored while waiting (its
                     possible to pass several using list/tuple

    """

    predicate_args = predicate_args or []
    predicate_kwargs = predicate_kwargs or {}
    _check_wait_args(raising_predicate, predicate_args, predicate_kwargs,
                     interval, timeout)
    msg = (
        "{msg}\nWaited for pass {cmd}: {spent} seconds."
        "".format(
            msg=timeout_msg,
            cmd=repr(raising_predicate),
            spent="{spent:0.3f}"
        ))

    start_time = time.time()
    with RunLimit(timeout, msg):
        while True:
            try:
                result = raising_predicate(*predicate_args, **predicate_kwargs)
                logger.debug("wait_pass() completed with result='{0}'"
                             .format(result))
                return result
            except expected as e:
                if start_time + timeout < time.time():
                    err_msg = msg.format(spent=time.time() - start_time)
                    if status_msg_function:
                        err_msg += str(status_msg_function())
                    LOG.error(err_msg)
                    raise exceptions.TimeoutError(err_msg)

                logger.debug("Got expected exception {!r}, continue".format(e))
                time.sleep(interval)


def wait_tcp(host, port, timeout, timeout_msg="Waiting timed out"):
    wait(tcp_ping, timeout=timeout, timeout_msg=timeout_msg,
         predicate_kwargs={'host': host, 'port': port})
