blob: 2941e7c36354f66e2f64026abbafe3316337980d [file] [log] [blame]
koder aka kdanilove06762a2015-03-22 23:32:09 +02001import re
koder aka kdanilov22d134e2016-11-08 11:33:19 +02002import json
koder aka kdanilov3a6633e2015-03-26 18:20:00 +02003import time
koder aka kdanilov652cd802015-04-13 12:21:07 +03004import socket
koder aka kdanilove06762a2015-03-22 23:32:09 +02005import logging
6import os.path
koder aka kdanilov3a6633e2015-03-26 18:20:00 +02007import getpass
koder aka kdanilov22d134e2016-11-08 11:33:19 +02008from io import BytesIO
koder aka kdanilov0c598a12015-04-21 03:01:40 +03009import subprocess
koder aka kdanilov22d134e2016-11-08 11:33:19 +020010from typing import Union, Optional, cast, Dict, List, Tuple, Any, Callable
11from concurrent.futures import ThreadPoolExecutor
koder aka kdanilov652cd802015-04-13 12:21:07 +030012
koder aka kdanilov3a6633e2015-03-26 18:20:00 +020013import paramiko
koder aka kdanilove06762a2015-03-22 23:32:09 +020014
koder aka kdanilov22d134e2016-11-08 11:33:19 +020015import agent
16
17from . import interfaces, utils
18
koder aka kdanilove06762a2015-03-22 23:32:09 +020019
koder aka kdanilovcff7b2e2015-04-18 20:48:15 +030020logger = logging.getLogger("wally")
koder aka kdanilove06762a2015-03-22 23:32:09 +020021
22
koder aka kdanilove06762a2015-03-22 23:32:09 +020023class URIsNamespace(object):
24 class ReParts(object):
25 user_rr = "[^:]*?"
koder aka kdanilov7e0f7cf2015-05-01 17:24:35 +030026 host_rr = "[^:@]*?"
koder aka kdanilove06762a2015-03-22 23:32:09 +020027 port_rr = "\\d+"
28 key_file_rr = "[^:@]*"
29 passwd_rr = ".*?"
30
31 re_dct = ReParts.__dict__
32
33 for attr_name, val in re_dct.items():
34 if attr_name.endswith('_rr'):
35 new_rr = "(?P<{0}>{1})".format(attr_name[:-3], val)
36 setattr(ReParts, attr_name, new_rr)
37
38 re_dct = ReParts.__dict__
39
40 templs = [
41 "^{host_rr}$",
koder aka kdanilov7e0f7cf2015-05-01 17:24:35 +030042 "^{host_rr}:{port_rr}$",
koder aka kdanilov416b87a2015-05-12 00:26:04 +030043 "^{host_rr}::{key_file_rr}$",
44 "^{host_rr}:{port_rr}:{key_file_rr}$",
koder aka kdanilov7e0f7cf2015-05-01 17:24:35 +030045 "^{user_rr}@{host_rr}$",
46 "^{user_rr}@{host_rr}:{port_rr}$",
koder aka kdanilove06762a2015-03-22 23:32:09 +020047 "^{user_rr}@{host_rr}::{key_file_rr}$",
48 "^{user_rr}@{host_rr}:{port_rr}:{key_file_rr}$",
koder aka kdanilov7e0f7cf2015-05-01 17:24:35 +030049 "^{user_rr}:{passwd_rr}@{host_rr}$",
50 "^{user_rr}:{passwd_rr}@{host_rr}:{port_rr}$",
koder aka kdanilove06762a2015-03-22 23:32:09 +020051 ]
52
koder aka kdanilov22d134e2016-11-08 11:33:19 +020053 uri_reg_exprs = [] # type: List[str]
koder aka kdanilove06762a2015-03-22 23:32:09 +020054 for templ in templs:
55 uri_reg_exprs.append(templ.format(**re_dct))
56
57
koder aka kdanilov22d134e2016-11-08 11:33:19 +020058class ConnCreds:
59 conn_uri_attrs = ("user", "passwd", "host", "port", "key_file")
60
61 def __init__(self) -> None:
62 self.user = None # type: Optional[str]
63 self.passwd = None # type: Optional[str]
64 self.host = None # type: str
65 self.port = 22 # type: int
66 self.key_file = None # type: Optional[str]
67
68 def __str__(self) -> str:
69 return str(self.__dict__)
70
71
72SSHCredsType = Union[str, ConnCreds]
73
74
75def parse_ssh_uri(uri: str) -> ConnCreds:
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +030076 # [ssh://]+
koder aka kdanilov7e0f7cf2015-05-01 17:24:35 +030077 # user:passwd@ip_host:port
78 # user:passwd@ip_host
koder aka kdanilove06762a2015-03-22 23:32:09 +020079 # user@ip_host:port
80 # user@ip_host
81 # ip_host:port
82 # ip_host
83 # user@ip_host:port:path_to_key_file
84 # user@ip_host::path_to_key_file
85 # ip_host:port:path_to_key_file
86 # ip_host::path_to_key_file
87
koder aka kdanilov4d4771c2015-04-23 01:32:02 +030088 if uri.startswith("ssh://"):
89 uri = uri[len("ssh://"):]
90
koder aka kdanilove06762a2015-03-22 23:32:09 +020091 res = ConnCreds()
koder aka kdanilov22d134e2016-11-08 11:33:19 +020092 res.port = 22
koder aka kdanilove06762a2015-03-22 23:32:09 +020093 res.key_file = None
94 res.passwd = None
koder aka kdanilova4a570f2015-04-23 22:11:40 +030095 res.user = getpass.getuser()
koder aka kdanilove06762a2015-03-22 23:32:09 +020096
koder aka kdanilov22d134e2016-11-08 11:33:19 +020097 for rr in URIsNamespace.uri_reg_exprs:
koder aka kdanilove06762a2015-03-22 23:32:09 +020098 rrm = re.match(rr, uri)
99 if rrm is not None:
100 res.__dict__.update(rrm.groupdict())
101 return res
koder aka kdanilov652cd802015-04-13 12:21:07 +0300102
koder aka kdanilove06762a2015-03-22 23:32:09 +0200103 raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
104
105
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200106class LocalHost(interfaces.IHost):
107 def __str__(self):
108 return "<Local>"
109
110 def get_ip(self) -> str:
111 return 'localhost'
112
113 def put_to_file(self, path: str, content: bytes) -> None:
114 dirname = os.path.dirname(path)
115 if not os.path.exists(dirname):
116 os.makedirs(dirname)
117 with open(path, "wb") as fd:
118 fd.write(content)
119
120 def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
121 proc = subprocess.Popen(cmd, shell=True,
122 stdin=subprocess.PIPE,
123 stdout=subprocess.PIPE,
124 stderr=subprocess.STDOUT)
125
126 stdout_data, _ = proc.communicate()
127 if proc.returncode != 0:
128 templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
129 raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
130
131 return stdout_data
132
133
134class SSHHost(interfaces.IHost):
135 def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
136 self.conn = ssh_conn
137 self.node_name = node_name
138 self.ip = ip
139
140 def get_ip(self) -> str:
141 return self.ip
142
143 def __str__(self) -> str:
144 return self.node_name
145
146 def put_to_file(self, path: str, content: bytes) -> None:
147 with self.conn.open_sftp() as sftp:
148 with sftp.open(path, "wb") as fd:
149 fd.write(content)
150
151 def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
152 transport = self.conn.get_transport()
153 session = transport.open_session()
154
155 try:
156 session.set_combine_stderr(True)
157
158 stime = time.time()
159
160 if not nolog:
161 logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
162
163 session.exec_command(cmd)
164 session.settimeout(1)
165 session.shutdown_write()
166 output = ""
167
168 while True:
169 try:
170 ndata = session.recv(1024)
171 output += ndata
172 if "" == ndata:
173 break
174 except socket.timeout:
175 pass
176
177 if time.time() - stime > timeout:
178 raise OSError(output + "\nExecution timeout")
179
180 code = session.recv_exit_status()
181 finally:
182 found = False
183
184 if found:
185 session.close()
186
187 if code != 0:
188 templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
189 raise OSError(templ.format(self, cmd, code, output))
190
191 return output
192
193
194NODE_KEYS = {} # type: Dict[Tuple[str, int], paramiko.RSAKey]
195
196
197def set_key_for_node(host_port: Tuple[str, int], key: bytes) -> None:
198 sio = BytesIO(key)
199 NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
200 sio.close()
201
202
203def ssh_connect(creds: SSHCredsType, conn_timeout: int = 60) -> interfaces.IHost:
204 if creds == 'local':
205 return LocalHost()
206
207 tcp_timeout = 15
208 default_banner_timeout = 30
209
210 ssh = paramiko.SSHClient()
211 ssh.load_host_keys('/dev/null')
212 ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
213 ssh.known_hosts = None
214
215 end_time = time.time() + conn_timeout # type: float
216
217 while True:
218 try:
219 time_left = end_time - time.time()
220 c_tcp_timeout = min(tcp_timeout, time_left)
221
222 banner_timeout_arg = {} # type: Dict[str, int]
223 if paramiko.__version_info__ >= (1, 15, 2):
224 banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
225
226 creds = cast(ConnCreds, creds)
227
228 if creds.passwd is not None:
229 ssh.connect(creds.host,
230 timeout=c_tcp_timeout,
231 username=creds.user,
232 password=cast(str, creds.passwd),
233 port=creds.port,
234 allow_agent=False,
235 look_for_keys=False,
236 **banner_timeout_arg)
237 elif creds.key_file is not None:
238 ssh.connect(creds.host,
239 username=creds.user,
240 timeout=c_tcp_timeout,
241 key_filename=cast(str, creds.key_file),
242 look_for_keys=False,
243 port=creds.port,
244 **banner_timeout_arg)
245 elif (creds.host, creds.port) in NODE_KEYS:
246 ssh.connect(creds.host,
247 username=creds.user,
248 timeout=c_tcp_timeout,
249 pkey=NODE_KEYS[(creds.host, creds.port)],
250 look_for_keys=False,
251 port=creds.port,
252 **banner_timeout_arg)
253 else:
254 key_file = os.path.expanduser('~/.ssh/id_rsa')
255 ssh.connect(creds.host,
256 username=creds.user,
257 timeout=c_tcp_timeout,
258 key_filename=key_file,
259 look_for_keys=False,
260 port=creds.port,
261 **banner_timeout_arg)
262 return SSHHost(ssh, "{0.host}:{0.port}".format(creds), creds.host)
263 except paramiko.PasswordRequiredException:
264 raise
265 except (socket.error, paramiko.SSHException):
266 if time.time() > end_time:
267 raise
268 time.sleep(1)
269
270
271def connect(uri: str, **params) -> interfaces.IHost:
koder aka kdanilovbb5fe072015-05-21 02:50:23 +0300272 if uri == 'local':
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200273 res = LocalHost()
koder aka kdanilov0fdaaee2015-06-30 11:10:48 +0300274 else:
275 creds = parse_ssh_uri(uri)
276 creds.port = int(creds.port)
277 res = ssh_connect(creds, **params)
278 return res
koder aka kdanilove06762a2015-03-22 23:32:09 +0200279
280
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200281SetupResult = Tuple[interfaces.IRPC, Dict[str, Any]]
koder aka kdanilov416b87a2015-05-12 00:26:04 +0300282
283
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200284RPCBeforeConnCallback = Callable[[interfaces.IHost, int], None]
koder aka kdanilov76471642015-08-14 11:44:43 +0300285
koder aka kdanilov416b87a2015-05-12 00:26:04 +0300286
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200287def setup_rpc(node: interfaces.IHost,
288 rpc_server_code: bytes,
289 port: int = 0,
290 rpc_conn_callback: RPCBeforeConnCallback = None) -> SetupResult:
291 code_file = node.run("mktemp").strip()
292 log_file = node.run("mktemp").strip()
293 node.put_to_file(code_file, rpc_server_code)
294 cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
295 "--show-settings --stdout-file={out_file}"
296 params_js = node.run(cmd.format(code_file=code_file,
297 listen_addr=node.get_ip(),
298 out_file=log_file,
299 port=port)).strip()
300 params = json.loads(params_js)
301 params['log_file'] = log_file
koder aka kdanilov416b87a2015-05-12 00:26:04 +0300302
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200303 if rpc_conn_callback:
304 ip, port = rpc_conn_callback(node, port)
305 else:
306 ip = node.get_ip()
307 port = int(params['addr'].split(":")[1])
koder aka kdanilov76471642015-08-14 11:44:43 +0300308
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200309 return agent.connect((ip, port)), params
koder aka kdanilov4af1c1d2015-05-18 15:48:58 +0300310
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200311
312def wait_ssh_awailable(addrs: List[Tuple[str, int]],
313 timeout: int = 300,
314 tcp_timeout: float = 1.0,
315 max_threads: int = 32) -> None:
316 addrs = addrs[:]
317 tout = utils.Timeout(timeout)
318
319 def check_sock(addr):
320 s = socket.socket()
321 s.settimeout(tcp_timeout)
koder aka kdanilov4af1c1d2015-05-18 15:48:58 +0300322 try:
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200323 s.connect(addr)
koder aka kdanilov4af1c1d2015-05-18 15:48:58 +0300324 return True
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200325 except (socket.timeout, ConnectionRefusedError):
koder aka kdanilov4af1c1d2015-05-18 15:48:58 +0300326 return False
koder aka kdanilov4af1c1d2015-05-18 15:48:58 +0300327
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200328 with ThreadPoolExecutor(max_workers=max_threads) as pool:
329 while addrs:
330 check_result = pool.map(check_sock, addrs)
331 addrs = [addr for ok, addr in zip(check_result, addrs) if not ok] # type: List[Tuple[str, int]]
332 tout.tick()
koder aka kdanilove06762a2015-03-22 23:32:09 +0200333
koder aka kdanilove06762a2015-03-22 23:32:09 +0200334
koder aka kdanilov0c598a12015-04-21 03:01:40 +0300335