blob: cb59a828613f19798a40cc97a5733f4ec8bd8eed [file] [log] [blame]
# Copyright 2012 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import io
import select
import socket
import time
import warnings
from oslo_log import log as logging
from oslo_utils.secretutils import md5
from tempest.lib import exceptions
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import paramiko
LOG = logging.getLogger(__name__)
def get_fingerprint(self):
"""Patch paramiko
This method needs to be patched to allow paramiko to work under FIPS.
Until the patch to do this merges, patch paramiko here.
TODO(alee) Remove this when paramiko is patched.
See https://github.com/paramiko/paramiko/pull/1928
"""
return md5(self.asbytes(), usedforsecurity=False).digest()
paramiko.pkey.PKey.get_fingerprint = get_fingerprint
class Client(object):
def __init__(self, host, username, password=None, timeout=300, pkey=None,
channel_timeout=10, look_for_keys=False, key_filename=None,
port=22, proxy_client=None, ssh_key_type='rsa'):
"""SSH client.
Many of parameters are just passed to the underlying implementation
as it is. See the paramiko documentation for more details.
http://docs.paramiko.org/en/2.1/api/client.html#paramiko.client.SSHClient.connect
:param host: Host to login.
:param username: SSH username.
:param password: SSH password, or a password to unlock private key.
:param timeout: Timeout in seconds, including retries.
Default is 300 seconds.
:param pkey: Private key.
:param channel_timeout: Channel timeout in seconds, passed to the
paramiko. Default is 10 seconds.
:param look_for_keys: Whether or not to search for private keys
in ``~/.ssh``. Default is False.
:param key_filename: Filename for private key to use.
:param port: SSH port number.
:param proxy_client: Another SSH client to provide a transport
for ssh-over-ssh. The default is None, which means
not to use ssh-over-ssh.
:param ssh_key_type: ssh key type (rsa, ecdsa)
:type proxy_client: ``tempest.lib.common.ssh.Client`` object
"""
self.host = host
self.username = username
self.port = port
self.password = password
if isinstance(pkey, str):
if ssh_key_type == 'rsa':
pkey = paramiko.RSAKey.from_private_key(
io.StringIO(str(pkey)))
elif ssh_key_type == 'ecdsa':
pkey = paramiko.ECDSAKey.from_private_key(
io.StringIO(str(pkey)))
else:
raise exceptions.SSHClientUnsupportedKeyType(
key_type=ssh_key_type)
self.pkey = pkey
self.look_for_keys = look_for_keys
self.key_filename = key_filename
self.timeout = int(timeout)
self.channel_timeout = float(channel_timeout)
self.buf_size = 1024
self.proxy_client = proxy_client
if (self.proxy_client and self.proxy_client.host == self.host and
self.proxy_client.port == self.port and
self.proxy_client.username == self.username):
raise exceptions.SSHClientProxyClientLoop(
host=self.host, port=self.port, username=self.username)
self._proxy_conn = None
def _get_ssh_connection(self, sleep=1.5, backoff=1):
"""Returns an ssh connection to the specified host."""
bsleep = sleep
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(
paramiko.AutoAddPolicy())
_start_time = time.time()
if self.pkey is not None:
LOG.info("Creating ssh connection to '%s:%d' as '%s'"
" with public key authentication",
self.host, self.port, self.username)
else:
LOG.info("Creating ssh connection to '%s:%d' as '%s'"
" with password %s",
self.host, self.port, self.username, str(self.password))
attempts = 0
while True:
if self.proxy_client is not None:
proxy_chan = self._get_proxy_channel()
else:
proxy_chan = None
try:
ssh.connect(self.host, port=self.port, username=self.username,
password=self.password,
look_for_keys=self.look_for_keys,
key_filename=self.key_filename,
timeout=self.channel_timeout, pkey=self.pkey,
sock=proxy_chan)
LOG.info("ssh connection to %s@%s successfully created",
self.username, self.host)
return ssh
except (EOFError,
socket.error, socket.timeout,
paramiko.SSHException) as e:
ssh.close()
if self._is_timed_out(_start_time):
LOG.exception("Failed to establish authenticated ssh"
" connection to %s@%s after %d attempts. "
"Proxy client: %s",
self.username, self.host, attempts,
self._get_proxy_client_info())
raise exceptions.SSHTimeout(host=self.host,
user=self.username,
password=self.password)
bsleep += backoff
attempts += 1
LOG.warning("Failed to establish authenticated ssh"
" connection to %s@%s (%s). Number attempts: %s."
" Retry after %d seconds.",
self.username, self.host, e, attempts, bsleep)
time.sleep(bsleep)
def _is_timed_out(self, start_time):
return (time.time() - self.timeout) > start_time
@staticmethod
def _can_system_poll():
return hasattr(select, 'poll')
def exec_command(self, cmd, encoding="utf-8"):
"""Execute the specified command on the server
Note that this method is reading whole command outputs to memory, thus
shouldn't be used for large outputs.
:param str cmd: Command to run at remote server.
:param str encoding: Encoding for result from paramiko.
Result will not be decoded if None.
:returns: data read from standard output of the command.
:raises: SSHExecCommandFailed if command returns nonzero
status. The exception contains command status stderr content.
:raises: TimeoutException if cmd doesn't end when timeout expires.
"""
ssh = self._get_ssh_connection()
transport = ssh.get_transport()
with transport.open_session() as channel:
channel.fileno() # Register event pipe
channel.exec_command(cmd)
channel.shutdown_write()
# If the executing host is linux-based, poll the channel
if self._can_system_poll():
out_data_chunks = []
err_data_chunks = []
poll = select.poll()
poll.register(channel, select.POLLIN)
start_time = time.time()
while True:
ready = poll.poll(self.channel_timeout)
if not any(ready):
if not self._is_timed_out(start_time):
continue
raise exceptions.TimeoutException(
"Command: '{0}' executed on host '{1}'.".format(
cmd, self.host))
if not ready[0]: # If there is nothing to read.
continue
out_chunk = err_chunk = None
if channel.recv_ready():
out_chunk = channel.recv(self.buf_size)
out_data_chunks += out_chunk,
if channel.recv_stderr_ready():
err_chunk = channel.recv_stderr(self.buf_size)
err_data_chunks += err_chunk,
if not err_chunk and not out_chunk:
break
out_data = b''.join(out_data_chunks)
err_data = b''.join(err_data_chunks)
# Just read from the channels
else:
out_file = channel.makefile('rb', self.buf_size)
err_file = channel.makefile_stderr('rb', self.buf_size)
out_data = out_file.read()
err_data = err_file.read()
if encoding:
out_data = out_data.decode(encoding)
err_data = err_data.decode(encoding)
exit_status = channel.recv_exit_status()
ssh.close()
if 0 != exit_status:
raise exceptions.SSHExecCommandFailed(
command=cmd, exit_status=exit_status,
stderr=err_data, stdout=out_data)
return out_data
def test_connection_auth(self):
"""Raises an exception when we can not connect to server via ssh."""
connection = self._get_ssh_connection()
connection.close()
def _get_proxy_channel(self):
conn = self.proxy_client._get_ssh_connection()
# Keep a reference to avoid g/c
# https://github.com/paramiko/paramiko/issues/440
self._proxy_conn = conn
transport = conn.get_transport()
chan = transport.open_session()
cmd = 'nc %s %s' % (self.host, self.port)
chan.exec_command(cmd)
return chan
def _get_proxy_client_info(self):
if not self.proxy_client:
return 'no proxy client'
nested_pclient = self.proxy_client._get_proxy_client_info()
return ('%(username)s@%(host)s:%(port)s, nested proxy client: '
'%(nested_pclient)s' % {'username': self.proxy_client.username,
'host': self.proxy_client.host,
'port': self.proxy_client.port,
'nested_pclient': nested_pclient})