Merge "Align multiple lines in tox"
diff --git a/tempest/lib/common/ssh.py b/tempest/lib/common/ssh.py
index 511dd08..a831dbd 100644
--- a/tempest/lib/common/ssh.py
+++ b/tempest/lib/common/ssh.py
@@ -117,56 +117,56 @@
         """
         ssh = self._get_ssh_connection()
         transport = ssh.get_transport()
-        channel = transport.open_session()
-        channel.fileno()  # Register event pipe
-        channel.exec_command(cmd)
-        channel.shutdown_write()
-        exit_status = channel.recv_exit_status()
+        with transport.open_session() as channel:
+            channel.fileno()  # Register event pipe
+            channel.exec_command(cmd)
+            channel.shutdown_write()
+            exit_status = channel.recv_exit_status()
 
-        # If the executing host is linux-based, poll the channel
-        if self._can_system_poll():
-            out_data_chunks = []
-            err_data_chunks = []
-            poll = select.poll()
-            poll.register(channel, select.POLLIN)
-            start_time = time.time()
+            # If the executing host is linux-based, poll the channel
+            if self._can_system_poll():
+                out_data_chunks = []
+                err_data_chunks = []
+                poll = select.poll()
+                poll.register(channel, select.POLLIN)
+                start_time = time.time()
 
-            while True:
-                ready = poll.poll(self.channel_timeout)
-                if not any(ready):
-                    if not self._is_timed_out(start_time):
+                while True:
+                    ready = poll.poll(self.channel_timeout)
+                    if not any(ready):
+                        if not self._is_timed_out(start_time):
+                            continue
+                        raise exceptions.TimeoutException(
+                            "Command: '{0}' executed on host '{1}'.".format(
+                                cmd, self.host))
+                    if not ready[0]:  # If there is nothing to read.
                         continue
-                    raise exceptions.TimeoutException(
-                        "Command: '{0}' executed on host '{1}'.".format(
-                            cmd, self.host))
-                if not ready[0]:  # If there is nothing to read.
-                    continue
-                out_chunk = err_chunk = None
-                if channel.recv_ready():
-                    out_chunk = channel.recv(self.buf_size)
-                    out_data_chunks += out_chunk,
-                if channel.recv_stderr_ready():
-                    err_chunk = channel.recv_stderr(self.buf_size)
-                    err_data_chunks += err_chunk,
-                if channel.closed and not err_chunk and not out_chunk:
-                    break
-            out_data = b''.join(out_data_chunks)
-            err_data = b''.join(err_data_chunks)
-        # Just read from the channels
-        else:
-            out_file = channel.makefile('rb', self.buf_size)
-            err_file = channel.makefile_stderr('rb', self.buf_size)
-            out_data = out_file.read()
-            err_data = err_file.read()
-        if encoding:
-            out_data = out_data.decode(encoding)
-            err_data = err_data.decode(encoding)
+                    out_chunk = err_chunk = None
+                    if channel.recv_ready():
+                        out_chunk = channel.recv(self.buf_size)
+                        out_data_chunks += out_chunk,
+                    if channel.recv_stderr_ready():
+                        err_chunk = channel.recv_stderr(self.buf_size)
+                        err_data_chunks += err_chunk,
+                    if not err_chunk and not out_chunk:
+                        break
+                out_data = b''.join(out_data_chunks)
+                err_data = b''.join(err_data_chunks)
+            # Just read from the channels
+            else:
+                out_file = channel.makefile('rb', self.buf_size)
+                err_file = channel.makefile_stderr('rb', self.buf_size)
+                out_data = out_file.read()
+                err_data = err_file.read()
+            if encoding:
+                out_data = out_data.decode(encoding)
+                err_data = err_data.decode(encoding)
 
-        if 0 != exit_status:
-            raise exceptions.SSHExecCommandFailed(
-                command=cmd, exit_status=exit_status,
-                stderr=err_data, stdout=out_data)
-        return out_data
+            if 0 != exit_status:
+                raise exceptions.SSHExecCommandFailed(
+                    command=cmd, exit_status=exit_status,
+                    stderr=err_data, stdout=out_data)
+            return out_data
 
     def test_connection_auth(self):
         """Raises an exception when we can not connect to server via ssh."""
diff --git a/tempest/tests/lib/test_ssh.py b/tempest/tests/lib/test_ssh.py
index f6efd47..d001c51 100644
--- a/tempest/tests/lib/test_ssh.py
+++ b/tempest/tests/lib/test_ssh.py
@@ -141,8 +141,6 @@
     def test_exec_command(self):
         chan_mock, poll_mock, select_mock = (
             self._set_mocks_for_select([[1, 0, 0]], True))
-        closed_prop = mock.PropertyMock(return_value=True)
-        type(chan_mock).closed = closed_prop
 
         chan_mock.recv_exit_status.return_value = 0
         chan_mock.recv.return_value = b''
@@ -164,7 +162,6 @@
         chan_mock.recv_stderr_ready.assert_called_once_with()
         chan_mock.recv_stderr.assert_called_once_with(1024)
         chan_mock.recv_exit_status.assert_called_once_with()
-        closed_prop.assert_called_once_with()
 
     def _set_mocks_for_select(self, poll_data, ito_value=False):
         gsc_mock = self.patch('tempest.lib.common.ssh.Client.'
@@ -184,7 +181,7 @@
         gsc_mock.return_value = client_mock
         ito_mock.return_value = ito_value
         client_mock.get_transport.return_value = tran_mock
-        tran_mock.open_session.return_value = chan_mock
+        tran_mock.open_session().__enter__.return_value = chan_mock
         if isinstance(poll_data[0], list):
             poll_mock.poll.side_effect = poll_data
         else:
@@ -242,7 +239,7 @@
 
         gsc_mock.return_value = client_mock
         client_mock.get_transport.return_value = tran_mock
-        tran_mock.open_session.return_value = chan_mock
+        tran_mock.open_session().__enter__.return_value = chan_mock
         chan_mock.recv_exit_status.return_value = 0
 
         std_out_mock = mock.MagicMock(StringIO)