THRIFT-2231 Support tornado-4.x (Python)
Client: Python
Patch: Roey Berman
Signed-off-by: Roger Meier <roger@apache.org>
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index d90f672..c8498c5 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -17,58 +17,91 @@
# under the License.
#
-from cStringIO import StringIO
+from __future__ import absolute_import
import logging
import socket
import struct
-from thrift.transport import TTransport
-from thrift.transport.TTransport import TTransportException
+from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
-from tornado import gen
-from tornado import iostream
-from tornado import tcpserver
+from io import BytesIO
+from collections import deque
+from contextlib import contextmanager
+from tornado import gen, iostream, ioloop, tcpserver, concurrent
+
+__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
-class TTornadoStreamTransport(TTransport.TTransportBase):
+class _Lock(object):
+ def __init__(self):
+ self._waiters = deque()
+
+ def acquired(self):
+ return len(self._waiters) > 0
+
+ @gen.coroutine
+ def acquire(self):
+ blocker = self._waiters[-1] if self.acquired() else None
+ future = concurrent.Future()
+ self._waiters.append(future)
+ if blocker:
+ yield blocker
+
+ raise gen.Return(self._lock_context())
+
+ def release(self):
+ assert self.acquired(), 'Lock not aquired'
+ future = self._waiters.popleft()
+ future.set_result(None)
+
+ @contextmanager
+ def _lock_context(self):
+ try:
+ yield
+ finally:
+ self.release()
+
+
+class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
- def __init__(self, host, port, stream=None):
+ def __init__(self, host, port, stream=None, io_loop=None):
self.host = host
self.port = port
- self.is_queuing_reads = False
- self.read_queue = []
- self.__wbuf = StringIO()
+ self.io_loop = io_loop or ioloop.IOLoop.current()
+ self.__wbuf = BytesIO()
+ self._read_lock = _Lock()
# servers provide a ready-to-go stream
self.stream = stream
- if self.stream is not None:
- self._set_close_callback()
- # not the same number of parameters as TTransportBase.open
- def open(self, callback):
+ def with_timeout(self, timeout, future):
+ return gen.with_timeout(timeout, future, self.io_loop)
+
+ @gen.coroutine
+ def open(self, timeout=None):
logging.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
- def on_close_in_connect(*_):
- message = 'could not connect to {}:{}'.format(self.host, self.port)
+ try:
+ connect = self.stream.connect((self.host, self.port))
+ if timeout is not None:
+ yield self.with_timeout(timeout, connect)
+ else:
+ yield connect
+ except (socket.error, IOError, ioloop.TimeoutError) as e:
+ message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
- self.stream.set_close_callback(on_close_in_connect)
- def finish(*_):
- self._set_close_callback()
- callback()
+ raise gen.Return(self)
- self.stream.connect((self.host, self.port), callback=finish)
-
- def _set_close_callback(self):
- def on_close():
- raise TTransportException(
- type=TTransportException.END_OF_FILE,
- message='socket closed')
- self.stream.set_close_callback(self.close)
+ def set_close_callback(self, callback):
+ """
+ Should be called only after open() returns
+ """
+ self.stream.set_close_callback(callback)
def close(self):
# don't raise if we intend to close
@@ -78,51 +111,45 @@
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
- assert "you're doing it wrong" is True
+ assert False, "you're doing it wrong"
- @gen.engine
- def readFrame(self, callback):
- self.read_queue.append(callback)
- logging.debug('read queue: %s', self.read_queue)
+ @contextmanager
+ def io_exception_context(self):
+ try:
+ yield
+ except (socket.error, IOError) as e:
+ raise TTransportException(
+ type=TTransportException.END_OF_FILE,
+ message=str(e))
+ except iostream.StreamBufferFullError as e:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message=str(e))
- if self.is_queuing_reads:
- # If a read is already in flight, then the while loop below should
- # pull it from self.read_queue
- return
-
- self.is_queuing_reads = True
- while self.read_queue:
- next_callback = self.read_queue.pop()
- result = yield gen.Task(self._readFrameFromStream)
- next_callback(result)
- self.is_queuing_reads = False
-
- @gen.engine
- def _readFrameFromStream(self, callback):
- logging.debug('_readFrameFromStream')
- frame_header = yield gen.Task(self.stream.read_bytes, 4)
- frame_length, = struct.unpack('!i', frame_header)
- logging.debug('received frame header, frame length = %i', frame_length)
- frame = yield gen.Task(self.stream.read_bytes, frame_length)
- logging.debug('received frame payload')
- callback(frame)
+ @gen.coroutine
+ def readFrame(self):
+ # IOStream processes reads one at a time
+ with (yield self._read_lock.acquire()):
+ with self.io_exception_context():
+ frame_header = yield self.stream.read_bytes(4)
+ if len(frame_header) == 0:
+ raise iostream.StreamClosedError('Read zero bytes from stream')
+ frame_length, = struct.unpack('!i', frame_header)
+ logging.debug('received frame header, frame length = %d', frame_length)
+ frame = yield self.stream.read_bytes(frame_length)
+ logging.debug('received frame payload: %r', frame)
+ raise gen.Return(frame)
def write(self, buf):
self.__wbuf.write(buf)
- def flush(self, callback=None):
- wout = self.__wbuf.getvalue()
- wsz = len(wout)
+ def flush(self):
+ frame = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = StringIO()
- # N.B.: Doing this string concatenation is WAY cheaper than making
- # two separate calls to the underlying socket object. Socket writes in
- # Python turn out to be REALLY expensive, but it seems to do a pretty
- # good job of managing string buffer operations without excessive copies
- buf = struct.pack("!i", wsz) + wout
-
- logging.debug('writing frame length = %i', wsz)
- self.stream.write(buf, callback)
+ frame_length = struct.pack('!i', len(frame))
+ self.__wbuf = BytesIO()
+ with self.io_exception_context():
+ return self.stream.write(frame_length + frame)
class TTornadoServer(tcpserver.TCPServer):
@@ -135,19 +162,21 @@
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
+ @gen.coroutine
def handle_stream(self, stream, address):
+ host, port = address
+ trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
+ io_loop=self.io_loop)
+ oprot = self._oprot_factory.getProtocol(trans)
+
try:
- host, port = address
- trans = TTornadoStreamTransport(host=host, port=port, stream=stream)
- oprot = self._oprot_factory.getProtocol(trans)
-
- def next_pass():
- if not trans.stream.closed():
- self._processor.process(trans, self._iprot_factory, oprot,
- callback=next_pass)
-
- next_pass()
-
+ while not trans.stream.closed():
+ frame = yield trans.readFrame()
+ tr = TMemoryBuffer(frame)
+ iprot = self._iprot_factory.getProtocol(tr)
+ yield self._processor.process(iprot, oprot)
except Exception:
logging.exception('thrift exception in handle_stream')
trans.close()
+
+ logging.info('client disconnected %s:%d', host, port)