#    Copyright 2019 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from contextlib import contextmanager
import signal
import os
import paramiko
from io import StringIO

from ksi_runtest import logger

LOG = logger.logger


def load_keyfile(file_path):
    """
    Load a private key (RSA, DSA, ECDSA or Ed25519), auto-detecting the type.

    :param file_path: the private file path
    :return: ugly dict, with  paramiko.PKey subclass instance and raw strings for pub\privatekeys
    :raises ValueError: if no supported key type matches
    """
    key_classes = [
        paramiko.RSAKey,
        paramiko.ECDSAKey,
        paramiko.Ed25519Key,
    ]
    with open(file_path, 'r') as private_key_file:
        private = private_key_file.read()
    for cls in key_classes:
        try:
            key = cls.from_private_key(file_obj=StringIO(private))
            public_str = f"{key.get_name()} {key.get_base64()} ksiuser-framework"
            LOG.debug(f"Detected key type: {cls.__name__}")
            return {'private_str': private,
                    'private_obj': key,
                    'public': public_str}
        except paramiko.SSHException:
            LOG.debug(f'File is not ssh key of format {cls.__name__}')
            continue

    raise ValueError(f"Unable to parse key {file_path}: unsupported format")


@contextmanager
def pushd(path):
    current_dir = os.getcwd()
    try:
        os.chdir(os.path.expanduser(path))
        yield
    finally:
        os.chdir(current_dir)


class RunLimit(object):
    # pity copy-paste from  si_tests/utils/waiters.py
    def __init__(self, timeout=60, timeout_msg='Timeout',
                 status_msg_function=None):
        self.seconds = int(timeout)
        self.error_message = timeout_msg
        self.status_msg_function = status_msg_function
        LOG.debug("RunLimit.__init__(timeout={0}, timeout_msg='{1}'"
                  .format(timeout, timeout_msg))

    def handle_timeout(self, signum, frame):
        LOG.debug("RunLimit.handle_timeout reached!")
        err_msg = self.error_message.format(spent=self.seconds)
        if self.status_msg_function is not None:
            err_msg += str(self.status_msg_function())

        raise TimeoutError(err_msg)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)
        LOG.debug("RunLimit.__enter__(seconds={0}".format(self.seconds))

    def __exit__(self, exc_type, value, traceback):
        time_remained = signal.alarm(0)
        LOG.debug("RunLimit.__exit__ , remained '{0}' sec"
                  .format(time_remained))
