Merge "Handle the case that RFP negotiation message arrived early."
diff --git a/tempest/common/compute.py b/tempest/common/compute.py
index 9110c4a..cb9525b 100644
--- a/tempest/common/compute.py
+++ b/tempest/common/compute.py
@@ -252,16 +252,34 @@
def __init__(self, client_socket, url):
"""Contructor for the WebSocket wrapper to the socket."""
self._socket = client_socket
+ # cached stream for early frames.
+ self.cached_stream = b''
# Upgrade the HTTP connection to a WebSocket
self._upgrade(url)
+ def _recv(self, recv_size):
+ """Wrapper to receive data from the cached stream or socket."""
+ if recv_size <= 0:
+ return None
+
+ data_from_cached = b''
+ data_from_socket = b''
+ if len(self.cached_stream) > 0:
+ read_from_cached = min(len(self.cached_stream), recv_size)
+ data_from_cached += self.cached_stream[:read_from_cached]
+ self.cached_stream = self.cached_stream[read_from_cached:]
+ recv_size -= read_from_cached
+ if recv_size > 0:
+ data_from_socket = self._socket.recv(recv_size)
+ return data_from_cached + data_from_socket
+
def receive_frame(self):
"""Wrapper for receiving data to parse the WebSocket frame format"""
# We need to loop until we either get some bytes back in the frame
# or no data was received (meaning the socket was closed). This is
# done to handle the case where we get back some empty frames
while True:
- header = self._socket.recv(2)
+ header = self._recv(2)
# If we didn't receive any data, just return None
if not header:
return None
@@ -270,7 +288,7 @@
# that only the 2nd byte contains the length, and since the
# server doesn't do masking, we can just read the data length
if ord_func(header[1]) & 127 > 0:
- return self._socket.recv(ord_func(header[1]) & 127)
+ return self._recv(ord_func(header[1]) & 127)
def send_frame(self, data):
"""Wrapper for sending data to add in the WebSocket frame format."""
@@ -318,6 +336,15 @@
self._socket.sendall(reqdata.encode('utf8'))
self.response = data = self._socket.recv(4096)
# Loop through & concatenate all of the data in the response body
- while data and self.response.find(b'\r\n\r\n') < 0:
+ end_loc = self.response.find(b'\r\n\r\n')
+ while data and end_loc < 0:
data = self._socket.recv(4096)
self.response += data
+ end_loc = self.response.find(b'\r\n\r\n')
+
+ if len(self.response) > end_loc + 4:
+ # In case some frames (e.g. the first RFP negotiation) have
+ # arrived, cache it for next reading.
+ self.cached_stream = self.response[end_loc + 4:]
+ # ensure response ends with '\r\n\r\n'.
+ self.response = self.response[:end_loc + 4]
diff --git a/tempest/tests/common/test_compute.py b/tempest/tests/common/test_compute.py
new file mode 100644
index 0000000..c108be9
--- /dev/null
+++ b/tempest/tests/common/test_compute.py
@@ -0,0 +1,106 @@
+# Copyright 2017 Citrix Systems
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from six.moves.urllib import parse as urlparse
+
+import mock
+
+from tempest.common import compute
+from tempest.tests import base
+
+
+class TestCompute(base.TestCase):
+ def setUp(self):
+ super(TestCompute, self).setUp()
+ self.client_sock = mock.Mock()
+ self.url = urlparse.urlparse("http://www.fake.com:80")
+
+ def test_rfp_frame_not_cached(self):
+ # rfp negotiation frame arrived separately after upgrade
+ # response, so it's not cached.
+ RFP_VERSION = b'RFB.003.003\x0a'
+ rfp_frame_header = b'\x82\x0c'
+
+ self.client_sock.recv.side_effect = [
+ b'fake response start\r\n',
+ b'fake response end\r\n\r\n',
+ rfp_frame_header,
+ RFP_VERSION]
+ expect_response = b'fake response start\r\nfake response end\r\n\r\n'
+
+ webSocket = compute._WebSocket(self.client_sock, self.url)
+
+ self.assertEqual(webSocket.response, expect_response)
+ # no cache
+ self.assertEqual(webSocket.cached_stream, b'')
+ self.client_sock.recv.assert_has_calls([mock.call(4096),
+ mock.call(4096)])
+
+ self.client_sock.recv.reset_mock()
+ recv_version = webSocket.receive_frame()
+
+ self.assertEqual(recv_version, RFP_VERSION)
+ self.client_sock.recv.assert_has_calls([mock.call(2),
+ mock.call(12)])
+
+ def test_rfp_frame_fully_cached(self):
+ RFP_VERSION = b'RFB.003.003\x0a'
+ rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION
+
+ self.client_sock.recv.side_effect = [
+ b'fake response start\r\n',
+ b'fake response end\r\n\r\n%s' % rfp_version_frame]
+ expect_response = b'fake response start\r\nfake response end\r\n\r\n'
+ webSocket = compute._WebSocket(self.client_sock, self.url)
+
+ self.client_sock.recv.assert_has_calls([mock.call(4096),
+ mock.call(4096)])
+ self.assertEqual(webSocket.response, expect_response)
+ self.assertEqual(webSocket.cached_stream, rfp_version_frame)
+
+ self.client_sock.recv.reset_mock()
+ recv_version = webSocket.receive_frame()
+
+ self.client_sock.recv.assert_not_called()
+ self.assertEqual(recv_version, RFP_VERSION)
+ # cached_stream should be empty in the end.
+ self.assertEqual(webSocket.cached_stream, b'')
+
+ def test_rfp_frame_partially_cached(self):
+ RFP_VERSION = b'RFB.003.003\x0a'
+ rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION
+ frame_part1 = rfp_version_frame[:6]
+ frame_part2 = rfp_version_frame[6:]
+
+ self.client_sock.recv.side_effect = [
+ b'fake response start\r\n',
+ b'fake response end\r\n\r\n%s' % frame_part1,
+ frame_part2]
+ expect_response = b'fake response start\r\nfake response end\r\n\r\n'
+ webSocket = compute._WebSocket(self.client_sock, self.url)
+
+ self.client_sock.recv.assert_has_calls([mock.call(4096),
+ mock.call(4096)])
+ self.assertEqual(webSocket.response, expect_response)
+ self.assertEqual(webSocket.cached_stream, frame_part1)
+
+ self.client_sock.recv.reset_mock()
+
+ recv_version = webSocket.receive_frame()
+
+ self.client_sock.recv.assert_called_once_with(len(frame_part2))
+ self.assertEqual(recv_version, RFP_VERSION)
+ # cached_stream should be empty in the end.
+ self.assertEqual(webSocket.cached_stream, b'')