refactoring and typing in progress
diff --git a/wally/ssh_utils.py b/wally/ssh_utils.py
index 2941e7c..7728dfd 100644
--- a/wally/ssh_utils.py
+++ b/wally/ssh_utils.py
@@ -1,27 +1,25 @@
import re
-import json
import time
+import errno
import socket
import logging
import os.path
import getpass
+import selectors
from io import BytesIO
-import subprocess
-from typing import Union, Optional, cast, Dict, List, Tuple, Any, Callable
-from concurrent.futures import ThreadPoolExecutor
+from typing import Union, Optional, cast, Dict, List, Tuple
import paramiko
-import agent
-
-from . import interfaces, utils
+from . import utils
logger = logging.getLogger("wally")
+IPAddr = Tuple[str, int]
-class URIsNamespace(object):
- class ReParts(object):
+class URIsNamespace:
+ class ReParts:
user_rr = "[^:]*?"
host_rr = "[^:@]*?"
port_rr = "\\d+"
@@ -58,41 +56,27 @@
class ConnCreds:
conn_uri_attrs = ("user", "passwd", "host", "port", "key_file")
- def __init__(self) -> None:
- self.user = None # type: Optional[str]
- self.passwd = None # type: Optional[str]
- self.host = None # type: str
- self.port = 22 # type: int
- self.key_file = None # type: Optional[str]
+ def __init__(self, host: str, user: str, passwd: str = None, port: int = 22, key_file: str = None) -> None:
+ self.user = user
+ self.passwd = passwd
+ self.host = host
+ self.port = port
+ self.key_file = key_file
def __str__(self) -> str:
return str(self.__dict__)
-SSHCredsType = Union[str, ConnCreds]
-
-
def parse_ssh_uri(uri: str) -> ConnCreds:
- # [ssh://]+
- # user:passwd@ip_host:port
- # user:passwd@ip_host
- # user@ip_host:port
- # user@ip_host
- # ip_host:port
- # ip_host
- # user@ip_host:port:path_to_key_file
- # user@ip_host::path_to_key_file
- # ip_host:port:path_to_key_file
- # ip_host::path_to_key_file
+ """Parse ssh connection URL from one of following form
+ [ssh://]user:passwd@host[:port]
+ [ssh://][user@]host[:port][:key_file]
+ """
if uri.startswith("ssh://"):
uri = uri[len("ssh://"):]
- res = ConnCreds()
- res.port = 22
- res.key_file = None
- res.passwd = None
- res.user = getpass.getuser()
+ res = ConnCreds("", getpass.getuser())
for rr in URIsNamespace.uri_reg_exprs:
rrm = re.match(rr, uri)
@@ -103,109 +87,18 @@
raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
-class LocalHost(interfaces.IHost):
- def __str__(self):
- return "<Local>"
-
- def get_ip(self) -> str:
- return 'localhost'
-
- def put_to_file(self, path: str, content: bytes) -> None:
- dirname = os.path.dirname(path)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- with open(path, "wb") as fd:
- fd.write(content)
-
- def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
- proc = subprocess.Popen(cmd, shell=True,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT)
-
- stdout_data, _ = proc.communicate()
- if proc.returncode != 0:
- templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
- raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
-
- return stdout_data
+NODE_KEYS = {} # type: Dict[IPAddr, paramiko.RSAKey]
-class SSHHost(interfaces.IHost):
- def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
- self.conn = ssh_conn
- self.node_name = node_name
- self.ip = ip
-
- def get_ip(self) -> str:
- return self.ip
-
- def __str__(self) -> str:
- return self.node_name
-
- def put_to_file(self, path: str, content: bytes) -> None:
- with self.conn.open_sftp() as sftp:
- with sftp.open(path, "wb") as fd:
- fd.write(content)
-
- def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
- transport = self.conn.get_transport()
- session = transport.open_session()
-
- try:
- session.set_combine_stderr(True)
-
- stime = time.time()
-
- if not nolog:
- logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
-
- session.exec_command(cmd)
- session.settimeout(1)
- session.shutdown_write()
- output = ""
-
- while True:
- try:
- ndata = session.recv(1024)
- output += ndata
- if "" == ndata:
- break
- except socket.timeout:
- pass
-
- if time.time() - stime > timeout:
- raise OSError(output + "\nExecution timeout")
-
- code = session.recv_exit_status()
- finally:
- found = False
-
- if found:
- session.close()
-
- if code != 0:
- templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
- raise OSError(templ.format(self, cmd, code, output))
-
- return output
+def set_key_for_node(host_port: IPAddr, key: bytes) -> None:
+ with BytesIO(key) as sio:
+ NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
-NODE_KEYS = {} # type: Dict[Tuple[str, int], paramiko.RSAKey]
-
-
-def set_key_for_node(host_port: Tuple[str, int], key: bytes) -> None:
- sio = BytesIO(key)
- NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
- sio.close()
-
-
-def ssh_connect(creds: SSHCredsType, conn_timeout: int = 60) -> interfaces.IHost:
- if creds == 'local':
- return LocalHost()
-
- tcp_timeout = 15
- default_banner_timeout = 30
+def ssh_connect(creds: ConnCreds,
+ conn_timeout: int = 60,
+ tcp_timeout: int = 15,
+ default_banner_timeout: int = 30) -> Tuple[paramiko.SSHClient, str, str]:
ssh = paramiko.SSHClient()
ssh.load_host_keys('/dev/null')
@@ -223,8 +116,6 @@
if paramiko.__version_info__ >= (1, 15, 2):
banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
- creds = cast(ConnCreds, creds)
-
if creds.passwd is not None:
ssh.connect(creds.host,
timeout=c_tcp_timeout,
@@ -259,7 +150,7 @@
look_for_keys=False,
port=creds.port,
**banner_timeout_arg)
- return SSHHost(ssh, "{0.host}:{0.port}".format(creds), creds.host)
+ return ssh, "{0.host}:{0.port}".format(creds), creds.host
except paramiko.PasswordRequiredException:
raise
except (socket.error, paramiko.SSHException):
@@ -268,68 +159,35 @@
time.sleep(1)
-def connect(uri: str, **params) -> interfaces.IHost:
- if uri == 'local':
- res = LocalHost()
- else:
- creds = parse_ssh_uri(uri)
- creds.port = int(creds.port)
- res = ssh_connect(creds, **params)
- return res
-
-
-SetupResult = Tuple[interfaces.IRPC, Dict[str, Any]]
-
-
-RPCBeforeConnCallback = Callable[[interfaces.IHost, int], None]
-
-
-def setup_rpc(node: interfaces.IHost,
- rpc_server_code: bytes,
- port: int = 0,
- rpc_conn_callback: RPCBeforeConnCallback = None) -> SetupResult:
- code_file = node.run("mktemp").strip()
- log_file = node.run("mktemp").strip()
- node.put_to_file(code_file, rpc_server_code)
- cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
- "--show-settings --stdout-file={out_file}"
- params_js = node.run(cmd.format(code_file=code_file,
- listen_addr=node.get_ip(),
- out_file=log_file,
- port=port)).strip()
- params = json.loads(params_js)
- params['log_file'] = log_file
-
- if rpc_conn_callback:
- ip, port = rpc_conn_callback(node, port)
- else:
- ip = node.get_ip()
- port = int(params['addr'].split(":")[1])
-
- return agent.connect((ip, port)), params
-
-
-def wait_ssh_awailable(addrs: List[Tuple[str, int]],
+def wait_ssh_available(addrs: List[IPAddr],
timeout: int = 300,
- tcp_timeout: float = 1.0,
- max_threads: int = 32) -> None:
- addrs = addrs[:]
- tout = utils.Timeout(timeout)
+ tcp_timeout: float = 1.0) -> None:
+ addrs = set(addrs)
+ for _ in utils.Timeout(timeout):
+ with selectors.DefaultSelector() as selector: # type: selectors.BaseSelector
+ for addr in addrs:
+ sock = socket.socket()
+ sock.setblocking(False)
+ try:
+ sock.connect(addr)
+ except BlockingIOError:
+ pass
+ selector.register(sock, selectors.EVENT_READ, data=addr)
- def check_sock(addr):
- s = socket.socket()
- s.settimeout(tcp_timeout)
- try:
- s.connect(addr)
- return True
- except (socket.timeout, ConnectionRefusedError):
- return False
+ etime = time.time() + tcp_timeout
+ ltime = etime - time.time()
+ while ltime > 0:
+ for key, _ in selector.select(timeout=ltime):
+ selector.unregister(key.fileobj)
+ try:
+ key.fileobj.getpeername()
+ addrs.remove(key.data)
+ except OSError as exc:
+ if exc.errno == errno.ENOTCONN:
+ pass
+ ltime = etime - time.time()
- with ThreadPoolExecutor(max_workers=max_threads) as pool:
- while addrs:
- check_result = pool.map(check_sock, addrs)
- addrs = [addr for ok, addr in zip(check_result, addrs) if not ok] # type: List[Tuple[str, int]]
- tout.tick()
-
+ if not addrs:
+ break