typing and refactoring on the way
diff --git a/wally/ssh_utils.py b/wally/ssh_utils.py
index ada4af6..2941e7c 100644
--- a/wally/ssh_utils.py
+++ b/wally/ssh_utils.py
@@ -1,298 +1,25 @@
import re
+import json
import time
-import errno
-import random
import socket
-import shutil
import logging
import os.path
import getpass
-import StringIO
-import threading
+from io import BytesIO
import subprocess
+from typing import Union, Optional, cast, Dict, List, Tuple, Any, Callable
+from concurrent.futures import ThreadPoolExecutor
import paramiko
+import agent
+
+from . import interfaces, utils
+
logger = logging.getLogger("wally")
-class Local(object):
- "simulate ssh connection to local"
- @classmethod
- def open_sftp(cls):
- return cls()
-
- @classmethod
- def mkdir(cls, remotepath, mode=None):
- os.mkdir(remotepath)
- if mode is not None:
- os.chmod(remotepath, mode)
-
- @classmethod
- def put(cls, localfile, remfile):
- dirname = os.path.dirname(remfile)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- shutil.copyfile(localfile, remfile)
-
- @classmethod
- def get(cls, remfile, localfile):
- dirname = os.path.dirname(localfile)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- shutil.copyfile(remfile, localfile)
-
- @classmethod
- def chmod(cls, path, mode):
- os.chmod(path, mode)
-
- @classmethod
- def copytree(cls, src, dst):
- shutil.copytree(src, dst)
-
- @classmethod
- def remove(cls, path):
- os.unlink(path)
-
- @classmethod
- def close(cls):
- pass
-
- @classmethod
- def open(cls, *args, **kwarhgs):
- return open(*args, **kwarhgs)
-
- @classmethod
- def stat(cls, path):
- return os.stat(path)
-
- def __enter__(self):
- return self
-
- def __exit__(self, x, y, z):
- return False
-
-
-NODE_KEYS = {}
-
-
-def exists(sftp, path):
- "os.path.exists for paramiko's SCP object"
- try:
- sftp.stat(path)
- return True
- except IOError as e:
- if e.errno == errno.ENOENT:
- return False
- raise
-
-
-def set_key_for_node(host_port, key):
- sio = StringIO.StringIO(key)
- NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
- sio.close()
-
-
-def ssh_connect(creds, conn_timeout=60, reuse_conn=None):
- if creds == 'local':
- return Local()
-
- tcp_timeout = 15
- banner_timeout = 30
-
- if reuse_conn is None:
- ssh = paramiko.SSHClient()
- ssh.load_host_keys('/dev/null')
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- ssh.known_hosts = None
- else:
- ssh = reuse_conn
-
- etime = time.time() + conn_timeout
-
- while True:
- try:
- tleft = etime - time.time()
- c_tcp_timeout = min(tcp_timeout, tleft)
-
- if paramiko.__version_info__ >= (1, 15, 2):
- banner_timeout = {'banner_timeout': min(banner_timeout, tleft)}
- else:
- banner_timeout = {}
-
- if creds.passwd is not None:
- ssh.connect(creds.host,
- timeout=c_tcp_timeout,
- username=creds.user,
- password=creds.passwd,
- port=creds.port,
- allow_agent=False,
- look_for_keys=False,
- **banner_timeout)
- elif creds.key_file is not None:
- ssh.connect(creds.host,
- username=creds.user,
- timeout=c_tcp_timeout,
- key_filename=creds.key_file,
- look_for_keys=False,
- port=creds.port,
- **banner_timeout)
- elif (creds.host, creds.port) in NODE_KEYS:
- ssh.connect(creds.host,
- username=creds.user,
- timeout=c_tcp_timeout,
- pkey=NODE_KEYS[(creds.host, creds.port)],
- look_for_keys=False,
- port=creds.port,
- **banner_timeout)
- else:
- key_file = os.path.expanduser('~/.ssh/id_rsa')
- ssh.connect(creds.host,
- username=creds.user,
- timeout=c_tcp_timeout,
- key_filename=key_file,
- look_for_keys=False,
- port=creds.port,
- **banner_timeout)
- return ssh
- except paramiko.PasswordRequiredException:
- raise
- except (socket.error, paramiko.SSHException):
- if time.time() > etime:
- raise
- time.sleep(1)
-
-
-def save_to_remote(sftp, path, content):
- with sftp.open(path, "wb") as fd:
- fd.write(content)
-
-
-def read_from_remote(sftp, path):
- with sftp.open(path, "rb") as fd:
- return fd.read()
-
-
-def normalize_dirpath(dirpath):
- while dirpath.endswith("/"):
- dirpath = dirpath[:-1]
- return dirpath
-
-
-ALL_RWX_MODE = ((1 << 9) - 1)
-
-
-def ssh_mkdir(sftp, remotepath, mode=ALL_RWX_MODE, intermediate=False):
- remotepath = normalize_dirpath(remotepath)
- if intermediate:
- try:
- sftp.mkdir(remotepath, mode=mode)
- except (IOError, OSError):
- upper_dir = remotepath.rsplit("/", 1)[0]
-
- if upper_dir == '' or upper_dir == '/':
- raise
-
- ssh_mkdir(sftp, upper_dir, mode=mode, intermediate=True)
- return sftp.mkdir(remotepath, mode=mode)
- else:
- sftp.mkdir(remotepath, mode=mode)
-
-
-def ssh_copy_file(sftp, localfile, remfile, preserve_perm=True):
- sftp.put(localfile, remfile)
- if preserve_perm:
- sftp.chmod(remfile, os.stat(localfile).st_mode & ALL_RWX_MODE)
-
-
-def put_dir_recursively(sftp, localpath, remotepath, preserve_perm=True):
- "upload local directory to remote recursively"
-
- # hack for localhost connection
- if hasattr(sftp, "copytree"):
- sftp.copytree(localpath, remotepath)
- return
-
- assert remotepath.startswith("/"), "%s must be absolute path" % remotepath
-
- # normalize
- localpath = normalize_dirpath(localpath)
- remotepath = normalize_dirpath(remotepath)
-
- try:
- sftp.chdir(remotepath)
- localsuffix = localpath.rsplit("/", 1)[1]
- remotesuffix = remotepath.rsplit("/", 1)[1]
- if localsuffix != remotesuffix:
- remotepath = os.path.join(remotepath, localsuffix)
- except IOError:
- pass
-
- for root, dirs, fls in os.walk(localpath):
- prefix = os.path.commonprefix([localpath, root])
- suffix = root.split(prefix, 1)[1]
- if suffix.startswith("/"):
- suffix = suffix[1:]
-
- remroot = os.path.join(remotepath, suffix)
-
- try:
- sftp.chdir(remroot)
- except IOError:
- if preserve_perm:
- mode = os.stat(root).st_mode & ALL_RWX_MODE
- else:
- mode = ALL_RWX_MODE
- ssh_mkdir(sftp, remroot, mode=mode, intermediate=True)
- sftp.chdir(remroot)
-
- for f in fls:
- remfile = os.path.join(remroot, f)
- localfile = os.path.join(root, f)
- ssh_copy_file(sftp, localfile, remfile, preserve_perm)
-
-
-def delete_file(conn, path):
- sftp = conn.open_sftp()
- sftp.remove(path)
- sftp.close()
-
-
-def copy_paths(conn, paths):
- sftp = conn.open_sftp()
- try:
- for src, dst in paths.items():
- try:
- if os.path.isfile(src):
- ssh_copy_file(sftp, src, dst)
- elif os.path.isdir(src):
- put_dir_recursively(sftp, src, dst)
- else:
- templ = "Can't copy {0!r} - " + \
- "it neither a file not a directory"
- raise OSError(templ.format(src))
- except Exception as exc:
- tmpl = "Scp {0!r} => {1!r} failed - {2!r}"
- raise OSError(tmpl.format(src, dst, exc))
- finally:
- sftp.close()
-
-
-class ConnCreds(object):
- conn_uri_attrs = ("user", "passwd", "host", "port", "path")
-
- def __init__(self):
- for name in self.conn_uri_attrs:
- setattr(self, name, None)
-
- def __str__(self):
- return str(self.__dict__)
-
-
-uri_reg_exprs = []
-
-
class URIsNamespace(object):
class ReParts(object):
user_rr = "[^:]*?"
@@ -323,11 +50,29 @@
"^{user_rr}:{passwd_rr}@{host_rr}:{port_rr}$",
]
+ uri_reg_exprs = [] # type: List[str]
for templ in templs:
uri_reg_exprs.append(templ.format(**re_dct))
-def parse_ssh_uri(uri):
+class ConnCreds:
+ conn_uri_attrs = ("user", "passwd", "host", "port", "key_file")
+
+ def __init__(self) -> None:
+ self.user = None # type: Optional[str]
+ self.passwd = None # type: Optional[str]
+ self.host = None # type: str
+ self.port = 22 # type: int
+ self.key_file = None # type: Optional[str]
+
+ def __str__(self) -> str:
+ return str(self.__dict__)
+
+
+SSHCredsType = Union[str, ConnCreds]
+
+
+def parse_ssh_uri(uri: str) -> ConnCreds:
# [ssh://]+
# user:passwd@ip_host:port
# user:passwd@ip_host
@@ -344,12 +89,12 @@
uri = uri[len("ssh://"):]
res = ConnCreds()
- res.port = "22"
+ res.port = 22
res.key_file = None
res.passwd = None
res.user = getpass.getuser()
- for rr in uri_reg_exprs:
+ for rr in URIsNamespace.uri_reg_exprs:
rrm = re.match(rr, uri)
if rrm is not None:
res.__dict__.update(rrm.groupdict())
@@ -358,18 +103,174 @@
raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
-def reconnect(conn, uri, **params):
+class LocalHost(interfaces.IHost):
+ def __str__(self):
+ return "<Local>"
+
+ def get_ip(self) -> str:
+ return 'localhost'
+
+ def put_to_file(self, path: str, content: bytes) -> None:
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ with open(path, "wb") as fd:
+ fd.write(content)
+
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ proc = subprocess.Popen(cmd, shell=True,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+
+ stdout_data, _ = proc.communicate()
+ if proc.returncode != 0:
+ templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
+ raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
+
+ return stdout_data
+
+
+class SSHHost(interfaces.IHost):
+ def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
+ self.conn = ssh_conn
+ self.node_name = node_name
+ self.ip = ip
+
+ def get_ip(self) -> str:
+ return self.ip
+
+ def __str__(self) -> str:
+ return self.node_name
+
+ def put_to_file(self, path: str, content: bytes) -> None:
+ with self.conn.open_sftp() as sftp:
+ with sftp.open(path, "wb") as fd:
+ fd.write(content)
+
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ transport = self.conn.get_transport()
+ session = transport.open_session()
+
+ try:
+ session.set_combine_stderr(True)
+
+ stime = time.time()
+
+ if not nolog:
+ logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
+
+ session.exec_command(cmd)
+ session.settimeout(1)
+ session.shutdown_write()
+ output = ""
+
+ while True:
+ try:
+ ndata = session.recv(1024)
+ output += ndata
+ if "" == ndata:
+ break
+ except socket.timeout:
+ pass
+
+ if time.time() - stime > timeout:
+ raise OSError(output + "\nExecution timeout")
+
+ code = session.recv_exit_status()
+ finally:
+ found = False
+
+ if found:
+ session.close()
+
+ if code != 0:
+ templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
+ raise OSError(templ.format(self, cmd, code, output))
+
+ return output
+
+
+NODE_KEYS = {} # type: Dict[Tuple[str, int], paramiko.RSAKey]
+
+
+def set_key_for_node(host_port: Tuple[str, int], key: bytes) -> None:
+ sio = BytesIO(key)
+ NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
+ sio.close()
+
+
+def ssh_connect(creds: SSHCredsType, conn_timeout: int = 60) -> interfaces.IHost:
+ if creds == 'local':
+ return LocalHost()
+
+ tcp_timeout = 15
+ default_banner_timeout = 30
+
+ ssh = paramiko.SSHClient()
+ ssh.load_host_keys('/dev/null')
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.known_hosts = None
+
+ end_time = time.time() + conn_timeout # type: float
+
+ while True:
+ try:
+ time_left = end_time - time.time()
+ c_tcp_timeout = min(tcp_timeout, time_left)
+
+ banner_timeout_arg = {} # type: Dict[str, int]
+ if paramiko.__version_info__ >= (1, 15, 2):
+ banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
+
+ creds = cast(ConnCreds, creds)
+
+ if creds.passwd is not None:
+ ssh.connect(creds.host,
+ timeout=c_tcp_timeout,
+ username=creds.user,
+ password=cast(str, creds.passwd),
+ port=creds.port,
+ allow_agent=False,
+ look_for_keys=False,
+ **banner_timeout_arg)
+ elif creds.key_file is not None:
+ ssh.connect(creds.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ key_filename=cast(str, creds.key_file),
+ look_for_keys=False,
+ port=creds.port,
+ **banner_timeout_arg)
+ elif (creds.host, creds.port) in NODE_KEYS:
+ ssh.connect(creds.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ pkey=NODE_KEYS[(creds.host, creds.port)],
+ look_for_keys=False,
+ port=creds.port,
+ **banner_timeout_arg)
+ else:
+ key_file = os.path.expanduser('~/.ssh/id_rsa')
+ ssh.connect(creds.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ key_filename=key_file,
+ look_for_keys=False,
+ port=creds.port,
+ **banner_timeout_arg)
+ return SSHHost(ssh, "{0.host}:{0.port}".format(creds), creds.host)
+ except paramiko.PasswordRequiredException:
+ raise
+ except (socket.error, paramiko.SSHException):
+ if time.time() > end_time:
+ raise
+ time.sleep(1)
+
+
+def connect(uri: str, **params) -> interfaces.IHost:
if uri == 'local':
- return conn
-
- creds = parse_ssh_uri(uri)
- creds.port = int(creds.port)
- return ssh_connect(creds, reuse_conn=conn, **params)
-
-
-def connect(uri, **params):
- if uri == 'local':
- res = Local()
+ res = LocalHost()
else:
creds = parse_ssh_uri(uri)
creds.port = int(creds.port)
@@ -377,180 +278,58 @@
return res
-all_sessions_lock = threading.Lock()
-all_sessions = {}
+SetupResult = Tuple[interfaces.IRPC, Dict[str, Any]]
-class BGSSHTask(object):
- CHECK_RETRY = 5
+RPCBeforeConnCallback = Callable[[interfaces.IHost, int], None]
- def __init__(self, node, use_sudo):
- self.node = node
- self.pid = None
- self.use_sudo = use_sudo
- def start(self, orig_cmd, **params):
- uniq_name = 'test'
- cmd = "screen -S {0} -d -m {1}".format(uniq_name, orig_cmd)
- run_over_ssh(self.node.connection, cmd,
- timeout=10, node=self.node.get_conn_id(),
- **params)
- processes = run_over_ssh(self.node.connection, "ps aux", nolog=True)
+def setup_rpc(node: interfaces.IHost,
+ rpc_server_code: bytes,
+ port: int = 0,
+ rpc_conn_callback: RPCBeforeConnCallback = None) -> SetupResult:
+ code_file = node.run("mktemp").strip()
+ log_file = node.run("mktemp").strip()
+ node.put_to_file(code_file, rpc_server_code)
+ cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
+ "--show-settings --stdout-file={out_file}"
+ params_js = node.run(cmd.format(code_file=code_file,
+ listen_addr=node.get_ip(),
+ out_file=log_file,
+ port=port)).strip()
+ params = json.loads(params_js)
+ params['log_file'] = log_file
- for iter in range(self.CHECK_RETRY):
- for proc in processes.split("\n"):
- if orig_cmd in proc and "SCREEN" not in proc:
- self.pid = proc.split()[1]
- break
- if self.pid is not None:
- break
- time.sleep(1)
+ if rpc_conn_callback:
+ ip, port = rpc_conn_callback(node, port)
+ else:
+ ip = node.get_ip()
+ port = int(params['addr'].split(":")[1])
- if self.pid is None:
- self.pid = -1
+ return agent.connect((ip, port)), params
- def check_running(self):
- assert self.pid is not None
- if -1 == self.pid:
- return False
+
+def wait_ssh_awailable(addrs: List[Tuple[str, int]],
+ timeout: int = 300,
+ tcp_timeout: float = 1.0,
+ max_threads: int = 32) -> None:
+ addrs = addrs[:]
+ tout = utils.Timeout(timeout)
+
+ def check_sock(addr):
+ s = socket.socket()
+ s.settimeout(tcp_timeout)
try:
- run_over_ssh(self.node.connection,
- "ls /proc/{0}".format(self.pid),
- timeout=10, nolog=True)
+ s.connect(addr)
return True
- except OSError:
+ except (socket.timeout, ConnectionRefusedError):
return False
- def kill(self, soft=True, use_sudo=True):
- assert self.pid is not None
- if self.pid == -1:
- return True
- try:
- if soft:
- cmd = "kill {0}"
- else:
- cmd = "kill -9 {0}"
-
- if self.use_sudo:
- cmd = "sudo " + cmd
-
- run_over_ssh(self.node.connection,
- cmd.format(self.pid), nolog=True)
- return True
- except OSError:
- return False
-
- def wait(self, soft_timeout, timeout):
- end_of_wait_time = timeout + time.time()
- soft_end_of_wait_time = soft_timeout + time.time()
-
- # time_till_check = random.randint(5, 10)
- time_till_check = 2
-
- # time_till_first_check = random.randint(2, 6)
- time_till_first_check = 2
- time.sleep(time_till_first_check)
- if not self.check_running():
- return True
-
- while self.check_running() and time.time() < soft_end_of_wait_time:
- # time.sleep(soft_end_of_wait_time - time.time())
- time.sleep(time_till_check)
-
- while end_of_wait_time > time.time():
- time.sleep(time_till_check)
- if not self.check_running():
- break
- else:
- self.kill()
- time.sleep(1)
- if self.check_running():
- self.kill(soft=False)
- return False
- return True
+ with ThreadPoolExecutor(max_workers=max_threads) as pool:
+ while addrs:
+ check_result = pool.map(check_sock, addrs)
+ addrs = [addr for ok, addr in zip(check_result, addrs) if not ok] # type: List[Tuple[str, int]]
+ tout.tick()
-def run_over_ssh(conn, cmd, stdin_data=None, timeout=60,
- nolog=False, node=None):
- "should be replaces by normal implementation, with select"
- if isinstance(conn, Local):
- if not nolog:
- logger.debug("SSH:local Exec {0!r}".format(cmd))
- proc = subprocess.Popen(cmd, shell=True,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT)
-
- stdoutdata, _ = proc.communicate(input=stdin_data)
- if proc.returncode != 0:
- templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
- raise OSError(templ.format(node, cmd, proc.returncode, stdoutdata))
-
- return stdoutdata
-
- transport = conn.get_transport()
- session = transport.open_session()
-
- if node is None:
- node = ""
-
- with all_sessions_lock:
- all_sessions[id(session)] = session
-
- try:
- session.set_combine_stderr(True)
-
- stime = time.time()
-
- if not nolog:
- logger.debug("SSH:{0} Exec {1!r}".format(node, cmd))
-
- session.exec_command(cmd)
-
- if stdin_data is not None:
- session.sendall(stdin_data)
-
- session.settimeout(1)
- session.shutdown_write()
- output = ""
-
- while True:
- try:
- ndata = session.recv(1024)
- output += ndata
- if "" == ndata:
- break
- except socket.timeout:
- pass
-
- if time.time() - stime > timeout:
- raise OSError(output + "\nExecution timeout")
-
- code = session.recv_exit_status()
- finally:
- found = False
- with all_sessions_lock:
- if id(session) in all_sessions:
- found = True
- del all_sessions[id(session)]
-
- if found:
- session.close()
-
- if code != 0:
- templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
- raise OSError(templ.format(node, cmd, code, output))
-
- return output
-
-
-def close_all_sessions():
- with all_sessions_lock:
- for session in all_sessions.values():
- try:
- session.sendall('\x03')
- session.close()
- except:
- pass
- all_sessions.clear()