blob: d500e361d9ca99614cbfaa90edb2dca4e988bea2 [file] [log] [blame]
import queue
import subprocess
import traceback
import threading
from time import sleep
from .exception import TimeoutException, CheckerException
from .other import shell, piped_shell
from .log import logger, logger_cli
# We do not use paramiko here to preserve system level ssh config
def ssh_shell_p(
command,
host,
username=None,
keypath=None,
port=None,
silent=False,
piped=False,
use_sudo=False
):
_ssh_cmd = []
_ssh_cmd.append("ssh")
if silent:
_ssh_cmd.append("-q")
# Build SSH cmd
if keypath:
_ssh_cmd.append("-i " + keypath)
_ssh_cmd.append("-o " + "IdentitiesOnly=yes")
if port:
_ssh_cmd.append("-p " + str(port))
if username:
_ssh_cmd.append(username+'@'+host)
else:
_ssh_cmd.append(host)
if use_sudo:
_ssh_cmd.append("sudo")
_ssh_cmd.append(command)
_ssh_cmd = " ".join(_ssh_cmd)
if not piped:
return shell(_ssh_cmd)
else:
return piped_shell(_ssh_cmd)
def scp_p(
source,
target,
port=None,
keypath=None,
silent=False,
piped=False
):
_scp_cmd = []
_scp_cmd.append("scp")
if port:
_scp_cmd.append("-P " + str(port))
if silent:
_scp_cmd.append("-q")
# Build SSH cmd
if keypath:
_scp_cmd.append("-i " + keypath)
_scp_cmd.append(source)
_scp_cmd.append(target)
_scp_cmd = " ".join(_scp_cmd)
if not piped:
return shell(_scp_cmd)
else:
return piped_shell(_scp_cmd)
def output_reader(_stdout, outq):
for line in iter(_stdout.readline, b''):
outq.put(line.decode('utf-8'))
# Base class for all SSH related actions
class SshBase(object):
def __init__(
self,
tgt_host,
user=None,
keypath=None,
port=None,
timeout=15,
silent=False,
piped=False
):
self._cmd = ["ssh"]
self.timeout = timeout
self.port = port if port else 22
self.host = tgt_host
self.username = user
self.keypath = keypath
self.silent = silent
self.piped = piped
self.output = []
self._options = ["-tt"]
if self.keypath:
self._options += ["-i", self.keypath]
if self.port:
self._options += ["-p", str(self.port)]
self._extra_options = [
"-o", "UserKnownHostsFile=/dev/null",
"-o", "StrictHostKeyChecking=no",
"-o", "IdentitiesOnly=yes"
]
self._host_uri = ""
if self.username:
self._host_uri = self.username + "@" + self.host
else:
self._host_uri = self.host
def _connect(self, banner="Welcome"):
if not isinstance(banner, str):
raise CheckerException(
"Invalid SSH banner type: '{}'".format(type(banner))
)
logger.debug("... connecting")
while True:
try:
line = self.outq.get(block=False)
self.output.append(line)
if line.startswith(banner):
break
except queue.Empty:
logger.debug("... {} sec".format(self.timeout))
sleep(1)
self.timeout -= 1
if not self.timeout:
logger.debug(
"...timed out after {} sec".format(str(self.timeout))
)
return False
logger.debug("... connected")
return True
def _wait_for_string(self, string):
logger.debug("... waiting for '{}'".format(string))
while True:
try:
line = self.outq.get(block=False)
line = line.decode() if isinstance(line, bytes) else line
self.output.append(line)
if not line.startswith(string):
continue
else:
break
except queue.Empty:
logger.debug("... {} sec".format(self.timeout))
sleep(1)
self.timeout -= 1
if not self.timeout:
logger.debug(
"... timed out after {} sec".format(str(self.timeout))
)
return False
logger.debug("... found")
return True
def _init_connection(self, cmd):
self._proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=False,
bufsize=0
)
# Create thread safe output getter
self.outq = queue.Queue()
self._t = threading.Thread(
target=output_reader,
args=(self._proc.stdout, self.outq)
)
self._t.start()
# Track if there is an yes/no
if not self._connect():
raise TimeoutException(
"SSH connection to '{}'".format(self.host)
)
self.input = self._proc.stdin
self.get_output()
logger.debug(
"Connected. Banners:\n{}".format(
"".join(self.flush_output())
)
)
def _end_connection(self):
# Kill the ssh process if it is alive
if self._proc.poll() is None:
self._proc.kill()
self.get_output()
return
def do(self, cmd, timeout=30, sudo=False, strip_cmd=True):
cmd = cmd if isinstance(cmd, bytes) else bytes(cmd.encode('utf-8'))
logger.debug("... ssh: '{}'".format(cmd))
if sudo:
_cmd = b"sudo " + cmd
else:
_cmd = cmd
# run command
self.input.write(_cmd + b'\n')
# wait for completion
self.wait_ready(_cmd, timeout=timeout)
self.get_output()
_output = self.flush_output().replace('\r', '')
if strip_cmd:
return "\n".join(_output.splitlines()[1:])
else:
return _output
def get_output(self):
while True:
try:
line = self.outq.get(block=False)
line = str(line) if isinstance(line, bytes) else line
self.output.append(line)
except queue.Empty:
return self.output
return None
def flush_output(self, as_string=True):
_out = self.output
self.output = []
if as_string:
return "".join(_out)
else:
return _out
def wait_ready(self, cmd, timeout=60):
# Wait for command to finish inside SSH
def _strip_cmd_carrets(_str, carret='\r', skip_chars=1):
_cnt = _str.count(carret)
while _cnt > 0:
_idx = _str.index(carret)
_str = _str[:_idx] + _str[_idx+1+skip_chars:]
_cnt -= 1
return _str
while True:
try:
_line = self.outq.get(block=False)
line = _line.decode() if isinstance(_line, bytes) else _line
# line = line.replace('\r', '')
self.output.append(line)
# check if this is the command itself and skip
if '$' in line:
_cmd = line.split('$', 1)[1].strip()
_cmd = _strip_cmd_carrets(_cmd)
if _cmd == cmd.decode():
continue
break
except queue.Empty:
logger.debug("... {} sec".format(timeout))
sleep(1)
timeout -= 1
if not timeout:
logger.debug("...timed out")
return False
return True
def wait_for_string(self, string, timeout=60):
if not self._wait_for_string(string):
raise TimeoutException(
"Time out waiting for string '{}'".format(string)
)
else:
return True
class SshShell(SshBase):
def __enter__(self):
self._cmd = ["ssh"]
self._cmd += self._options
self._cmd += self._extra_options
self._cmd += [self._host_uri]
logger.debug("...shell to: '{}'".format(" ".join(self._cmd)))
self._init_connection(self._cmd)
return self
def __exit__(self, _type, _value, _traceback):
self._end_connection()
if _value:
logger.warn(
"Error running SSH:\r\n{}".format(
"".join(traceback.format_exception(
_type,
_value,
_traceback
))
)
)
return True
def connect(self):
return self.__enter__()
def kill(self):
self._end_connection()
def get_host_path(self, path):
_uri = self.host + ":" + path
if self.username:
_uri = self.username + "@" + _uri
return _uri
def scp(self, _src, _dst):
self._scp_options = []
if self.keypath:
self._scp_options += ["-i", self.keypath]
if self.port:
self._scp_options += ["-P", str(self.port)]
_cmd = ["scp"]
_cmd += self._scp_options
_cmd += self._extra_options
_cmd += [_src]
_cmd += [_dst]
logger.debug("...scp: '{}'".format(" ".join(_cmd)))
_proc = subprocess.Popen(
_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
_r = _proc.communicate()
_e = _r[1].decode() if _r[1] else ""
return _proc.returncode, _r[0].decode(), _e
class PortForward(SshBase):
def __init__(
self,
host,
fwd_host,
user=None,
keypath=None,
port=None,
loc_port=10022,
fwd_port=22,
timeout=15
):
super(PortForward, self).__init__(
host,
user=user,
keypath=keypath,
port=port,
timeout=timeout,
silent=True,
piped=False
)
self.f_host = fwd_host
self.l_port = loc_port
self.f_port = fwd_port
self._forward_options = [
"-L",
":".join([str(self.l_port), self.f_host, str(self.f_port)])
]
def __enter__(self):
self._cmd = ["ssh"]
self._cmd += self._forward_options
self._cmd += self._options
self._cmd += self._extra_options
self._cmd += [self._host_uri]
logger.debug(
"... port forwarding: '{}'".format(" ".join(self._cmd))
)
self._init_connection(self._cmd)
return self
def __exit__(self, _type, _value, _traceback):
self._end_connection()
if _value:
logger_cli.warn(
"Error running SSH:\r\n{}".format(
"".join(traceback.format_exception(
_type,
_value,
_traceback
))
)
)
return True
def connect(self):
return self.__enter__()
def kill(self):
self._end_connection()