THRIFT-2231 Support tornado-4.x (Python)

Client: Python
Patch: Roey Berman
Signed-off-by: Roger Meier <roger@apache.org>
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index d90f672..c8498c5 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -17,58 +17,91 @@
 # under the License.
 #
 
-from cStringIO import StringIO
+from __future__ import absolute_import
 import logging
 import socket
 import struct
 
-from thrift.transport import TTransport
-from thrift.transport.TTransport import TTransportException
+from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
 
-from tornado import gen
-from tornado import iostream
-from tornado import tcpserver
+from io import BytesIO
+from collections import deque
+from contextlib import contextmanager
+from tornado import gen, iostream, ioloop, tcpserver, concurrent
+
+__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
 
 
-class TTornadoStreamTransport(TTransport.TTransportBase):
+class _Lock(object):
+    def __init__(self):
+        self._waiters = deque()
+
+    def acquired(self):
+        return len(self._waiters) > 0
+
+    @gen.coroutine
+    def acquire(self):
+        blocker = self._waiters[-1] if self.acquired() else None
+        future = concurrent.Future()
+        self._waiters.append(future)
+        if blocker:
+            yield blocker
+
+        raise gen.Return(self._lock_context())
+
+    def release(self):
+        assert self.acquired(), 'Lock not aquired'
+        future = self._waiters.popleft()
+        future.set_result(None)
+
+    @contextmanager
+    def _lock_context(self):
+        try:
+            yield
+        finally:
+            self.release()
+
+
+class TTornadoStreamTransport(TTransportBase):
     """a framed, buffered transport over a Tornado stream"""
-    def __init__(self, host, port, stream=None):
+    def __init__(self, host, port, stream=None, io_loop=None):
         self.host = host
         self.port = port
-        self.is_queuing_reads = False
-        self.read_queue = []
-        self.__wbuf = StringIO()
+        self.io_loop = io_loop or ioloop.IOLoop.current()
+        self.__wbuf = BytesIO()
+        self._read_lock = _Lock()
 
         # servers provide a ready-to-go stream
         self.stream = stream
-        if self.stream is not None:
-            self._set_close_callback()
 
-    # not the same number of parameters as TTransportBase.open
-    def open(self, callback):
+    def with_timeout(self, timeout, future):
+        return gen.with_timeout(timeout, future, self.io_loop)
+
+    @gen.coroutine
+    def open(self, timeout=None):
         logging.debug('socket connecting')
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         self.stream = iostream.IOStream(sock)
 
-        def on_close_in_connect(*_):
-            message = 'could not connect to {}:{}'.format(self.host, self.port)
+        try:
+            connect = self.stream.connect((self.host, self.port))
+            if timeout is not None:
+                yield self.with_timeout(timeout, connect)
+            else:
+                yield connect
+        except (socket.error, IOError, ioloop.TimeoutError) as e:
+            message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
             raise TTransportException(
                 type=TTransportException.NOT_OPEN,
                 message=message)
-        self.stream.set_close_callback(on_close_in_connect)
 
-        def finish(*_):
-            self._set_close_callback()
-            callback()
+        raise gen.Return(self)
 
-        self.stream.connect((self.host, self.port), callback=finish)
-
-    def _set_close_callback(self):
-        def on_close():
-            raise TTransportException(
-                type=TTransportException.END_OF_FILE,
-                message='socket closed')
-        self.stream.set_close_callback(self.close)
+    def set_close_callback(self, callback):
+        """
+        Should be called only after open() returns
+        """
+        self.stream.set_close_callback(callback)
 
     def close(self):
         # don't raise if we intend to close
@@ -78,51 +111,45 @@
     def read(self, _):
         # The generated code for Tornado shouldn't do individual reads -- only
         # frames at a time
-        assert "you're doing it wrong" is True
+        assert False, "you're doing it wrong"
 
