blob: 295541864d496c27413cbef49ba47c44020acd6b [file] [log] [blame]
Steve Baker450aa7f2014-08-25 10:37:27 +12001# Licensed under the Apache License, Version 2.0 (the "License"); you may
2# not use this file except in compliance with the License. You may obtain
3# a copy of the License at
4#
5# http://www.apache.org/licenses/LICENSE-2.0
6#
7# Unless required by applicable law or agreed to in writing, software
8# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10# License for the specific language governing permissions and limitations
11# under the License.
12
13import cStringIO
Steve Baker450aa7f2014-08-25 10:37:27 +120014import re
15import select
Steve Baker450aa7f2014-08-25 10:37:27 +120016import socket
17import time
18
Steve Baker24641292015-03-13 10:47:50 +130019from oslo_log import log as logging
Pavlo Shchelokovskyy60e0ecd2014-12-14 22:17:21 +020020import paramiko
21import six
22
Steve Baker450aa7f2014-08-25 10:37:27 +120023from heat_integrationtests.common import exceptions
24
25LOG = logging.getLogger(__name__)
26
27
28class Client(object):
29
30 def __init__(self, host, username, password=None, timeout=300, pkey=None,
31 channel_timeout=10, look_for_keys=False, key_filename=None):
32 self.host = host
33 self.username = username
34 self.password = password
35 if isinstance(pkey, six.string_types):
36 pkey = paramiko.RSAKey.from_private_key(
37 cStringIO.StringIO(str(pkey)))
38 self.pkey = pkey
39 self.look_for_keys = look_for_keys
40 self.key_filename = key_filename
41 self.timeout = int(timeout)
42 self.channel_timeout = float(channel_timeout)
43 self.buf_size = 1024
44
45 def _get_ssh_connection(self, sleep=1.5, backoff=1):
46 """Returns an ssh connection to the specified host."""
47 bsleep = sleep
48 ssh = paramiko.SSHClient()
49 ssh.set_missing_host_key_policy(
50 paramiko.AutoAddPolicy())
51 _start_time = time.time()
52 if self.pkey is not None:
53 LOG.info("Creating ssh connection to '%s' as '%s'"
54 " with public key authentication",
55 self.host, self.username)
56 else:
57 LOG.info("Creating ssh connection to '%s' as '%s'"
58 " with password %s",
59 self.host, self.username, str(self.password))
60 attempts = 0
61 while True:
62 try:
63 ssh.connect(self.host, username=self.username,
64 password=self.password,
65 look_for_keys=self.look_for_keys,
66 key_filename=self.key_filename,
67 timeout=self.channel_timeout, pkey=self.pkey)
68 LOG.info("ssh connection to %s@%s successfuly created",
69 self.username, self.host)
70 return ssh
71 except (socket.error,
72 paramiko.SSHException) as e:
73 if self._is_timed_out(_start_time):
74 LOG.exception("Failed to establish authenticated ssh"
75 " connection to %s@%s after %d attempts",
76 self.username, self.host, attempts)
77 raise exceptions.SSHTimeout(host=self.host,
78 user=self.username,
79 password=self.password)
80 bsleep += backoff
81 attempts += 1
82 LOG.warning("Failed to establish authenticated ssh"
83 " connection to %s@%s (%s). Number attempts: %s."
84 " Retry after %d seconds.",
85 self.username, self.host, e, attempts, bsleep)
86 time.sleep(bsleep)
87
88 def _is_timed_out(self, start_time):
89 return (time.time() - self.timeout) > start_time
90
91 def exec_command(self, cmd):
92 """
93 Execute the specified command on the server.
94
95 Note that this method is reading whole command outputs to memory, thus
96 shouldn't be used for large outputs.
97
98 :returns: data read from standard output of the command.
99 :raises: SSHExecCommandFailed if command returns nonzero
100 status. The exception contains command status stderr content.
101 """
102 ssh = self._get_ssh_connection()
103 transport = ssh.get_transport()
104 channel = transport.open_session()
105 channel.fileno() # Register event pipe
106 channel.exec_command(cmd)
107 channel.shutdown_write()
108 out_data = []
109 err_data = []
110 poll = select.poll()
111 poll.register(channel, select.POLLIN)
112 start_time = time.time()
113
114 while True:
115 ready = poll.poll(self.channel_timeout)
116 if not any(ready):
117 if not self._is_timed_out(start_time):
118 continue
119 raise exceptions.TimeoutException(
120 "Command: '{0}' executed on host '{1}'.".format(
121 cmd, self.host))
122 if not ready[0]: # If there is nothing to read.
123 continue
124 out_chunk = err_chunk = None
125 if channel.recv_ready():
126 out_chunk = channel.recv(self.buf_size)
127 out_data += out_chunk,
128 if channel.recv_stderr_ready():
129 err_chunk = channel.recv_stderr(self.buf_size)
130 err_data += err_chunk,
131 if channel.closed and not err_chunk and not out_chunk:
132 break
133 exit_status = channel.recv_exit_status()
134 if 0 != exit_status:
135 raise exceptions.SSHExecCommandFailed(
136 command=cmd, exit_status=exit_status,
137 strerror=''.join(err_data))
138 return ''.join(out_data)
139
140 def test_connection_auth(self):
141 """Raises an exception when we can not connect to server via ssh."""
142 connection = self._get_ssh_connection()
143 connection.close()
144
145
146class RemoteClient():
147
148 # NOTE(afazekas): It should always get an address instead of server
149 def __init__(self, server, username, password=None, pkey=None,
150 conf=None):
151 self.conf = conf
152 ssh_timeout = self.conf.ssh_timeout
153 network = self.conf.network_for_ssh
154 ip_version = self.conf.ip_version_for_ssh
155 ssh_channel_timeout = self.conf.ssh_channel_timeout
156 if isinstance(server, six.string_types):
157 ip_address = server
158 else:
159 addresses = server['addresses'][network]
160 for address in addresses:
161 if address['version'] == ip_version:
162 ip_address = address['addr']
163 break
164 else:
165 raise exceptions.ServerUnreachable()
166 self.ssh_client = Client(ip_address, username, password,
167 ssh_timeout, pkey=pkey,
168 channel_timeout=ssh_channel_timeout)
169
170 def exec_command(self, cmd):
171 return self.ssh_client.exec_command(cmd)
172
173 def validate_authentication(self):
174 """Validate ssh connection and authentication
175 This method raises an Exception when the validation fails.
176 """
177 self.ssh_client.test_connection_auth()
178
179 def get_partitions(self):
180 # Return the contents of /proc/partitions
181 command = 'cat /proc/partitions'
182 output = self.exec_command(command)
183 return output
184
185 def get_boot_time(self):
186 cmd = 'cut -f1 -d. /proc/uptime'
187 boot_secs = self.exec_command(cmd)
188 boot_time = time.time() - int(boot_secs)
189 return time.localtime(boot_time)
190
191 def write_to_console(self, message):
192 message = re.sub("([$\\`])", "\\\\\\\\\\1", message)
193 # usually to /dev/ttyS0
194 cmd = 'sudo sh -c "echo \\"%s\\" >/dev/console"' % message
195 return self.exec_command(cmd)
196
197 def ping_host(self, host):
198 cmd = 'ping -c1 -w1 %s' % host
199 return self.exec_command(cmd)
200
201 def get_ip_list(self):
202 cmd = "/bin/ip address"
203 return self.exec_command(cmd)