#    Author: Alex Savatieiev (osavatieiev@mirantis.com; a.savex@gmail.com)
#    Copyright 2019-2022 Mirantis, Inc.
import queue
import subprocess
import traceback
import threading

from time import sleep
from .exception import TimeoutException, CheckerException
from .other import shell, piped_shell
from .log import logger, logger_cli


# We do not use paramiko here to preserve system level ssh config
def ssh_shell_p(
    command,
    host,
    username=None,
    keypath=None,
    port=None,
    silent=False,
    piped=False,
    use_sudo=False
):
    _ssh_cmd = []
    _ssh_cmd.append("ssh")
    if silent:
        _ssh_cmd.append("-q")
    # Build SSH cmd
    if keypath:
        _ssh_cmd.append("-i " + keypath)
        _ssh_cmd.append("-o " + "IdentitiesOnly=yes")
    if port:
        _ssh_cmd.append("-p " + str(port))
    if username:
        _ssh_cmd.append(username+'@'+host)
    else:
        _ssh_cmd.append(host)

    if use_sudo:
        _ssh_cmd.append("sudo")

    _ssh_cmd.append(command)

    _ssh_cmd = " ".join(_ssh_cmd)
    if not piped:
        return shell(_ssh_cmd)
    else:
        return piped_shell(_ssh_cmd)


def scp_p(
    source,
    target,
    port=None,
    keypath=None,
    silent=False,
    piped=False
):
    _scp_cmd = []
    _scp_cmd.append("scp")
    if port:
        _scp_cmd.append("-P " + str(port))
    if silent:
        _scp_cmd.append("-q")
    # Build SSH cmd
    if keypath:
        _scp_cmd.append("-i " + keypath)
    _scp_cmd.append(source)
    _scp_cmd.append(target)
    _scp_cmd = " ".join(_scp_cmd)
    if not piped:
        return shell(_scp_cmd)
    else:
        return piped_shell(_scp_cmd)


def output_reader(_stdout, outq):
    for line in iter(_stdout.readline, b''):
        outq.put(line.decode('utf-8'))


# Base class for all SSH related actions
class SshBase(object):
    def __init__(
        self,
        tgt_host,
        user=None,
        keypath=None,
        port=None,
        timeout=15,
        silent=False,
        piped=False
    ):
        self._cmd = ["ssh"]
        self.timeout = timeout
        self.port = port if port else 22
        self.host = tgt_host
        self.username = user
        self.keypath = keypath
        self.silent = silent
        self.piped = piped
        self.output = []

        self._options = ["-tt"]
        if self.keypath:
            self._options += ["-i", self.keypath]
        if self.port:
            self._options += ["-p", str(self.port)]
        self._extra_options = [
            "-o", "UserKnownHostsFile=/dev/null",
            "-o", "StrictHostKeyChecking=no",
            "-o", "IdentitiesOnly=yes"
        ]

        self._host_uri = ""
        if self.username:
            self._host_uri = self.username + "@" + self.host
        else:
            self._host_uri = self.host

    def _connect(self, banner="Welcome"):
        if not isinstance(banner, str):
            raise CheckerException(
                "Invalid SSH banner type: '{}'".format(type(banner))
            )
        logger.debug("... connecting")
        while True:
            try:
                line = self.outq.get(block=False)
                self.output.append(line)
                if line.startswith(banner):
                    break
            except queue.Empty:
                logger.debug("... {} sec".format(self.timeout))
                sleep(1)
                self.timeout -= 1
                if not self.timeout:
                    logger.debug(
                        "...timed out after {} sec".format(str(self.timeout))
                    )
                    return False
        logger.debug("... connected")
        return True

    def _wait_for_string(self, string):
        logger.debug("... waiting for '{}'".format(string))
        while True:
            try:
                line = self.outq.get(block=False)
                line = line.decode() if isinstance(line, bytes) else line
                self.output.append(line)
                if not line.startswith(string):
                    continue
                else:
                    break
            except queue.Empty:
                logger.debug("... {} sec".format(self.timeout))
                sleep(1)
                self.timeout -= 1
                if not self.timeout:
                    logger.debug(
                        "... timed out after {} sec".format(str(self.timeout))
                    )
                    return False
        logger.debug("... found")
        return True

    def _init_connection(self, cmd):
        self._proc = subprocess.Popen(
            cmd,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=False,
            bufsize=0
        )
        # Create thread safe output getter
        self.outq = queue.Queue()
        self._t = threading.Thread(
            target=output_reader,
            args=(self._proc.stdout, self.outq)
        )
        self._t.start()

        # Track if there is an yes/no
        if not self._connect():
            raise TimeoutException(
                "SSH connection to '{}'".format(self.host)
            )

        self.input = self._proc.stdin
        self.get_output()
        logger.debug(
            "Connected. Banners:\n{}".format(
                "".join(self.flush_output())
            )
        )

    def _end_connection(self):
        # Kill the ssh process if it is alive
        if self._proc.poll() is None:
            self._proc.kill()
        self.get_output()

        return

    def do(self, cmd, timeout=30, sudo=False, strip_cmd=True):
        cmd = cmd if isinstance(cmd, bytes) else bytes(cmd.encode('utf-8'))
        logger.debug("... ssh: '{}'".format(cmd))
        if sudo:
            _cmd = b"sudo " + cmd
        else:
            _cmd = cmd
        # run command
        self.input.write(_cmd + b'\n')
        # wait for completion
        self.wait_ready(_cmd, timeout=timeout)
        self.get_output()
        _output = self.flush_output().replace('\r', '')
        if strip_cmd:
            return "\n".join(_output.splitlines()[1:])
        else:
            return _output

    def get_output(self):
        while True:
            try:
                line = self.outq.get(block=False)
                line = str(line) if isinstance(line, bytes) else line
                self.output.append(line)
            except queue.Empty:
                return self.output
        return None

    def flush_output(self, as_string=True):
        _out = self.output
        self.output = []
        if as_string:
            return "".join(_out)
        else:
            return _out

    def wait_ready(self, cmd, timeout=60):
        # Wait for command to finish inside SSH
        def _strip_cmd_carrets(_str, carret='\r', skip_chars=1):
            _cnt = _str.count(carret)
            while _cnt > 0:
                _idx = _str.index(carret)
                _str = _str[:_idx] + _str[_idx+1+skip_chars:]
                _cnt -= 1
            return _str
        while True:
            try:
                _line = self.outq.get(block=False)
                line = _line.decode() if isinstance(_line, bytes) else _line
                # line = line.replace('\r', '')
                self.output.append(line)
                # check if this is the command itself and skip
                if '$' in line:
                    _cmd = line.split('$', 1)[1].strip()
                    _cmd = _strip_cmd_carrets(_cmd)
                    if _cmd == cmd.decode():
                        continue
                break
            except queue.Empty:
                logger.debug("... {} sec".format(timeout))
                sleep(1)
                timeout -= 1
                if not timeout:
                    logger.debug("...timed out")
                    return False
        return True

    def wait_for_string(self, string, timeout=60):
        if not self._wait_for_string(string):
            raise TimeoutException(
                "Time out waiting for string '{}'".format(string)
            )
        else:
            return True


