blob: fdf4c91888102af2db02f0b27a97b1d99ff0c12f [file] [log] [blame]
Alex9a4ad212020-10-01 18:04:25 -05001import queue
2import subprocess
3import traceback
4import threading
5
6from time import sleep
7from .exception import TimeoutException, CheckerException
8from .other import shell, piped_shell
9from .log import logger, logger_cli
10
11
12# We do not use paramiko here to preserve system level ssh config
13def ssh_shell_p(
14 command,
15 host,
16 username=None,
17 keypath=None,
18 port=None,
19 silent=False,
20 piped=False,
21 use_sudo=False
22):
23 _ssh_cmd = []
24 _ssh_cmd.append("ssh")
25 if silent:
26 _ssh_cmd.append("-q")
27 # Build SSH cmd
28 if keypath:
29 _ssh_cmd.append("-i " + keypath)
30 if port:
31 _ssh_cmd.append("-p " + str(port))
32 if username:
33 _ssh_cmd.append(username+'@'+host)
34 else:
35 _ssh_cmd.append(host)
36
37 if use_sudo:
38 _ssh_cmd.append("sudo")
39
40 _ssh_cmd.append(command)
41
42 _ssh_cmd = " ".join(_ssh_cmd)
43 if not piped:
44 return shell(_ssh_cmd)
45 else:
46 return piped_shell(_ssh_cmd)
47
48
49def scp_p(
50 source,
51 target,
52 port=None,
53 keypath=None,
54 silent=False,
55 piped=False
56):
57 _scp_cmd = []
58 _scp_cmd.append("scp")
59 if port:
60 _scp_cmd.append("-P " + str(port))
61 if silent:
62 _scp_cmd.append("-q")
63 # Build SSH cmd
64 if keypath:
65 _scp_cmd.append("-i " + keypath)
66 _scp_cmd.append(source)
67 _scp_cmd.append(target)
68 _scp_cmd = " ".join(_scp_cmd)
69 if not piped:
70 return shell(_scp_cmd)
71 else:
72 return piped_shell(_scp_cmd)
73
74
75def output_reader(_stdout, outq):
76 for line in iter(_stdout.readline, b''):
77 outq.put(line.decode('utf-8'))
78
79
80# Base class for all SSH related actions
81class SshBase(object):
82 def __init__(
83 self,
84 tgt_host,
85 user=None,
86 keypath=None,
87 port=None,
88 timeout=15,
89 silent=False,
90 piped=False
91 ):
92 self._cmd = ["ssh"]
93 self.timeout = timeout
94 self.port = port if port else 22
95 self.host = tgt_host
96 self.username = user
97 self.keypath = keypath
98 self.silent = silent
99 self.piped = piped
100 self.output = []
101
102 self._options = ["-tt"]
103 if self.keypath:
104 self._options += ["-i", self.keypath]
105 if self.port:
106 self._options += ["-p", str(self.port)]
107 self._extra_options = [
108 "-o", "UserKnownHostsFile=/dev/null",
109 "-o", "StrictHostKeyChecking=no"
110 ]
111
112 self._host_uri = ""
113 if self.username:
114 self._host_uri = self.username + "@" + self.host
115 else:
116 self._host_uri = self.host
117
118 def _connect(self, banner="Welcome"):
119 if not isinstance(banner, str):
120 raise CheckerException(
121 "Invalid SSH banner type: '{}'".format(type(banner))
122 )
123 logger.debug("...connecting")
124 while True:
125 try:
126 line = self.outq.get(block=False)
127 self.output.append(line)
128 if line.startswith(banner):
129 break
130 except queue.Empty:
131 logger.debug("... {} sec".format(self.timeout))
132 sleep(1)
133 self.timeout -= 1
134 if not self.timeout:
135 logger.debug(
136 "...timed out after {} sec".format(str(self.timeout))
137 )
138 return False
139 logger.debug("...connected")
140 return True
141
142 def _wait_for_string(self, string):
143 logger.debug("...waiting for '{}'".format(string))
144 while True:
145 try:
146 line = self.outq.get(block=False)
147 line = line.decode() if isinstance(line, bytes) else line
148 self.output.append(line)
149 if not line.startswith(string):
150 continue
151 else:
152 break
153 except queue.Empty:
154 logger.debug("... {} sec".format(self.timeout))
155 sleep(1)
156 self.timeout -= 1
157 if not self.timeout:
158 logger.debug(
159 "...timed out after {} sec".format(str(self.timeout))
160 )
161 return False
162 logger.debug("...found")
163 return True
164
165 def _init_connection(self, cmd):
166 self._proc = subprocess.Popen(
167 cmd,
168 stdin=subprocess.PIPE,
169 stdout=subprocess.PIPE,
170 stderr=subprocess.PIPE,
171 universal_newlines=False,
172 bufsize=0
173 )
174 # Create thread safe output getter
175 self.outq = queue.Queue()
176 self._t = threading.Thread(
177 target=output_reader,
178 args=(self._proc.stdout, self.outq)
179 )
180 self._t.start()
181
182 # Track if there is an yes/no
183 if not self._connect():
184 raise TimeoutException(
185 "SSH connection to '{}'".format(self.host)
186 )
187
188 self.input = self._proc.stdin
189 self.get_output()
190 logger.debug(
191 "Connected. Banners:\n{}".format(
192 "".join(self.flush_output())
193 )
194 )
195
196 def _end_connection(self):
197 # Kill the ssh process if it is alive
198 if self._proc.poll() is None:
199 self._proc.kill()
200 self.get_output()
201
202 return
203
204 def do(self, cmd, timeout=30, sudo=False, strip_cmd=True):
205 cmd = cmd if isinstance(cmd, bytes) else bytes(cmd.encode('utf-8'))
206 logger.debug("...ssh: '{}'".format(cmd))
207 if sudo:
208 _cmd = b"sudo " + cmd
209 else:
210 _cmd = cmd
211 # run command
212 self.input.write(_cmd + b'\n')
213 # wait for completion
214 self.wait_ready(_cmd, timeout=timeout)
215 self.get_output()
216 _output = self.flush_output().replace('\r', '')
217 if strip_cmd:
218 return "\n".join(_output.splitlines()[1:])
219 else:
220 return _output
221
222 def get_output(self):
223 while True:
224 try:
225 line = self.outq.get(block=False)
226 line = str(line) if isinstance(line, bytes) else line
227 self.output.append(line)
228 except queue.Empty:
229 return self.output
230 return None
231
232 def flush_output(self, as_string=True):
233 _out = self.output
234 self.output = []
235 if as_string:
236 return "".join(_out)
237 else:
238 return _out
239
240 def wait_ready(self, cmd, timeout=60):
241 def _strip_cmd_carrets(_str, carret='\r', skip_chars=1):
242 _cnt = _str.count(carret)
243 while _cnt > 0:
244 _idx = _str.index(carret)
245 _str = _str[:_idx] + _str[_idx+1+skip_chars:]
246 _cnt -= 1
247 return _str
248 while True:
249 try:
250 _line = self.outq.get(block=False)
251 line = _line.decode() if isinstance(_line, bytes) else _line
252 # line = line.replace('\r', '')
253 self.output.append(line)
254 # check if this is the command itself and skip
255 if '$' in line:
256 _cmd = line.split('$', 1)[1].strip()
257 _cmd = _strip_cmd_carrets(_cmd)
258 if _cmd == cmd.decode():
259 continue
260 break
261 except queue.Empty:
262 logger.debug("... {} sec".format(timeout))
263 sleep(1)
264 timeout -= 1
265 if not timeout:
266 logger.debug("...timed out")
267 return False
268 return True
269
270 def wait_for_string(self, string, timeout=60):
271 if not self._wait_for_string(string):
272 raise TimeoutException(
273 "Time out waiting for string '{}'".format(string)
274 )
275 else:
276 return True
277
278
279class SshShell(SshBase):
280 def __enter__(self):
281 self._cmd = ["ssh"]
282 self._cmd += self._options
283 self._cmd += self._extra_options
284 self._cmd += [self._host_uri]
285
286 logger.debug("...shell to: '{}'".format(" ".join(self._cmd)))
287 self._init_connection(self._cmd)
288 return self
289
290 def __exit__(self, _type, _value, _traceback):
291 self._end_connection()
292 if _value:
293 logger.warn(
294 "Error running SSH:\r\n{}".format(
295 "".join(traceback.format_exception(
296 _type,
297 _value,
298 _traceback
299 ))
300 )
301 )
302
303 return True
304
305 def connect(self):
306 return self.__enter__()
307
308 def kill(self):
309 self._end_connection()
310
311 def get_host_path(self, path):
312 _uri = self.host + ":" + path
313 if self.username:
314 _uri = self.username + "@" + _uri
315 return _uri
316
317 def scp(self, _src, _dst):
318 self._scp_options = []
319 if self.keypath:
320 self._scp_options += ["-i", self.keypath]
321 if self.port:
322 self._scp_options += ["-P", str(self.port)]
323
324 _cmd = ["scp"]
325 _cmd += self._scp_options
326 _cmd += self._extra_options
327 _cmd += [_src]
328 _cmd += [_dst]
329
330 logger.debug("...scp: '{}'".format(" ".join(_cmd)))
331 _proc = subprocess.Popen(
332 _cmd,
333 stdout=subprocess.PIPE,
334 stderr=subprocess.PIPE
335 )
336 _r = _proc.communicate()
337 _e = _r[1].decode() if _r[1] else ""
338 return _proc.returncode, _r[0].decode(), _e
339
340
341class PortForward(SshBase):
342 def __init__(
343 self,
344 host,
345 fwd_host,
346 user=None,
347 keypath=None,
348 port=None,
349 loc_port=10022,
350 fwd_port=22,
351 timeout=15
352 ):
353 super(PortForward, self).__init__(
354 host,
355 user=user,
356 keypath=keypath,
357 port=port,
358 timeout=timeout,
359 silent=True,
360 piped=False
361 )
362 self.f_host = fwd_host
363 self.l_port = loc_port
364 self.f_port = fwd_port
365
366 self._forward_options = [
367 "-L",
368 ":".join([str(self.l_port), self.f_host, str(self.f_port)])
369 ]
370
371 def __enter__(self):
372 self._cmd = ["ssh"]
373 self._cmd += self._forward_options
374 self._cmd += self._options
375 self._cmd += self._extra_options
376 self._cmd += [self._host_uri]
377
378 logger.debug(
379 "...port forwarding: '{}'".format(" ".join(self._cmd))
380 )
381 self._init_connection(self._cmd)
382 return self
383
384 def __exit__(self, _type, _value, _traceback):
385 self._end_connection()
386 if _value:
387 logger_cli.warn(
388 "Error running SSH:\r\n{}".format(
389 "".join(traceback.format_exception(
390 _type,
391 _value,
392 _traceback
393 ))
394 )
395 )
396
397 return True
398
399 def connect(self):
400 return self.__enter__()
401
402 def kill(self):
403 self._end_connection()