blob: da3655df4616a723d5eaf8b7d6306dfead1970e9 [file] [log] [blame]
Dennis Dmitrieve56c8b92017-06-16 01:53:16 +03001# Copyright 2013 - 2016 Mirantis, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
14
15from __future__ import unicode_literals
16
17import base64
18import os
19import posixpath
20import stat
21import sys
22import threading
23import time
24import warnings
25
26import paramiko
27import six
28
29from reclass_tools.helpers import decorators
30from reclass_tools.helpers import exec_result
31from reclass_tools.helpers import proc_enums
32from reclass_tools import logger
33
34
35def get_private_keys(home, identity_files=None):
36 if not identity_files:
37 identity_files = ['.ssh/id_rsa']
38 keys = []
39 for i in identity_files:
40 with open(os.path.join(home, i)) as f:
41 keys.append(paramiko.RSAKey.from_private_key(f))
42 return keys
43
44
45class SSHAuth(object):
46 __slots__ = ['__username', '__password', '__key', '__keys']
47
48 def __init__(
49 self,
50 username=None, password=None, key=None, keys=None):
51 """SSH authorisation object
52
53 Used to authorize SSHClient.
54 Single SSHAuth object is associated with single host:port.
55 Password and key is private, other data is read-only.
56
57 :type username: str
58 :type password: str
59 :type key: paramiko.RSAKey
60 :type keys: list
61 """
62 self.__username = username
63 self.__password = password
64 self.__key = key
65 self.__keys = [None]
66 if key is not None:
67 # noinspection PyTypeChecker
68 self.__keys.append(key)
69 if keys is not None:
70 for key in keys:
71 if key not in self.__keys:
72 self.__keys.append(key)
73
74 @property
75 def username(self):
76 """Username for auth
77
78 :rtype: str
79 """
80 return self.__username
81
82 @staticmethod
83 def __get_public_key(key):
84 """Internal method for get public key from private
85
86 :type key: paramiko.RSAKey
87 """
88 if key is None:
89 return None
90 return '{0} {1}'.format(key.get_name(), key.get_base64())
91
92 @property
93 def public_key(self):
94 """public key for stored private key if presents else None
95
96 :rtype: str
97 """
98 return self.__get_public_key(self.__key)
99
100 def enter_password(self, tgt):
101 """Enter password to STDIN
102
103 Note: required for 'sudo' call
104
105 :type tgt: file
106 :rtype: str
107 """
108 # noinspection PyTypeChecker
109 return tgt.write('{}\n'.format(self.__password))
110
111 def connect(self, client, hostname=None, port=22, log=True):
112 """Connect SSH client object using credentials
113
114 :type client:
115 paramiko.client.SSHClient
116 paramiko.transport.Transport
117 :type log: bool
118 :raises paramiko.AuthenticationException
119 """
120 kwargs = {
121 'username': self.username,
122 'password': self.__password}
123 if hostname is not None:
124 kwargs['hostname'] = hostname
125 kwargs['port'] = port
126
127 keys = [self.__key]
128 keys.extend([k for k in self.__keys if k != self.__key])
129
130 for key in keys:
131 kwargs['pkey'] = key
132 try:
133 client.connect(**kwargs)
134 if self.__key != key:
135 self.__key = key
136 logger.debug(
137 'Main key has been updated, public key is: \n'
138 '{}'.format(self.public_key))
139 return
140 except paramiko.PasswordRequiredException:
141 if self.__password is None:
142 logger.exception('No password has been set!')
143 raise
144 else:
145 logger.critical(
146 'Unexpected PasswordRequiredException, '
147 'when password is set!')
148 raise
149 except paramiko.AuthenticationException:
150 continue
151 msg = 'Connection using stored authentication info failed!'
152 if log:
153 logger.exception(
154 'Connection using stored authentication info failed!')
155 raise paramiko.AuthenticationException(msg)
156
157 def __hash__(self):
158 return hash((
159 self.__class__,
160 self.username,
161 self.__password,
162 tuple(self.__keys)
163 ))
164
165 def __eq__(self, other):
166 return hash(self) == hash(other)
167
168 def __ne__(self, other):
169 return not self.__eq__(other)
170
171 def __deepcopy__(self, memo):
172 return self.__class__(
173 username=self.username,
174 password=self.__password,
175 key=self.__key,
176 keys=self.__keys.copy()
177 )
178
179 def copy(self):
180 return self.__class__(
181 username=self.username,
182 password=self.__password,
183 key=self.__key,
184 keys=self.__keys
185 )
186
187 def __repr__(self):
188 _key = (
189 None if self.__key is None else
190 '<private for pub: {}>'.format(self.public_key)
191 )
192 _keys = []
193 for k in self.__keys:
194 if k == self.__key:
195 continue
196 # noinspection PyTypeChecker
197 _keys.append(
198 '<private for pub: {}>'.format(
199 self.__get_public_key(key=k)) if k is not None else None)
200
201 return (
202 '{cls}(username={username}, '
203 'password=<*masked*>, key={key}, keys={keys})'.format(
204 cls=self.__class__.__name__,
205 username=self.username,
206 key=_key,
207 keys=_keys)
208 )
209
210 def __str__(self):
211 return (
212 '{cls} for {username}'.format(
213 cls=self.__class__.__name__,
214 username=self.username,
215 )
216 )
217
218
219class _MemorizedSSH(type):
220 """Memorize metaclass for SSHClient
221
222 This class implements caching and managing of SSHClient connections.
223 Class is not in public scope: all required interfaces is accessible throw
224 SSHClient classmethods.
225
226 Main flow is:
227 SSHClient() -> check for cached connection and
228 - If exists the same: check for alive, reconnect if required and return
229 - If exists with different credentials: delete and continue processing
230 create new connection and cache on success
231 * Note: each invocation of SSHClient instance will return current dir to
232 the root of the current user home dir ("cd ~").
233 It is necessary to avoid unpredictable behavior when the same
234 connection is used from different places.
235 If you need to enter some directory and execute command there, please
236 use the following approach:
237 cmd1 = "cd <some dir> && <command1>"
238 cmd2 = "cd <some dir> && <command2>"
239
240 Close cached connections is allowed per-client and all stored:
241 connection will be closed, but still stored in cache for faster reconnect
242
243 Clear cache is strictly not recommended:
244 from this moment all open connections should be managed manually,
245 duplicates is possible.
246 """
247 __cache = {}
248
249 def __call__(
250 cls,
251 host, port=22,
252 username=None, password=None, private_keys=None,
253 auth=None
254 ):
255 """Main memorize method: check for cached instance and return it
256
257 :type host: str
258 :type port: int
259 :type username: str
260 :type password: str
261 :type private_keys: list
262 :type auth: SSHAuth
263 :rtype: SSHClient
264 """
265 if (host, port) in cls.__cache:
266 key = host, port
267 if auth is None:
268 auth = SSHAuth(
269 username=username, password=password, keys=private_keys)
270 if hash((cls, host, port, auth)) == hash(cls.__cache[key]):
271 ssh = cls.__cache[key]
272 # noinspection PyBroadException
273 try:
274 ssh.execute('cd ~', timeout=5)
275 except BaseException: # Note: Do not change to lower level!
276 logger.debug('Reconnect {}'.format(ssh))
277 ssh.reconnect()
278 return ssh
279 if sys.getrefcount(cls.__cache[key]) == 2:
280 # If we have only cache reference and temporary getrefcount
281 # reference: close connection before deletion
282 logger.debug('Closing {} as unused'.format(cls.__cache[key]))
283 cls.__cache[key].close()
284 del cls.__cache[key]
285 # noinspection PyArgumentList
286 return super(
287 _MemorizedSSH, cls).__call__(
288 host=host, port=port,
289 username=username, password=password, private_keys=private_keys,
290 auth=auth)
291
292 @classmethod
293 def record(mcs, ssh):
294 """Record SSH client to cache
295
296 :type ssh: SSHClient
297 """
298 mcs.__cache[(ssh.hostname, ssh.port)] = ssh
299
300 @classmethod
301 def clear_cache(mcs):
302 """Clear cached connections for initialize new instance on next call"""
303 n_count = 3 if six.PY3 else 4
304 # PY3: cache, ssh, temporary
305 # PY4: cache, values mapping, ssh, temporary
306 for ssh in mcs.__cache.values():
307 if sys.getrefcount(ssh) == n_count:
308 logger.debug('Closing {} as unused'.format(ssh))
309 ssh.close()
310 mcs.__cache = {}
311
312 @classmethod
313 def close_connections(mcs, hostname=None):
314 """Close connections for selected or all cached records
315
316 :type hostname: str
317 """
318 if hostname is None:
319 keys = [key for key, ssh in mcs.__cache.items() if ssh.is_alive]
320 else:
321 keys = [
322 (host, port)
323 for (host, port), ssh
324 in mcs.__cache.items() if host == hostname and ssh.is_alive]
325 # raise ValueError(keys)
326 for key in keys:
327 mcs.__cache[key].close()
328
329
330class SSHClient(six.with_metaclass(_MemorizedSSH, object)):
331 __slots__ = [
332 '__hostname', '__port', '__auth', '__ssh', '__sftp', 'sudo_mode',
333 '__lock'
334 ]
335
336 class __get_sudo(object):
337 """Context manager for call commands with sudo"""
338 def __init__(self, ssh, enforce=None):
339 """Context manager for call commands with sudo
340
341 :type ssh: SSHClient
342 :type enforce: bool
343 """
344 self.__ssh = ssh
345 self.__sudo_status = ssh.sudo_mode
346 self.__enforce = enforce
347
348 def __enter__(self):
349 self.__sudo_status = self.__ssh.sudo_mode
350 if self.__enforce is not None:
351 self.__ssh.sudo_mode = self.__enforce
352
353 def __exit__(self, exc_type, exc_val, exc_tb):
354 self.__ssh.sudo_mode = self.__sudo_status
355
356 # noinspection PyPep8Naming
357 class get_sudo(__get_sudo):
358 """Context manager for call commands with sudo"""
359
360 def __init__(self, ssh, enforce=True):
361 warnings.warn(
362 'SSHClient.get_sudo(SSHClient()) is deprecated in favor of '
363 'SSHClient().sudo(enforce=...) , which is much more powerful.')
364 super(self.__class__, self).__init__(ssh=ssh, enforce=enforce)
365
366 def __hash__(self):
367 return hash((
368 self.__class__,
369 self.hostname,
370 self.port,
371 self.auth))
372
373 def __init__(
374 self,
375 host, port=22,
376 username=None, password=None, private_keys=None,
377 auth=None
378 ):
379 """SSHClient helper
380
381 :type host: str
382 :type port: int
383 :type username: str
384 :type password: str
385 :type private_keys: list
386 :type auth: SSHAuth
387 """
388 self.__lock = threading.RLock()
389
390 self.__hostname = host
391 self.__port = port
392
393 self.sudo_mode = False
394 self.__ssh = paramiko.SSHClient()
395 self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
396 self.__sftp = None
397
398 self.__auth = auth if auth is None else auth.copy()
399
400 if auth is None:
401 msg = (
402 'SSHClient(host={host}, port={port}, username={username}): '
403 'initialization by username/password/private_keys '
404 'is deprecated in favor of SSHAuth usage. '
405 'Please update your code'.format(
406 host=host, port=port, username=username
407 ))
408 warnings.warn(msg, DeprecationWarning)
409 logger.debug(msg)
410
411 self.__auth = SSHAuth(
412 username=username,
413 password=password,
414 keys=private_keys
415 )
416
417 self.__connect()
418 _MemorizedSSH.record(ssh=self)
419 if auth is None:
420 logger.info(
421 '{0}:{1}> SSHAuth was made from old style creds: '
422 '{2}'.format(self.hostname, self.port, self.auth))
423
424 @property
425 def lock(self):
426 """Connection lock
427
428 :rtype: threading.RLock
429 """
430 return self.__lock
431
432 @property
433 def auth(self):
434 """Internal authorisation object
435
436 Attention: this public property is mainly for inheritance,
437 debug and information purposes.
438 Calls outside SSHClient and child classes is sign of incorrect design.
439 Change is completely disallowed.
440
441 :rtype: SSHAuth
442 """
443 return self.__auth
444
445 @property
446 def hostname(self):
447 """Connected remote host name
448
449 :rtype: str
450 """
451 return self.__hostname
452
453 @property
454 def host(self):
455 """Hostname access for backward compatibility
456
457 :rtype: str
458 """
459 warnings.warn(
460 'host has been deprecated in favor of hostname',
461 DeprecationWarning
462 )
463 return self.hostname
464
465 @property
466 def port(self):
467 """Connected remote port number
468
469 :rtype: int
470 """
471 return self.__port
472
473 @property
474 def is_alive(self):
475 """Paramiko status: ready to use|reconnect required
476
477 :rtype: bool
478 """
479 return self.__ssh.get_transport() is not None
480
481 def __repr__(self):
482 return '{cls}(host={host}, port={port}, auth={auth!r})'.format(
483 cls=self.__class__.__name__, host=self.hostname, port=self.port,
484 auth=self.auth
485 )
486
487 def __str__(self):
488 return '{cls}(host={host}, port={port}) for user {user}'.format(
489 cls=self.__class__.__name__, host=self.hostname, port=self.port,
490 user=self.auth.username
491 )
492
493 @property
494 def _ssh(self):
495 """ssh client object getter for inheritance support only
496
497 Attention: ssh client object creation and change
498 is allowed only by __init__ and reconnect call.
499
500 :rtype: paramiko.SSHClient
501 """
502 return self.__ssh
503
504 @decorators.retry(paramiko.SSHException, count=3, delay=3)
505 def __connect(self):
506 """Main method for connection open"""
507 with self.lock:
508 self.auth.connect(
509 client=self.__ssh,
510 hostname=self.hostname, port=self.port,
511 log=True)
512
513 def __connect_sftp(self):
514 """SFTP connection opener"""
515 with self.lock:
516 try:
517 self.__sftp = self.__ssh.open_sftp()
518 except paramiko.SSHException:
519 logger.warning('SFTP enable failed! SSH only is accessible.')
520
521 @property
522 def _sftp(self):
523 """SFTP channel access for inheritance
524
525 :rtype: paramiko.sftp_client.SFTPClient
526 :raises: paramiko.SSHException
527 """
528 if self.__sftp is not None:
529 return self.__sftp
530 logger.debug('SFTP is not connected, try to connect...')
531 self.__connect_sftp()
532 if self.__sftp is not None:
533 return self.__sftp
534 raise paramiko.SSHException('SFTP connection failed')
535
536 def close(self):
537 """Close SSH and SFTP sessions"""
538 with self.lock:
539 # noinspection PyBroadException
540 try:
541 self.__ssh.close()
542 self.__sftp = None
543 except Exception:
544 logger.exception("Could not close ssh connection")
545 if self.__sftp is not None:
546 # noinspection PyBroadException
547 try:
548 self.__sftp.close()
549 except Exception:
550 logger.exception("Could not close sftp connection")
551
552 @staticmethod
553 def clear():
554 warnings.warn(
555 "clear is removed: use close() only if it mandatory: "
556 "it's automatically called on revert|shutdown|suspend|destroy",
557 DeprecationWarning
558 )
559
560 @classmethod
561 def _clear_cache(cls):
562 """Enforce clear memorized records"""
563 warnings.warn(
564 '_clear_cache() is dangerous and not recommended for normal use!',
565 Warning
566 )
567 _MemorizedSSH.clear_cache()
568
569 @classmethod
570 def close_connections(cls, hostname=None):
571 """Close cached connections: if hostname is not set, then close all
572
573 :type hostname: str
574 """
575 _MemorizedSSH.close_connections(hostname=hostname)
576
577 def __del__(self):
578 """Destructor helper: close channel and threads BEFORE closing others
579
580 Due to threading in paramiko, default destructor could generate asserts
581 on close, so we calling channel close before closing main ssh object.
582 """
583 self.__ssh.close()
584 self.__sftp = None
585
586 def __enter__(self):
587 return self
588
589 def __exit__(self, exc_type, exc_val, exc_tb):
590 pass
591
592 def reconnect(self):
593 """Reconnect SSH session"""
594 with self.lock:
595 self.close()
596
597 self.__ssh = paramiko.SSHClient()
598 self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
599
600 self.__connect()
601
602 def sudo(self, enforce=None):
603 """Call contextmanager for sudo mode change
604
605 :type enforce: bool
606 :param enforce: Enforce sudo enabled or disabled. By default: None
607 """
608 return self.__get_sudo(ssh=self, enforce=enforce)
609
610 def check_call(
611 self,
612 command, verbose=False, timeout=None,
613 error_info=None,
614 expected=None, raise_on_err=True, **kwargs):
615 """Execute command and check for return code
616
617 :type command: str
618 :type verbose: bool
619 :type timeout: int
620 :type error_info: str
621 :type expected: list
622 :type raise_on_err: bool
623 :rtype: ExecResult
624 :raises: DevopsCalledProcessError
625 """
626 if expected is None:
627 expected = [proc_enums.ExitCodes.EX_OK]
628 else:
629 expected = [
630 proc_enums.ExitCodes(code)
631 if (
632 isinstance(code, int) and
633 code in proc_enums.ExitCodes.__members__.values())
634 else code
635 for code in expected
636 ]
637 ret = self.execute(command, verbose, timeout, **kwargs)
638 if ret['exit_code'] not in expected:
639 message = (
640 "{append}Command '{cmd!r}' returned exit code {code!s} while "
641 "expected {expected!s}\n".format(
642 append=error_info + '\n' if error_info else '',
643 cmd=command,
644 code=ret['exit_code'],
645 expected=expected,
646 ))
647 logger.error(message)
648 if raise_on_err:
649 raise SSHCalledProcessError(
650 command, ret['exit_code'],
651 expected=expected,
652 stdout=ret['stdout_brief'],
653 stderr=ret['stdout_brief'])
654 return ret
655
656 def check_stderr(
657 self,
658 command, verbose=False, timeout=None,
659 error_info=None,
660 raise_on_err=True, **kwargs):
661 """Execute command expecting return code 0 and empty STDERR
662
663 :type command: str
664 :type verbose: bool
665 :type timeout: int
666 :type error_info: str
667 :type raise_on_err: bool
668 :rtype: ExecResult
669 :raises: DevopsCalledProcessError
670 """
671 ret = self.check_call(
672 command, verbose, timeout=timeout,
673 error_info=error_info, raise_on_err=raise_on_err, **kwargs)
674 if ret['stderr']:
675 message = (
676 "{append}Command '{cmd!r}' STDERR while not expected\n"
677 "\texit code: {code!s}\n".format(
678 append=error_info + '\n' if error_info else '',
679 cmd=command,
680 code=ret['exit_code'],
681 ))
682 logger.error(message)
683 if raise_on_err:
684 raise SSHCalledProcessError(
685 command,
686 ret['exit_code'],
687 stdout=ret['stdout_brief'],
688 stderr=ret['stdout_brief'])
689 return ret
690
691 @classmethod
692 def execute_together(
693 cls, remotes, command, expected=None, raise_on_err=True, **kwargs):
694 """Execute command on multiple remotes in async mode
695
696 :type remotes: list
697 :type command: str
698 :type expected: list
699 :type raise_on_err: bool
700 :raises: DevopsCalledProcessError
701 """
702 if expected is None:
703 expected = [0]
704 futures = {}
705 errors = {}
706 for remote in set(remotes): # Use distinct remotes
707 chan, _, _, _ = remote.execute_async(command, **kwargs)
708 futures[remote] = chan
709 for remote, chan in futures.items():
710 ret = chan.recv_exit_status()
711 chan.close()
712 if ret not in expected:
713 errors[remote.hostname] = ret
714 if errors and raise_on_err:
715 raise SSHCalledProcessError(command, errors)
716
717 @classmethod
718 def __exec_command(
719 cls, command, channel, stdout, stderr, timeout, verbose=False):
720 """Get exit status from channel with timeout
721
722 :type command: str
723 :type channel: paramiko.channel.Channel
724 :type stdout: paramiko.channel.ChannelFile
725 :type stderr: paramiko.channel.ChannelFile
726 :type timeout: int
727 :type verbose: bool
728 :rtype: ExecResult
729 :raises: TimeoutError
730 """
731 def poll_stream(src, verb_logger=None):
732 dst = []
733 try:
734 for line in src:
735 dst.append(line)
736 if verb_logger is not None:
737 verb_logger(
738 line.decode('utf-8',
739 errors='backslashreplace').rstrip()
740 )
741 except IOError:
742 pass
743 return dst
744
745 def poll_streams(result, channel, stdout, stderr, verbose):
746 if channel.recv_ready():
747 result.stdout += poll_stream(
748 src=stdout,
749 verb_logger=logger.info if verbose else logger.debug)
750 if channel.recv_stderr_ready():
751 result.stderr += poll_stream(
752 src=stderr,
753 verb_logger=logger.error if verbose else logger.debug)
754
755 @decorators.threaded(started=True)
756 def poll_pipes(stdout, stderr, result, stop, channel):
757 """Polling task for FIFO buffers
758
759 :type stdout: paramiko.channel.ChannelFile
760 :type stderr: paramiko.channel.ChannelFile
761 :type result: ExecResult
762 :type stop: Event
763 :type channel: paramiko.channel.Channel
764 """
765
766 while not stop.isSet():
767 time.sleep(0.1)
768 poll_streams(
769 result=result,
770 channel=channel,
771 stdout=stdout,
772 stderr=stderr,
773 verbose=verbose
774 )
775
776 if channel.status_event.is_set():
777 result.exit_code = result.exit_code = channel.exit_status
778
779 result.stdout += poll_stream(
780 src=stdout,
781 verb_logger=logger.info if verbose else logger.debug)
782 result.stderr += poll_stream(
783 src=stderr,
784 verb_logger=logger.error if verbose else logger.debug)
785
786 stop.set()
787
788 # channel.status_event.wait(timeout)
789 result = exec_result.ExecResult(cmd=command)
790 stop_event = threading.Event()
791 if verbose:
792 logger.info("\nExecuting command: {!r}".format(command.rstrip()))
793 else:
794 logger.debug("\nExecuting command: {!r}".format(command.rstrip()))
795 poll_pipes(
796 stdout=stdout,
797 stderr=stderr,
798 result=result,
799 stop=stop_event,
800 channel=channel
801 )
802
803 stop_event.wait(timeout)
804
805 # Process closed?
806 if stop_event.isSet():
807 stop_event.clear()
808 channel.close()
809 return result
810
811 stop_event.set()
812 channel.close()
813
814 wait_err_msg = ('Wait for {0!r} during {1}s: no return code!\n'
815 .format(command, timeout))
816 output_brief_msg = ('\tSTDOUT:\n'
817 '{0}\n'
818 '\tSTDERR"\n'
819 '{1}'.format(result.stdout_brief,
820 result.stderr_brief))
821 logger.debug(wait_err_msg)
822 raise SSHTimeoutError(wait_err_msg + output_brief_msg)
823
824 def execute(self, command, verbose=False, timeout=None, **kwargs):
825 """Execute command and wait for return code
826
827 :type command: str
828 :type verbose: bool
829 :type timeout: int
830 :rtype: ExecResult
831 :raises: TimeoutError
832 """
833 chan, _, stderr, stdout = self.execute_async(command, **kwargs)
834
835 result = self.__exec_command(
836 command, chan, stdout, stderr, timeout,
837 verbose=verbose
838 )
839
840 message = (
841 '\n{cmd!r} execution results: Exit code: {code!s}'.format(
842 cmd=command,
843 code=result.exit_code
844 ))
845 if verbose:
846 logger.info(message)
847 else:
848 logger.debug(message)
849 return result
850
851 def execute_async(self, command, get_pty=False):
852 """Execute command in async mode and return channel with IO objects
853
854 :type command: str
855 :type get_pty: bool
856 :rtype:
857 tuple(
858 paramiko.Channel,
859 paramiko.ChannelFile,
860 paramiko.ChannelFile,
861 paramiko.ChannelFile
862 )
863 """
864 logger.debug("Executing command: {!r}".format(command.rstrip()))
865
866 chan = self._ssh.get_transport().open_session()
867
868 if get_pty:
869 # Open PTY
870 chan.get_pty(
871 term='vt100',
872 width=80, height=24,
873 width_pixels=0, height_pixels=0
874 )
875
876 stdin = chan.makefile('wb')
877 stdout = chan.makefile('rb')
878 stderr = chan.makefile_stderr('rb')
879 cmd = "{}\n".format(command)
880 if self.sudo_mode:
881 encoded_cmd = base64.b64encode(cmd.encode('utf-8')).decode('utf-8')
882 cmd = ("sudo -S bash -c 'eval \"$(base64 -d "
883 "<(echo \"{0}\"))\"'").format(
884 encoded_cmd
885 )
886 chan.exec_command(cmd)
887 if stdout.channel.closed is False:
888 self.auth.enter_password(stdin)
889 stdin.flush()
890 else:
891 chan.exec_command(cmd)
892 return chan, stdin, stderr, stdout
893
894 def execute_through_host(
895 self,
896 hostname,
897 cmd,
898 auth=None,
899 target_port=22,
900 timeout=None,
901 verbose=False
902 ):
903 """Execute command on remote host through currently connected host
904
905 :type hostname: str
906 :type cmd: str
907 :type auth: SSHAuth
908 :type target_port: int
909 :type timeout: int
910 :type verbose: bool
911 :rtype: ExecResult
912 :raises: TimeoutError
913 """
914 if auth is None:
915 auth = self.auth
916
917 intermediate_channel = self._ssh.get_transport().open_channel(
918 kind='direct-tcpip',
919 dest_addr=(hostname, target_port),
920 src_addr=(self.hostname, 0))
921 transport = paramiko.Transport(sock=intermediate_channel)
922
923 # start client and authenticate transport
924 auth.connect(transport)
925
926 # open ssh session
927 channel = transport.open_session()
928
929 # Make proxy objects for read
930 stdout = channel.makefile('rb')
931 stderr = channel.makefile_stderr('rb')
932
933 channel.exec_command(cmd)
934
935 # noinspection PyDictCreation
936 result = self.__exec_command(
937 cmd, channel, stdout, stderr, timeout, verbose=verbose)
938
939 intermediate_channel.close()
940
941 return result
942
943 def mkdir(self, path):
944 """run 'mkdir -p path' on remote
945
946 :type path: str
947 """
948 if self.exists(path):
949 return
950 logger.debug("Creating directory: {}".format(path))
951 # noinspection PyTypeChecker
952 self.execute("mkdir -p {}\n".format(path))
953
954 def rm_rf(self, path):
955 """run 'rm -rf path' on remote
956
957 :type path: str
958 """
959 logger.debug("rm -rf {}".format(path))
960 # noinspection PyTypeChecker
961 self.execute("rm -rf {}".format(path))
962
963 def open(self, path, mode='r'):
964 """Open file on remote using SFTP session
965
966 :type path: str
967 :type mode: str
968 :return: file.open() stream
969 """
970 return self._sftp.open(path, mode)
971
972 def upload(self, source, target):
973 """Upload file(s) from source to target using SFTP session
974
975 :type source: str
976 :type target: str
977 """
978 logger.debug("Copying '%s' -> '%s'", source, target)
979
980 if self.isdir(target):
981 target = posixpath.join(target, os.path.basename(source))
982
983 source = os.path.expanduser(source)
984 if not os.path.isdir(source):
985 self._sftp.put(source, target)
986 return
987
988 for rootdir, _, files in os.walk(source):
989 targetdir = os.path.normpath(
990 os.path.join(
991 target,
992 os.path.relpath(rootdir, source))).replace("\\", "/")
993
994 self.mkdir(targetdir)
995
996 for entry in files:
997 local_path = os.path.join(rootdir, entry)
998 remote_path = posixpath.join(targetdir, entry)
999 if self.exists(remote_path):
1000 self._sftp.unlink(remote_path)
1001 self._sftp.put(local_path, remote_path)
1002
1003 def download(self, destination, target):
1004 """Download file(s) to target from destination
1005
1006 :type destination: str
1007 :type target: str
1008 :rtype: bool
1009 """
1010 logger.debug(
1011 "Copying '%s' -> '%s' from remote to local host",
1012 destination, target
1013 )
1014
1015 if os.path.isdir(target):
1016 target = posixpath.join(target, os.path.basename(destination))
1017
1018 if not self.isdir(destination):
1019 if self.exists(destination):
1020 self._sftp.get(destination, target)
1021 else:
1022 logger.debug(
1023 "Can't download %s because it doesn't exist", destination
1024 )
1025 else:
1026 logger.debug(
1027 "Can't download %s because it is a directory", destination
1028 )
1029 return os.path.exists(target)
1030
1031 def exists(self, path):
1032 """Check for file existence using SFTP session
1033
1034 :type path: str
1035 :rtype: bool
1036 """
1037 try:
1038 self._sftp.lstat(path)
1039 return True
1040 except IOError:
1041 return False
1042
1043 def stat(self, path):
1044 """Get stat info for path with following symlinks
1045
1046 :type path: str
1047 :rtype: paramiko.sftp_attr.SFTPAttributes
1048 """
1049 return self._sftp.stat(path)
1050
1051 def isfile(self, path, follow_symlink=False):
1052 """Check, that path is file using SFTP session
1053
1054 :type path: str
1055 :type follow_symlink: bool (default=False), resolve symlinks
1056 :rtype: bool
1057 """
1058 try:
1059 if follow_symlink:
1060 attrs = self._sftp.stat(path)
1061 else:
1062 attrs = self._sftp.lstat(path)
1063 return attrs.st_mode & stat.S_IFREG != 0
1064 except IOError:
1065 return False
1066
1067 def isdir(self, path, follow_symlink=False):
1068 """Check, that path is directory using SFTP session
1069
1070 :type path: str
1071 :type follow_symlink: bool (default=False), resolve symlinks
1072 :rtype: bool
1073 """
1074 try:
1075 if follow_symlink:
1076 attrs = self._sftp.stat(path)
1077 else:
1078 attrs = self._sftp.lstat(path)
1079 return attrs.st_mode & stat.S_IFDIR != 0
1080 except IOError:
1081 return False
1082
1083 def walk(self, path):
1084 files=[]
1085 folders=[]
1086 try:
1087 for item in self._sftp.listdir_iter(path):
1088 if item.st_mode & stat.S_IFDIR:
1089 folders.append(item.filename)
1090 else:
1091 files.append(item.filename)
1092 except IOError as e:
1093 print("Error opening directory {0}: {1}".format(path, e))
1094
1095 yield path, folders, files
1096 for folder in folders:
1097 for res in self.walk(os.path.join(path, folder)):
1098 yield res
1099
1100
1101class SSHClientError(Exception):
1102 """Base class for errors"""
1103
1104
1105class SSHCalledProcessError(SSHClientError):
1106 @staticmethod
1107 def _makestr(data):
1108 if isinstance(data, six.binary_type):
1109 return data.decode('utf-8', errors='backslashreplace')
1110 elif isinstance(data, six.text_type):
1111 return data
1112 else:
1113 return repr(data)
1114
1115 def __init__(
1116 self, command, returncode, expected=0, stdout=None, stderr=None):
1117 self.returncode = returncode
1118 self.expected = expected
1119 self.cmd = command
1120 self.stdout = stdout
1121 self.stderr = stderr
1122 message = (
1123 "Command '{cmd}' returned exit code {code} while "
1124 "expected {expected}".format(
1125 cmd=self._makestr(self.cmd),
1126 code=self.returncode,
1127 expected=self.expected
1128 ))
1129 if self.stdout:
1130 message += "\n\tSTDOUT:\n{}".format(self._makestr(self.stdout))
1131 if self.stderr:
1132 message += "\n\tSTDERR:\n{}".format(self._makestr(self.stderr))
1133 super(SSHCalledProcessError, self).__init__(message)
1134
1135 @property
1136 def output(self):
1137 warnings.warn(
1138 'output is deprecated, please use stdout and stderr separately',
1139 DeprecationWarning)
1140 return self.stdout + self.stderr
1141
1142
1143class SSHTimeoutError(SSHClientError):
1144 pass
1145
1146
1147__all__ = ['SSHAuth', 'SSHClient', 'SSHClientError', 'SSHCalledProcessError', 'SSHTimeoutError']