blob: de1ad886dfbabbef9534ee3eb9565565b6973d54 [file] [log] [blame]
Maru Newbyb096d9f2015-03-09 18:54:54 +00001# Copyright 2012 OpenStack Foundation
2# All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
15
16
17import cStringIO
18import select
19import socket
20import time
21import warnings
22
23import six
24
25from neutron.tests.tempest import exceptions
26from neutron.openstack.common import log as logging
27
28
29with warnings.catch_warnings():
30 warnings.simplefilter("ignore")
31 import paramiko
32
33
34LOG = logging.getLogger(__name__)
35
36
37class Client(object):
38
39 def __init__(self, host, username, password=None, timeout=300, pkey=None,
40 channel_timeout=10, look_for_keys=False, key_filename=None):
41 self.host = host
42 self.username = username
43 self.password = password
44 if isinstance(pkey, six.string_types):
45 pkey = paramiko.RSAKey.from_private_key(
46 cStringIO.StringIO(str(pkey)))
47 self.pkey = pkey
48 self.look_for_keys = look_for_keys
49 self.key_filename = key_filename
50 self.timeout = int(timeout)
51 self.channel_timeout = float(channel_timeout)
52 self.buf_size = 1024
53
54 def _get_ssh_connection(self, sleep=1.5, backoff=1):
55 """Returns an ssh connection to the specified host."""
56 bsleep = sleep
57 ssh = paramiko.SSHClient()
58 ssh.set_missing_host_key_policy(
59 paramiko.AutoAddPolicy())
60 _start_time = time.time()
61 if self.pkey is not None:
62 LOG.info("Creating ssh connection to '%s' as '%s'"
63 " with public key authentication",
64 self.host, self.username)
65 else:
66 LOG.info("Creating ssh connection to '%s' as '%s'"
67 " with password %s",
68 self.host, self.username, str(self.password))
69 attempts = 0
70 while True:
71 try:
72 ssh.connect(self.host, username=self.username,
73 password=self.password,
74 look_for_keys=self.look_for_keys,
75 key_filename=self.key_filename,
76 timeout=self.channel_timeout, pkey=self.pkey)
77 LOG.info("ssh connection to %s@%s successfuly created",
78 self.username, self.host)
79 return ssh
80 except (socket.error,
81 paramiko.SSHException) as e:
82 if self._is_timed_out(_start_time):
83 LOG.exception("Failed to establish authenticated ssh"
84 " connection to %s@%s after %d attempts",
85 self.username, self.host, attempts)
86 raise exceptions.SSHTimeout(host=self.host,
87 user=self.username,
88 password=self.password)
89 bsleep += backoff
90 attempts += 1
91 LOG.warning("Failed to establish authenticated ssh"
92 " connection to %s@%s (%s). Number attempts: %s."
93 " Retry after %d seconds.",
94 self.username, self.host, e, attempts, bsleep)
95 time.sleep(bsleep)
96
97 def _is_timed_out(self, start_time):
98 return (time.time() - self.timeout) > start_time
99
100 def exec_command(self, cmd):
101 """
102 Execute the specified command on the server.
103
104 Note that this method is reading whole command outputs to memory, thus
105 shouldn't be used for large outputs.
106
107 :returns: data read from standard output of the command.
108 :raises: SSHExecCommandFailed if command returns nonzero
109 status. The exception contains command status stderr content.
110 """
111 ssh = self._get_ssh_connection()
112 transport = ssh.get_transport()
113 channel = transport.open_session()
114 channel.fileno() # Register event pipe
115 channel.exec_command(cmd)
116 channel.shutdown_write()
117 out_data = []
118 err_data = []
119 poll = select.poll()
120 poll.register(channel, select.POLLIN)
121 start_time = time.time()
122
123 while True:
124 ready = poll.poll(self.channel_timeout)
125 if not any(ready):
126 if not self._is_timed_out(start_time):
127 continue
128 raise exceptions.TimeoutException(
129 "Command: '{0}' executed on host '{1}'.".format(
130 cmd, self.host))
131 if not ready[0]: # If there is nothing to read.
132 continue
133 out_chunk = err_chunk = None
134 if channel.recv_ready():
135 out_chunk = channel.recv(self.buf_size)
136 out_data += out_chunk,
137 if channel.recv_stderr_ready():
138 err_chunk = channel.recv_stderr(self.buf_size)
139 err_data += err_chunk,
140 if channel.closed and not err_chunk and not out_chunk:
141 break
142 exit_status = channel.recv_exit_status()
143 if 0 != exit_status:
144 raise exceptions.SSHExecCommandFailed(
145 command=cmd, exit_status=exit_status,
146 strerror=''.join(err_data))
147 return ''.join(out_data)
148
149 def test_connection_auth(self):
150 """Raises an exception when we can not connect to server via ssh."""
151 connection = self._get_ssh_connection()
152 connection.close()