koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 1 | import re |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 2 | import json |
koder aka kdanilov | 3a6633e | 2015-03-26 18:20:00 +0200 | [diff] [blame] | 3 | import time |
koder aka kdanilov | 652cd80 | 2015-04-13 12:21:07 +0300 | [diff] [blame] | 4 | import socket |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 5 | import logging |
| 6 | import os.path |
koder aka kdanilov | 3a6633e | 2015-03-26 18:20:00 +0200 | [diff] [blame] | 7 | import getpass |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 8 | from io import BytesIO |
koder aka kdanilov | 0c598a1 | 2015-04-21 03:01:40 +0300 | [diff] [blame] | 9 | import subprocess |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 10 | from typing import Union, Optional, cast, Dict, List, Tuple, Any, Callable |
| 11 | from concurrent.futures import ThreadPoolExecutor |
koder aka kdanilov | 652cd80 | 2015-04-13 12:21:07 +0300 | [diff] [blame] | 12 | |
koder aka kdanilov | 3a6633e | 2015-03-26 18:20:00 +0200 | [diff] [blame] | 13 | import paramiko |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 14 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 15 | import agent |
| 16 | |
| 17 | from . import interfaces, utils |
| 18 | |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 19 | |
koder aka kdanilov | cff7b2e | 2015-04-18 20:48:15 +0300 | [diff] [blame] | 20 | logger = logging.getLogger("wally") |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 21 | |
| 22 | |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 23 | class URIsNamespace(object): |
| 24 | class ReParts(object): |
| 25 | user_rr = "[^:]*?" |
koder aka kdanilov | 7e0f7cf | 2015-05-01 17:24:35 +0300 | [diff] [blame] | 26 | host_rr = "[^:@]*?" |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 27 | port_rr = "\\d+" |
| 28 | key_file_rr = "[^:@]*" |
| 29 | passwd_rr = ".*?" |
| 30 | |
| 31 | re_dct = ReParts.__dict__ |
| 32 | |
| 33 | for attr_name, val in re_dct.items(): |
| 34 | if attr_name.endswith('_rr'): |
| 35 | new_rr = "(?P<{0}>{1})".format(attr_name[:-3], val) |
| 36 | setattr(ReParts, attr_name, new_rr) |
| 37 | |
| 38 | re_dct = ReParts.__dict__ |
| 39 | |
| 40 | templs = [ |
| 41 | "^{host_rr}$", |
koder aka kdanilov | 7e0f7cf | 2015-05-01 17:24:35 +0300 | [diff] [blame] | 42 | "^{host_rr}:{port_rr}$", |
koder aka kdanilov | 416b87a | 2015-05-12 00:26:04 +0300 | [diff] [blame] | 43 | "^{host_rr}::{key_file_rr}$", |
| 44 | "^{host_rr}:{port_rr}:{key_file_rr}$", |
koder aka kdanilov | 7e0f7cf | 2015-05-01 17:24:35 +0300 | [diff] [blame] | 45 | "^{user_rr}@{host_rr}$", |
| 46 | "^{user_rr}@{host_rr}:{port_rr}$", |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 47 | "^{user_rr}@{host_rr}::{key_file_rr}$", |
| 48 | "^{user_rr}@{host_rr}:{port_rr}:{key_file_rr}$", |
koder aka kdanilov | 7e0f7cf | 2015-05-01 17:24:35 +0300 | [diff] [blame] | 49 | "^{user_rr}:{passwd_rr}@{host_rr}$", |
| 50 | "^{user_rr}:{passwd_rr}@{host_rr}:{port_rr}$", |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 51 | ] |
| 52 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 53 | uri_reg_exprs = [] # type: List[str] |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 54 | for templ in templs: |
| 55 | uri_reg_exprs.append(templ.format(**re_dct)) |
| 56 | |
| 57 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 58 | class ConnCreds: |
| 59 | conn_uri_attrs = ("user", "passwd", "host", "port", "key_file") |
| 60 | |
| 61 | def __init__(self) -> None: |
| 62 | self.user = None # type: Optional[str] |
| 63 | self.passwd = None # type: Optional[str] |
| 64 | self.host = None # type: str |
| 65 | self.port = 22 # type: int |
| 66 | self.key_file = None # type: Optional[str] |
| 67 | |
| 68 | def __str__(self) -> str: |
| 69 | return str(self.__dict__) |
| 70 | |
| 71 | |
| 72 | SSHCredsType = Union[str, ConnCreds] |
| 73 | |
| 74 | |
| 75 | def parse_ssh_uri(uri: str) -> ConnCreds: |
koder aka kdanilov | 3b4da8b | 2016-10-17 00:17:53 +0300 | [diff] [blame] | 76 | # [ssh://]+ |
koder aka kdanilov | 7e0f7cf | 2015-05-01 17:24:35 +0300 | [diff] [blame] | 77 | # user:passwd@ip_host:port |
| 78 | # user:passwd@ip_host |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 79 | # user@ip_host:port |
| 80 | # user@ip_host |
| 81 | # ip_host:port |
| 82 | # ip_host |
| 83 | # user@ip_host:port:path_to_key_file |
| 84 | # user@ip_host::path_to_key_file |
| 85 | # ip_host:port:path_to_key_file |
| 86 | # ip_host::path_to_key_file |
| 87 | |
koder aka kdanilov | 4d4771c | 2015-04-23 01:32:02 +0300 | [diff] [blame] | 88 | if uri.startswith("ssh://"): |
| 89 | uri = uri[len("ssh://"):] |
| 90 | |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 91 | res = ConnCreds() |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 92 | res.port = 22 |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 93 | res.key_file = None |
| 94 | res.passwd = None |
koder aka kdanilov | a4a570f | 2015-04-23 22:11:40 +0300 | [diff] [blame] | 95 | res.user = getpass.getuser() |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 96 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 97 | for rr in URIsNamespace.uri_reg_exprs: |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 98 | rrm = re.match(rr, uri) |
| 99 | if rrm is not None: |
| 100 | res.__dict__.update(rrm.groupdict()) |
| 101 | return res |
koder aka kdanilov | 652cd80 | 2015-04-13 12:21:07 +0300 | [diff] [blame] | 102 | |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 103 | raise ValueError("Can't parse {0!r} as ssh uri value".format(uri)) |
| 104 | |
| 105 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 106 | class LocalHost(interfaces.IHost): |
| 107 | def __str__(self): |
| 108 | return "<Local>" |
| 109 | |
| 110 | def get_ip(self) -> str: |
| 111 | return 'localhost' |
| 112 | |
| 113 | def put_to_file(self, path: str, content: bytes) -> None: |
| 114 | dirname = os.path.dirname(path) |
| 115 | if not os.path.exists(dirname): |
| 116 | os.makedirs(dirname) |
| 117 | with open(path, "wb") as fd: |
| 118 | fd.write(content) |
| 119 | |
| 120 | def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str: |
| 121 | proc = subprocess.Popen(cmd, shell=True, |
| 122 | stdin=subprocess.PIPE, |
| 123 | stdout=subprocess.PIPE, |
| 124 | stderr=subprocess.STDOUT) |
| 125 | |
| 126 | stdout_data, _ = proc.communicate() |
| 127 | if proc.returncode != 0: |
| 128 | templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}" |
| 129 | raise OSError(templ.format(self, cmd, proc.returncode, stdout_data)) |
| 130 | |
| 131 | return stdout_data |
| 132 | |
| 133 | |
| 134 | class SSHHost(interfaces.IHost): |
| 135 | def __init__(self, ssh_conn, node_name: str, ip: str) -> None: |
| 136 | self.conn = ssh_conn |
| 137 | self.node_name = node_name |
| 138 | self.ip = ip |
| 139 | |
| 140 | def get_ip(self) -> str: |
| 141 | return self.ip |
| 142 | |
| 143 | def __str__(self) -> str: |
| 144 | return self.node_name |
| 145 | |
| 146 | def put_to_file(self, path: str, content: bytes) -> None: |
| 147 | with self.conn.open_sftp() as sftp: |
| 148 | with sftp.open(path, "wb") as fd: |
| 149 | fd.write(content) |
| 150 | |
| 151 | def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str: |
| 152 | transport = self.conn.get_transport() |
| 153 | session = transport.open_session() |
| 154 | |
| 155 | try: |
| 156 | session.set_combine_stderr(True) |
| 157 | |
| 158 | stime = time.time() |
| 159 | |
| 160 | if not nolog: |
| 161 | logger.debug("SSH:{0} Exec {1!r}".format(self, cmd)) |
| 162 | |
| 163 | session.exec_command(cmd) |
| 164 | session.settimeout(1) |
| 165 | session.shutdown_write() |
| 166 | output = "" |
| 167 | |
| 168 | while True: |
| 169 | try: |
| 170 | ndata = session.recv(1024) |
| 171 | output += ndata |
| 172 | if "" == ndata: |
| 173 | break |
| 174 | except socket.timeout: |
| 175 | pass |
| 176 | |
| 177 | if time.time() - stime > timeout: |
| 178 | raise OSError(output + "\nExecution timeout") |
| 179 | |
| 180 | code = session.recv_exit_status() |
| 181 | finally: |
| 182 | found = False |
| 183 | |
| 184 | if found: |
| 185 | session.close() |
| 186 | |
| 187 | if code != 0: |
| 188 | templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}" |
| 189 | raise OSError(templ.format(self, cmd, code, output)) |
| 190 | |
| 191 | return output |
| 192 | |
| 193 | |
| 194 | NODE_KEYS = {} # type: Dict[Tuple[str, int], paramiko.RSAKey] |
| 195 | |
| 196 | |
| 197 | def set_key_for_node(host_port: Tuple[str, int], key: bytes) -> None: |
| 198 | sio = BytesIO(key) |
| 199 | NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio) |
| 200 | sio.close() |
| 201 | |
| 202 | |
| 203 | def ssh_connect(creds: SSHCredsType, conn_timeout: int = 60) -> interfaces.IHost: |
| 204 | if creds == 'local': |
| 205 | return LocalHost() |
| 206 | |
| 207 | tcp_timeout = 15 |
| 208 | default_banner_timeout = 30 |
| 209 | |
| 210 | ssh = paramiko.SSHClient() |
| 211 | ssh.load_host_keys('/dev/null') |
| 212 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
| 213 | ssh.known_hosts = None |
| 214 | |
| 215 | end_time = time.time() + conn_timeout # type: float |
| 216 | |
| 217 | while True: |
| 218 | try: |
| 219 | time_left = end_time - time.time() |
| 220 | c_tcp_timeout = min(tcp_timeout, time_left) |
| 221 | |
| 222 | banner_timeout_arg = {} # type: Dict[str, int] |
| 223 | if paramiko.__version_info__ >= (1, 15, 2): |
| 224 | banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left)) |
| 225 | |
| 226 | creds = cast(ConnCreds, creds) |
| 227 | |
| 228 | if creds.passwd is not None: |
| 229 | ssh.connect(creds.host, |
| 230 | timeout=c_tcp_timeout, |
| 231 | username=creds.user, |
| 232 | password=cast(str, creds.passwd), |
| 233 | port=creds.port, |
| 234 | allow_agent=False, |
| 235 | look_for_keys=False, |
| 236 | **banner_timeout_arg) |
| 237 | elif creds.key_file is not None: |
| 238 | ssh.connect(creds.host, |
| 239 | username=creds.user, |
| 240 | timeout=c_tcp_timeout, |
| 241 | key_filename=cast(str, creds.key_file), |
| 242 | look_for_keys=False, |
| 243 | port=creds.port, |
| 244 | **banner_timeout_arg) |
| 245 | elif (creds.host, creds.port) in NODE_KEYS: |
| 246 | ssh.connect(creds.host, |
| 247 | username=creds.user, |
| 248 | timeout=c_tcp_timeout, |
| 249 | pkey=NODE_KEYS[(creds.host, creds.port)], |
| 250 | look_for_keys=False, |
| 251 | port=creds.port, |
| 252 | **banner_timeout_arg) |
| 253 | else: |
| 254 | key_file = os.path.expanduser('~/.ssh/id_rsa') |
| 255 | ssh.connect(creds.host, |
| 256 | username=creds.user, |
| 257 | timeout=c_tcp_timeout, |
| 258 | key_filename=key_file, |
| 259 | look_for_keys=False, |
| 260 | port=creds.port, |
| 261 | **banner_timeout_arg) |
| 262 | return SSHHost(ssh, "{0.host}:{0.port}".format(creds), creds.host) |
| 263 | except paramiko.PasswordRequiredException: |
| 264 | raise |
| 265 | except (socket.error, paramiko.SSHException): |
| 266 | if time.time() > end_time: |
| 267 | raise |
| 268 | time.sleep(1) |
| 269 | |
| 270 | |
| 271 | def connect(uri: str, **params) -> interfaces.IHost: |
koder aka kdanilov | bb5fe07 | 2015-05-21 02:50:23 +0300 | [diff] [blame] | 272 | if uri == 'local': |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 273 | res = LocalHost() |
koder aka kdanilov | 0fdaaee | 2015-06-30 11:10:48 +0300 | [diff] [blame] | 274 | else: |
| 275 | creds = parse_ssh_uri(uri) |
| 276 | creds.port = int(creds.port) |
| 277 | res = ssh_connect(creds, **params) |
| 278 | return res |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 279 | |
| 280 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 281 | SetupResult = Tuple[interfaces.IRPC, Dict[str, Any]] |
koder aka kdanilov | 416b87a | 2015-05-12 00:26:04 +0300 | [diff] [blame] | 282 | |
| 283 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 284 | RPCBeforeConnCallback = Callable[[interfaces.IHost, int], None] |
koder aka kdanilov | 7647164 | 2015-08-14 11:44:43 +0300 | [diff] [blame] | 285 | |
koder aka kdanilov | 416b87a | 2015-05-12 00:26:04 +0300 | [diff] [blame] | 286 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 287 | def setup_rpc(node: interfaces.IHost, |
| 288 | rpc_server_code: bytes, |
| 289 | port: int = 0, |
| 290 | rpc_conn_callback: RPCBeforeConnCallback = None) -> SetupResult: |
| 291 | code_file = node.run("mktemp").strip() |
| 292 | log_file = node.run("mktemp").strip() |
| 293 | node.put_to_file(code_file, rpc_server_code) |
| 294 | cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \ |
| 295 | "--show-settings --stdout-file={out_file}" |
| 296 | params_js = node.run(cmd.format(code_file=code_file, |
| 297 | listen_addr=node.get_ip(), |
| 298 | out_file=log_file, |
| 299 | port=port)).strip() |
| 300 | params = json.loads(params_js) |
| 301 | params['log_file'] = log_file |
koder aka kdanilov | 416b87a | 2015-05-12 00:26:04 +0300 | [diff] [blame] | 302 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 303 | if rpc_conn_callback: |
| 304 | ip, port = rpc_conn_callback(node, port) |
| 305 | else: |
| 306 | ip = node.get_ip() |
| 307 | port = int(params['addr'].split(":")[1]) |
koder aka kdanilov | 7647164 | 2015-08-14 11:44:43 +0300 | [diff] [blame] | 308 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 309 | return agent.connect((ip, port)), params |
koder aka kdanilov | 4af1c1d | 2015-05-18 15:48:58 +0300 | [diff] [blame] | 310 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 311 | |
| 312 | def wait_ssh_awailable(addrs: List[Tuple[str, int]], |
| 313 | timeout: int = 300, |
| 314 | tcp_timeout: float = 1.0, |
| 315 | max_threads: int = 32) -> None: |
| 316 | addrs = addrs[:] |
| 317 | tout = utils.Timeout(timeout) |
| 318 | |
| 319 | def check_sock(addr): |
| 320 | s = socket.socket() |
| 321 | s.settimeout(tcp_timeout) |
koder aka kdanilov | 4af1c1d | 2015-05-18 15:48:58 +0300 | [diff] [blame] | 322 | try: |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 323 | s.connect(addr) |
koder aka kdanilov | 4af1c1d | 2015-05-18 15:48:58 +0300 | [diff] [blame] | 324 | return True |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 325 | except (socket.timeout, ConnectionRefusedError): |
koder aka kdanilov | 4af1c1d | 2015-05-18 15:48:58 +0300 | [diff] [blame] | 326 | return False |
koder aka kdanilov | 4af1c1d | 2015-05-18 15:48:58 +0300 | [diff] [blame] | 327 | |
koder aka kdanilov | 22d134e | 2016-11-08 11:33:19 +0200 | [diff] [blame^] | 328 | with ThreadPoolExecutor(max_workers=max_threads) as pool: |
| 329 | while addrs: |
| 330 | check_result = pool.map(check_sock, addrs) |
| 331 | addrs = [addr for ok, addr in zip(check_result, addrs) if not ok] # type: List[Tuple[str, int]] |
| 332 | tout.tick() |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 333 | |
koder aka kdanilov | e06762a | 2015-03-22 23:32:09 +0200 | [diff] [blame] | 334 | |
koder aka kdanilov | 0c598a1 | 2015-04-21 03:01:40 +0300 | [diff] [blame] | 335 | |