blob: a831dbda11df58899694b6cffa73a8404e56f985 [file] [log] [blame]
Matthew Treinish9e26ca82016-02-23 11:43:20 -05001# Copyright 2012 OpenStack Foundation
2# All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
15
16
17import select
18import socket
19import time
20import warnings
21
22from oslo_log import log as logging
23import six
24
25from tempest.lib import exceptions
26
27
28with warnings.catch_warnings():
29 warnings.simplefilter("ignore")
30 import paramiko
31
32
33LOG = logging.getLogger(__name__)
34
35
36class Client(object):
37
38 def __init__(self, host, username, password=None, timeout=300, pkey=None,
39 channel_timeout=10, look_for_keys=False, key_filename=None):
40 self.host = host
41 self.username = username
42 self.password = password
43 if isinstance(pkey, six.string_types):
44 pkey = paramiko.RSAKey.from_private_key(
45 six.StringIO(str(pkey)))
46 self.pkey = pkey
47 self.look_for_keys = look_for_keys
48 self.key_filename = key_filename
49 self.timeout = int(timeout)
50 self.channel_timeout = float(channel_timeout)
51 self.buf_size = 1024
52
53 def _get_ssh_connection(self, sleep=1.5, backoff=1):
54 """Returns an ssh connection to the specified host."""
55 bsleep = sleep
56 ssh = paramiko.SSHClient()
57 ssh.set_missing_host_key_policy(
58 paramiko.AutoAddPolicy())
59 _start_time = time.time()
60 if self.pkey is not None:
61 LOG.info("Creating ssh connection to '%s' as '%s'"
62 " with public key authentication",
63 self.host, self.username)
64 else:
65 LOG.info("Creating ssh connection to '%s' as '%s'"
66 " with password %s",
67 self.host, self.username, str(self.password))
68 attempts = 0
69 while True:
70 try:
71 ssh.connect(self.host, username=self.username,
72 password=self.password,
73 look_for_keys=self.look_for_keys,
74 key_filename=self.key_filename,
75 timeout=self.channel_timeout, pkey=self.pkey)
76 LOG.info("ssh connection to %s@%s successfully created",
77 self.username, self.host)
78 return ssh
79 except (EOFError,
80 socket.error,
81 paramiko.SSHException) as e:
82 if self._is_timed_out(_start_time):
83 LOG.exception("Failed to establish authenticated ssh"
84 " connection to %s@%s after %d attempts",
85 self.username, self.host, attempts)
86 raise exceptions.SSHTimeout(host=self.host,
87 user=self.username,
88 password=self.password)
89 bsleep += backoff
90 attempts += 1
91 LOG.warning("Failed to establish authenticated ssh"
92 " connection to %s@%s (%s). Number attempts: %s."
93 " Retry after %d seconds.",
94 self.username, self.host, e, attempts, bsleep)
95 time.sleep(bsleep)
96
97 def _is_timed_out(self, start_time):
98 return (time.time() - self.timeout) > start_time
99
100 @staticmethod
101 def _can_system_poll():
102 return hasattr(select, 'poll')
103
104 def exec_command(self, cmd, encoding="utf-8"):
105 """Execute the specified command on the server
106
107 Note that this method is reading whole command outputs to memory, thus
108 shouldn't be used for large outputs.
109
110 :param str cmd: Command to run at remote server.
111 :param str encoding: Encoding for result from paramiko.
112 Result will not be decoded if None.
113 :returns: data read from standard output of the command.
114 :raises: SSHExecCommandFailed if command returns nonzero
115 status. The exception contains command status stderr content.
116 :raises: TimeoutException if cmd doesn't end when timeout expires.
117 """
118 ssh = self._get_ssh_connection()
119 transport = ssh.get_transport()
Lucas Alvares Gomes68c197e2016-04-19 18:18:05 +0100120 with transport.open_session() as channel:
121 channel.fileno() # Register event pipe
122 channel.exec_command(cmd)
123 channel.shutdown_write()
124 exit_status = channel.recv_exit_status()
Matthew Treinish9e26ca82016-02-23 11:43:20 -0500125
Lucas Alvares Gomes68c197e2016-04-19 18:18:05 +0100126 # If the executing host is linux-based, poll the channel
127 if self._can_system_poll():
128 out_data_chunks = []
129 err_data_chunks = []
130 poll = select.poll()
131 poll.register(channel, select.POLLIN)
132 start_time = time.time()
Matthew Treinish9e26ca82016-02-23 11:43:20 -0500133
Lucas Alvares Gomes68c197e2016-04-19 18:18:05 +0100134 while True:
135 ready = poll.poll(self.channel_timeout)
136 if not any(ready):
137 if not self._is_timed_out(start_time):
138 continue
139 raise exceptions.TimeoutException(
140 "Command: '{0}' executed on host '{1}'.".format(
141 cmd, self.host))
142 if not ready[0]: # If there is nothing to read.
Matthew Treinish9e26ca82016-02-23 11:43:20 -0500143 continue
Lucas Alvares Gomes68c197e2016-04-19 18:18:05 +0100144 out_chunk = err_chunk = None
145 if channel.recv_ready():
146 out_chunk = channel.recv(self.buf_size)
147 out_data_chunks += out_chunk,
148 if channel.recv_stderr_ready():
149 err_chunk = channel.recv_stderr(self.buf_size)
150 err_data_chunks += err_chunk,
151 if not err_chunk and not out_chunk:
152 break
153 out_data = b''.join(out_data_chunks)
154 err_data = b''.join(err_data_chunks)
155 # Just read from the channels
156 else:
157 out_file = channel.makefile('rb', self.buf_size)
158 err_file = channel.makefile_stderr('rb', self.buf_size)
159 out_data = out_file.read()
160 err_data = err_file.read()
161 if encoding:
162 out_data = out_data.decode(encoding)
163 err_data = err_data.decode(encoding)
Matthew Treinish9e26ca82016-02-23 11:43:20 -0500164
Lucas Alvares Gomes68c197e2016-04-19 18:18:05 +0100165 if 0 != exit_status:
166 raise exceptions.SSHExecCommandFailed(
167 command=cmd, exit_status=exit_status,
168 stderr=err_data, stdout=out_data)
169 return out_data
Matthew Treinish9e26ca82016-02-23 11:43:20 -0500170
171 def test_connection_auth(self):
172 """Raises an exception when we can not connect to server via ssh."""
173 connection = self._get_ssh_connection()
174 connection.close()