class SshShell(SshBase):
    def __enter__(self):
        self._cmd = ["ssh"]
        self._cmd += self._options
        self._cmd += self._extra_options
        self._cmd += [self._host_uri]

        logger.debug("...shell to: '{}'".format(" ".join(self._cmd)))
        self._init_connection(self._cmd)
        return self

    def __exit__(self, _type, _value, _traceback):
        self._end_connection()
        if _value:
            logger.warn(
                "Error running SSH:\r\n{}".format(
                    "".join(traceback.format_exception(
                        _type,
                        _value,
                        _traceback
                    ))
                )
            )

        return True

    def connect(self):
        return self.__enter__()

    def kill(self):
        self._end_connection()

    def get_host_path(self, path):
        _uri = self.host + ":" + path
        if self.username:
            _uri = self.username + "@" + _uri
        return _uri

    def scp(self, _src, _dst):
        self._scp_options = []
        if self.keypath:
            self._scp_options += ["-i", self.keypath]
        if self.port:
            self._scp_options += ["-P", str(self.port)]

        _cmd = ["scp"]
        _cmd += self._scp_options
        _cmd += self._extra_options
        _cmd += [_src]
        _cmd += [_dst]

        logger.debug("...scp: '{}'".format(" ".join(_cmd)))
        _proc = subprocess.Popen(
            _cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        _r = _proc.communicate()
        _e = _r[1].decode() if _r[1] else ""
        return _proc.returncode, _r[0].decode(), _e


class PortForward(SshBase):
    def __init__(
        self,
        host,
        fwd_host,
        user=None,
        keypath=None,
        port=None,
        loc_port=10022,
        fwd_port=22,
        timeout=15
    ):
        super(PortForward, self).__init__(
            host,
            user=user,
            keypath=keypath,
            port=port,
            timeout=timeout,
            silent=True,
            piped=False
        )
        self.f_host = fwd_host
        self.l_port = loc_port
        self.f_port = fwd_port

        self._forward_options = [
            "-L",
            ":".join([str(self.l_port), self.f_host, str(self.f_port)])
        ]

    def __enter__(self):
        self._cmd = ["ssh"]
        self._cmd += self._forward_options
        self._cmd += self._options
        self._cmd += self._extra_options
        self._cmd += [self._host_uri]

        logger.debug(
            "... port forwarding: '{}'".format(" ".join(self._cmd))
        )
        self._init_connection(self._cmd)
        return self

    def __exit__(self, _type, _value, _traceback):
        self._end_connection()
        if _value:
            logger_cli.warn(
                "Error running SSH:\r\n{}".format(
                    "".join(traceback.format_exception(
                        _type,
                        _value,
                        _traceback
                    ))
                )
            )

        return True

    def connect(self):
        return self.__enter__()

    def kill(self):
        self._end_connection()
