Merge "Handle the case that RFP negotiation message arrived early."
diff --git a/tempest/common/compute.py b/tempest/common/compute.py
index 9110c4a..cb9525b 100644
--- a/tempest/common/compute.py
+++ b/tempest/common/compute.py
@@ -252,16 +252,34 @@
     def __init__(self, client_socket, url):
         """Contructor for the WebSocket wrapper to the socket."""
         self._socket = client_socket
+        # cached stream for early frames.
+        self.cached_stream = b''
         # Upgrade the HTTP connection to a WebSocket
         self._upgrade(url)
 
+    def _recv(self, recv_size):
+        """Wrapper to receive data from the cached stream or socket."""
+        if recv_size <= 0:
+            return None
+
+        data_from_cached = b''
+        data_from_socket = b''
+        if len(self.cached_stream) > 0:
+            read_from_cached = min(len(self.cached_stream), recv_size)
+            data_from_cached += self.cached_stream[:read_from_cached]
+            self.cached_stream = self.cached_stream[read_from_cached:]
+            recv_size -= read_from_cached
+        if recv_size > 0:
+            data_from_socket = self._socket.recv(recv_size)
+        return data_from_cached + data_from_socket
+
     def receive_frame(self):
         """Wrapper for receiving data to parse the WebSocket frame format"""
         # We need to loop until we either get some bytes back in the frame
         # or no data was received (meaning the socket was closed).  This is
         # done to handle the case where we get back some empty frames
         while True:
-            header = self._socket.recv(2)
+            header = self._recv(2)
             # If we didn't receive any data, just return None
             if not header:
                 return None
@@ -270,7 +288,7 @@
             # that only the 2nd byte contains the length, and since the
             # server doesn't do masking, we can just read the data length
             if ord_func(header[1]) & 127 > 0:
-                return self._socket.recv(ord_func(header[1]) & 127)
+                return self._recv(ord_func(header[1]) & 127)
 
     def send_frame(self, data):
         """Wrapper for sending data to add in the WebSocket frame format."""
@@ -318,6 +336,15 @@
         self._socket.sendall(reqdata.encode('utf8'))
         self.response = data = self._socket.recv(4096)
         # Loop through & concatenate all of the data in the response body
-        while data and self.response.find(b'\r\n\r\n') < 0:
+        end_loc = self.response.find(b'\r\n\r\n')
+        while data and end_loc < 0:
             data = self._socket.recv(4096)
             self.response += data
+            end_loc = self.response.find(b'\r\n\r\n')
+
+        if len(self.response) > end_loc + 4:
+            # In case some frames (e.g. the first RFP negotiation) have
+            # arrived, cache it for next reading.
+            self.cached_stream = self.response[end_loc + 4:]
+            # ensure response ends with '\r\n\r\n'.
+            self.response = self.response[:end_loc + 4]
diff --git a/tempest/tests/common/test_compute.py b/tempest/tests/common/test_compute.py
new file mode 100644
index 0000000..c108be9
--- /dev/null
+++ b/tempest/tests/common/test_compute.py
@@ -0,0 +1,106 @@
+# Copyright 2017 Citrix Systems
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+from six.moves.urllib import parse as urlparse
+
+import mock
+
+from tempest.common import compute
+from tempest.tests import base
+
+
+class TestCompute(base.TestCase):
+    def setUp(self):
+        super(TestCompute, self).setUp()
+        self.client_sock = mock.Mock()
+        self.url = urlparse.urlparse("http://www.fake.com:80")
+
+    def test_rfp_frame_not_cached(self):
+        # rfp negotiation frame arrived separately after upgrade
+        # response, so it's not cached.
+        RFP_VERSION = b'RFB.003.003\x0a'
+        rfp_frame_header = b'\x82\x0c'
+
+        self.client_sock.recv.side_effect = [
+            b'fake response start\r\n',
+            b'fake response end\r\n\r\n',
+            rfp_frame_header,
+            RFP_VERSION]
+        expect_response = b'fake response start\r\nfake response end\r\n\r\n'
+
+        webSocket = compute._WebSocket(self.client_sock, self.url)
+
+        self.assertEqual(webSocket.response, expect_response)
+        # no cache
+        self.assertEqual(webSocket.cached_stream, b'')
+        self.client_sock.recv.assert_has_calls([mock.call(4096),
+                                                mock.call(4096)])
+
+        self.client_sock.recv.reset_mock()
+        recv_version = webSocket.receive_frame()
+
+        self.assertEqual(recv_version, RFP_VERSION)
+        self.client_sock.recv.assert_has_calls([mock.call(2),
+                                                mock.call(12)])
+
+    def test_rfp_frame_fully_cached(self):
+        RFP_VERSION = b'RFB.003.003\x0a'
+        rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION
+
+        self.client_sock.recv.side_effect = [
+            b'fake response start\r\n',
+            b'fake response end\r\n\r\n%s' % rfp_version_frame]
+        expect_response = b'fake response start\r\nfake response end\r\n\r\n'
+        webSocket = compute._WebSocket(self.client_sock, self.url)
+
+        self.client_sock.recv.assert_has_calls([mock.call(4096),
+                                                mock.call(4096)])
+        self.assertEqual(webSocket.response, expect_response)
+        self.assertEqual(webSocket.cached_stream, rfp_version_frame)
+
+        self.client_sock.recv.reset_mock()
+        recv_version = webSocket.receive_frame()
+
+        self.client_sock.recv.assert_not_called()
+        self.assertEqual(recv_version, RFP_VERSION)
+        # cached_stream should be empty in the end.
+        self.assertEqual(webSocket.cached_stream, b'')
+
+    def test_rfp_frame_partially_cached(self):
+        RFP_VERSION = b'RFB.003.003\x0a'
+        rfp_version_frame = b'\x82\x0c%s' % RFP_VERSION
+        frame_part1 = rfp_version_frame[:6]
+        frame_part2 = rfp_version_frame[6:]
+
+        self.client_sock.recv.side_effect = [
+            b'fake response start\r\n',
+            b'fake response end\r\n\r\n%s' % frame_part1,
+            frame_part2]
+        expect_response = b'fake response start\r\nfake response end\r\n\r\n'
+        webSocket = compute._WebSocket(self.client_sock, self.url)
+
+        self.client_sock.recv.assert_has_calls([mock.call(4096),
+                                                mock.call(4096)])
+        self.assertEqual(webSocket.response, expect_response)
+        self.assertEqual(webSocket.cached_stream, frame_part1)
+
+        self.client_sock.recv.reset_mock()
+
+        recv_version = webSocket.receive_frame()
+
+        self.client_sock.recv.assert_called_once_with(len(frame_part2))
+        self.assertEqual(recv_version, RFP_VERSION)
+        # cached_stream should be empty in the end.
+        self.assertEqual(webSocket.cached_stream, b'')