-    @gen.engine
-    def readFrame(self, callback):
-        self.read_queue.append(callback)
-        logging.debug('read queue: %s', self.read_queue)
+    @contextmanager
+    def io_exception_context(self):
+        try:
+            yield
+        except (socket.error, IOError) as e:
+            raise TTransportException(
+                type=TTransportException.END_OF_FILE,
+                message=str(e))
+        except iostream.StreamBufferFullError as e:
+            raise TTransportException(
+                type=TTransportException.UNKNOWN,
+                message=str(e))
 
-        if self.is_queuing_reads:
-            # If a read is already in flight, then the while loop below should
-            # pull it from self.read_queue
-            return
-
-        self.is_queuing_reads = True
-        while self.read_queue:
-            next_callback = self.read_queue.pop()
-            result = yield gen.Task(self._readFrameFromStream)
-            next_callback(result)
-        self.is_queuing_reads = False
-
-    @gen.engine
-    def _readFrameFromStream(self, callback):
-        logging.debug('_readFrameFromStream')
-        frame_header = yield gen.Task(self.stream.read_bytes, 4)
-        frame_length, = struct.unpack('!i', frame_header)
-        logging.debug('received frame header, frame length = %i', frame_length)
-        frame = yield gen.Task(self.stream.read_bytes, frame_length)
-        logging.debug('received frame payload')
-        callback(frame)
+    @gen.coroutine
+    def readFrame(self):
+        # IOStream processes reads one at a time
+        with (yield self._read_lock.acquire()):
+            with self.io_exception_context():
+                frame_header = yield self.stream.read_bytes(4)
+                if len(frame_header) == 0:
+                    raise iostream.StreamClosedError('Read zero bytes from stream')
+                frame_length, = struct.unpack('!i', frame_header)
+                logging.debug('received frame header, frame length = %d', frame_length)
+                frame = yield self.stream.read_bytes(frame_length)
+                logging.debug('received frame payload: %r', frame)
+                raise gen.Return(frame)
 
     def write(self, buf):
         self.__wbuf.write(buf)
 
-    def flush(self, callback=None):
-        wout = self.__wbuf.getvalue()
-        wsz = len(wout)
+    def flush(self):
+        frame = self.__wbuf.getvalue()
         # reset wbuf before write/flush to preserve state on underlying failure
-        self.__wbuf = StringIO()
-        # N.B.: Doing this string concatenation is WAY cheaper than making
-        # two separate calls to the underlying socket object. Socket writes in
-        # Python turn out to be REALLY expensive, but it seems to do a pretty
-        # good job of managing string buffer operations without excessive copies
-        buf = struct.pack("!i", wsz) + wout
-
-        logging.debug('writing frame length = %i', wsz)
-        self.stream.write(buf, callback)
+        frame_length = struct.pack('!i', len(frame))
+        self.__wbuf = BytesIO()
+        with self.io_exception_context():
+            return self.stream.write(frame_length + frame)
 
 
 class TTornadoServer(tcpserver.TCPServer):
@@ -135,19 +162,21 @@
         self._oprot_factory = (oprot_factory if oprot_factory is not None
                                else iprot_factory)
 
+    @gen.coroutine
     def handle_stream(self, stream, address):
+        host, port = address
+        trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
+                                        io_loop=self.io_loop)
+        oprot = self._oprot_factory.getProtocol(trans)
+
         try:
-            host, port = address
-            trans = TTornadoStreamTransport(host=host, port=port, stream=stream)
-            oprot = self._oprot_factory.getProtocol(trans)
-
-            def next_pass():
-                if not trans.stream.closed():
-                    self._processor.process(trans, self._iprot_factory, oprot,
-                                            callback=next_pass)
-
-            next_pass()
-
+            while not trans.stream.closed():
+                frame = yield trans.readFrame()
+                tr = TMemoryBuffer(frame)
+                iprot = self._iprot_factory.getProtocol(tr)
+                yield self._processor.process(iprot, oprot)
         except Exception:
             logging.exception('thrift exception in handle_stream')
             trans.close()
+
+        logging.info('client disconnected %s:%d', host, port)