THRIFT-2231 Support tornado-4.x (Python)

Client: Python
Patch: Roey Berman
Signed-off-by: Roger Meier <roger@apache.org>
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index 4d155e0..5ac92c3 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -223,17 +223,6 @@
                                           t_doc* tdoc);
 
   /**
-   * a type for specifying to function_signature what type of Tornado callback
-   * parameter to add
-   */
-
-  enum tornado_callback_t {
-    NONE = 0,
-    MANDATORY_FOR_ONEWAY_ELSE_NONE = 1,
-    OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY = 2,
-  };
-
-  /**
    * Helper rendering functions
    */
 
@@ -245,8 +234,7 @@
   std::string render_field_default_value(t_field* tfield);
   std::string type_name(t_type* ttype);
   std::string function_signature(t_function* tfunction,
-                                 bool interface=false,
-                                 tornado_callback_t callback=NONE);
+                                 bool interface=false);
   std::string argument_list(t_struct* tstruct,
                             std::vector<std::string> *pre=NULL,
                             std::vector<std::string> *post=NULL);
@@ -1067,7 +1055,8 @@
       "from thrift.transport import TTwisted" << endl;
   } else if (gen_tornado_) {
     f_service_ << "from tornado import gen" << endl;
-    f_service_ << "from tornado import stack_context" << endl;
+    f_service_ << "from tornado import concurrent" << endl;
+    f_service_ << "from thrift.transport import TTransport" << endl;
   }
 
   f_service_ << endl;
@@ -1156,7 +1145,7 @@
     vector<t_function*>::iterator f_iter;
     for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
       f_service_ <<
-        indent() << "def " << function_signature(*f_iter, true, OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) << ":" << endl;
+        indent() << "def " << function_signature(*f_iter, true) << ":" << endl;
       indent_up();
       generate_python_docstring(f_service_, (*f_iter));
       f_service_ <<
@@ -1229,6 +1218,7 @@
         indent() << "                         else iprot_factory)" << endl <<
         indent() << "  self._seqid = 0" << endl <<
         indent() << "  self._reqs = {}" << endl <<
+        indent() << "  self._transport.io_loop.spawn_callback(self._start_receiving)" << endl <<
         endl;
     } else {
       f_service_ <<
@@ -1257,18 +1247,29 @@
   if (gen_tornado_ && extends.empty()) {
     f_service_ <<
       indent() << "@gen.engine" << endl <<
-      indent() << "def recv_dispatch(self):" << endl <<
-      indent() << "  \"\"\"read a response from the wire. schedule exactly one per send that" << endl <<
-      indent() << "  expects a response, but it doesn't matter which callee gets which" << endl <<
-      indent() << "  response; they're dispatched here properly\"\"\"" << endl <<
-      endl <<
-      indent() << "  # wait for a frame header" << endl <<
-      indent() << "  frame = yield gen.Task(self._transport.readFrame)" << endl <<
-      indent() << "  tr = TTransport.TMemoryBuffer(frame)" << endl <<
-      indent() << "  iprot = self._iprot_factory.getProtocol(tr)" << endl <<
-      indent() << "  (fname, mtype, rseqid) = iprot.readMessageBegin()" << endl <<
-      indent() << "  method = getattr(self, 'recv_' + fname)" << endl <<
-      indent() << "  method(iprot, mtype, rseqid)" << endl <<
+      indent() << "def _start_receiving(self):" << endl <<
+      indent() << "  while True:" << endl <<
+      indent() << "    try:" << endl <<
+      indent() << "      frame = yield self._transport.readFrame()" << endl <<
+      indent() << "    except TTransport.TTransportException as e:" << endl <<
+      indent() << "      for future in self._reqs.itervalues():" << endl <<
+      indent() << "        future.set_exception(e)" << endl <<
+      indent() << "      self._reqs = {}" << endl <<
+      indent() << "      return" << endl <<
+      indent() << "    tr = TTransport.TMemoryBuffer(frame)" << endl <<
+      indent() << "    iprot = self._iprot_factory.getProtocol(tr)" << endl <<
+      indent() << "    (fname, mtype, rseqid) = iprot.readMessageBegin()" << endl <<
+      indent() << "    method = getattr(self, 'recv_' + fname)" << endl <<
+      indent() << "    future = self._reqs.pop(rseqid, None)" << endl <<
+      indent() << "    if not future:" << endl <<
+      indent() << "      # future has already been discarded" << endl <<
+      indent() << "      continue" << endl <<
+      indent() << "    try:" << endl <<
+      indent() << "      result = method(iprot, mtype, rseqid)" << endl <<
+      indent() << "    except Exception as e:" << endl <<
+      indent() << "      future.set_exception(e)" << endl <<
+      indent() << "    else:" << endl <<
+      indent() << "      future.set_result(result)" << endl <<
       endl;
   }
 
@@ -1283,7 +1284,7 @@
 
     // Open function
     indent(f_service_) <<
-      "def " << function_signature(*f_iter, false, OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) << ":" << endl;
+      "def " << function_signature(*f_iter, false) << ":" << endl;
     indent_up();
     generate_python_docstring(f_service_, (*f_iter));
     if (gen_twisted_) {
@@ -1296,7 +1297,7 @@
       indent(f_service_) << "self._seqid += 1" << endl;
       if (!(*f_iter)->is_oneway()) {
         indent(f_service_) <<
-          "self._reqs[self._seqid] = callback" << endl;
+          "future = self._reqs[self._seqid] = concurrent.Future()" << endl;
       }
     }
 
@@ -1313,15 +1314,6 @@
       f_service_ << (*fld_iter)->get_name();
     }
 
-    if (gen_tornado_ && (*f_iter)->is_oneway()) {
-      if (first) {
-        first = false;
-      } else {
-        f_service_ << ", ";
-      }
-      f_service_ << "callback";
-    }
-
     f_service_ << ")" << endl;
 
     if (!(*f_iter)->is_oneway()) {
@@ -1329,7 +1321,7 @@
       if (gen_twisted_) {
         f_service_ << "return d" << endl;
       } else if (gen_tornado_) {
-        f_service_ << "self.recv_dispatch()" << endl;
+        f_service_ << "return future" << endl;
       } else {
         if (!(*f_iter)->get_returntype()->is_void()) {
           f_service_ << "return ";
@@ -1347,7 +1339,7 @@
     f_service_ << endl;
 
     indent(f_service_) <<
-      "def send_" << function_signature(*f_iter, false, MANDATORY_FOR_ONEWAY_ELSE_NONE) << ":" << endl;
+      "def send_" << function_signature(*f_iter, false) << ":" << endl;
 
     indent_up();
 
@@ -1374,24 +1366,11 @@
     }
 
     // Write to the stream
-    if (gen_twisted_) {
+    if (gen_twisted_ || gen_tornado_) {
       f_service_ <<
         indent() << "args.write(oprot)" << endl <<
         indent() << "oprot.writeMessageEnd()" << endl <<
         indent() << "oprot.trans.flush()" << endl;
-    } else if (gen_tornado_) {
-      f_service_ <<
-        indent() << "args.write(oprot)" << endl <<
-        indent() << "oprot.writeMessageEnd()" << endl;
-      if ((*f_iter)->is_oneway()) {
-        // send_* carry the callback so you can block on the write's flush
-        // (rather than on receipt of the response)
-        f_service_ <<
-          indent() << "oprot.trans.flush(callback=callback)" << endl;
-      } else {
-        f_service_ <<
-          indent() << "oprot.trans.flush()" << endl;
-      }
     } else {
       f_service_ <<
         indent() << "args.write(self._oprot)" << endl <<
@@ -1426,11 +1405,10 @@
         f_service_ <<
           indent() << "d = self._reqs.pop(rseqid)" << endl;
       } else if (gen_tornado_) {
-        f_service_ <<
-          indent() << "callback = self._reqs.pop(rseqid)" << endl;
       } else {
         f_service_ <<
-          indent() << "(fname, mtype, rseqid) = self._iprot.readMessageBegin()" << endl;
+          indent() << "iprot = self._iprot" << endl <<
+          indent() << "(fname, mtype, rseqid) = iprot.readMessageBegin()" << endl;
       }
 
       f_service_ <<
@@ -1445,23 +1423,14 @@
           indent() << "result = " << resultname << "()" << endl <<
           indent() << "result.read(iprot)" << endl <<
           indent() << "iprot.readMessageEnd()" << endl;
-      } else if (gen_tornado_) {
+      } else {
         f_service_ <<
           indent() << "  x.read(iprot)" << endl <<
           indent() << "  iprot.readMessageEnd()" << endl <<
-          indent() << "  callback(x)" << endl <<
-          indent() << "  return" << endl <<
+          indent() << "  raise x" << endl <<
           indent() << "result = " << resultname << "()" << endl <<
           indent() << "result.read(iprot)" << endl <<
           indent() << "iprot.readMessageEnd()" << endl;
-      } else {
-        f_service_ <<
-          indent() << "  x.read(self._iprot)" << endl <<
-          indent() << "  self._iprot.readMessageEnd()" << endl <<
-          indent() << "  raise x" << endl <<
-          indent() << "result = " << resultname << "()" << endl <<
-          indent() << "result.read(self._iprot)" << endl <<
-          indent() << "self._iprot.readMessageEnd()" << endl;
       }
 
       // Careful, only return _result if not a void function
@@ -1471,10 +1440,6 @@
           if (gen_twisted_) {
             f_service_ <<
               indent() << "  return d.callback(result.success)" << endl;
-          } else if (gen_tornado_) {
-            f_service_ <<
-              indent() << "  callback(result.success)" << endl <<
-              indent() << "  return" << endl;
           } else {
             f_service_ <<
               indent() << "  return result.success" << endl;
@@ -1490,11 +1455,6 @@
           if (gen_twisted_) {
             f_service_ <<
               indent() << "  return d.errback(result." << (*x_iter)->get_name() << ")" << endl;
-
-          } else if (gen_tornado_) {
-            f_service_ <<
-              indent() << "  callback(result." << (*x_iter)->get_name() << ")" << endl <<
-              indent() << "  return" << endl;
           } else {
             f_service_ <<
               indent() << "  raise result." << (*x_iter)->get_name() << "" << endl;
@@ -1506,10 +1466,6 @@
         if (gen_twisted_) {
           f_service_ <<
             indent() << "return d.callback(None)" << endl;
-        } else if (gen_tornado_) {
-          f_service_ <<
-            indent() << "callback(None)" << endl <<
-            indent() << "return" << endl;
         } else {
           f_service_ <<
             indent() << "return" << endl;
@@ -1518,10 +1474,6 @@
         if (gen_twisted_) {
           f_service_ <<
             indent() << "return d.errback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl;
-        } else if (gen_tornado_) {
-          f_service_ <<
-            indent() << "callback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl <<
-            indent() << "return" << endl;
         } else {
           f_service_ <<
             indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl;
@@ -1785,22 +1737,9 @@
   f_service_ << endl;
 
   // Generate the server implementation
-  if (gen_tornado_) {
-    f_service_ <<
-      indent() << "@gen.engine" << endl <<
-      indent() << "def process(self, transport, iprot_factory, oprot, callback):" << endl;
-    indent_up();
-    f_service_ <<
-      indent() << "# wait for a frame header" << endl <<
-      indent() << "frame = yield gen.Task(transport.readFrame)" << endl <<
-      indent() << "tr = TTransport.TMemoryBuffer(frame)" << endl <<
-      indent() << "iprot = iprot_factory.getProtocol(tr)" << endl <<
-      endl;
-  } else {
-    f_service_ <<
-      indent() << "def process(self, iprot, oprot):" << endl;
-    indent_up();
-  }
+  f_service_ <<
+    indent() << "def process(self, iprot, oprot):" << endl;
+  indent_up();
 
   f_service_ <<
     indent() << "(name, type, seqid) = iprot.readMessageBegin()" << endl;
@@ -1821,8 +1760,6 @@
   if (gen_twisted_) {
     f_service_ <<
       indent() << "  return defer.succeed(None)" << endl;
-  } else if (gen_tornado_) {
-    // nothing
   } else {
     f_service_ <<
       indent() << "  return" << endl;
@@ -1831,13 +1768,9 @@
   f_service_ <<
     indent() << "else:" << endl;
 
-  if (gen_twisted_) {
+  if (gen_twisted_ || gen_tornado_) {
     f_service_ <<
       indent() << "  return self._processMap[name](self, seqid, iprot, oprot)" << endl;
-  } else if (gen_tornado_) {
-    f_service_ <<
-      indent() << "  yield gen.Task(self._processMap[name], self, seqid, iprot, oprot)" << endl <<
-      indent() << "callback()" << endl;
   } else {
     f_service_ <<
       indent() << "  self._processMap[name](self, seqid, iprot, oprot)" << endl;
@@ -1870,9 +1803,9 @@
   // Open function
   if (gen_tornado_) {
     f_service_ <<
-      indent() << "@gen.engine" << endl <<
+      indent() << "@gen.coroutine" << endl <<
       indent() << "def process_" << tfunction->get_name() <<
-                  "(self, seqid, iprot, oprot, callback):" << endl;
+                  "(self, seqid, iprot, oprot):" << endl;
   } else {
     f_service_ <<
       indent() << "def process_" << tfunction->get_name() <<
@@ -1996,6 +1929,7 @@
     }
 
   } else if (gen_tornado_) {
+    /*
     if (!tfunction->is_oneway() && xceptions.size() > 0) {
       f_service_ <<
         endl <<
@@ -2014,21 +1948,27 @@
 
       f_service_ <<
         endl <<
-        indent() << "with stack_context.ExceptionStackContext(handle_exception):" << endl;
+        indent() << "try:" << endl;
       indent_up();
     }
+    */
 
     // Generate the function call
     t_struct* arg_struct = tfunction->get_arglist();
     const std::vector<t_field*>& fields = arg_struct->get_members();
     vector<t_field*>::const_iterator f_iter;
 
+    if (xceptions.size() > 0) {
+      f_service_ <<
+        indent() << "try:" << endl;
+      indent_up();
+    }
     f_service_ << indent();
     if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
       f_service_ << "result.success = ";
     }
     f_service_ <<
-      "yield gen.Task(self._handler." << tfunction->get_name() << ", ";
+      "yield gen.maybe_future(self._handler." << tfunction->get_name() << "(";
     bool first = true;
     for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
       if (first) {
@@ -2038,22 +1978,27 @@
       }
       f_service_ << "args." << (*f_iter)->get_name();
     }
-    f_service_ << ")" << endl;
-
-    if (xceptions.size() > 0) {
-      f_service_ << endl;
-    }
+    f_service_ << "))" << endl;
 
     if (!tfunction->is_oneway() && xceptions.size() > 0) {
       indent_down();
+      for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
+        f_service_ <<
+          indent() << "except " << type_name((*x_iter)->get_type()) << ", " << (*x_iter)->get_name() << ":" << endl;
+        if (!tfunction->is_oneway()) {
+          indent_up();
+          f_service_ <<
+            indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << endl;
+          indent_down();
+        } else {
+          f_service_ <<
+            indent() << "pass" << endl;
+        }
+      }
     }
 
     // Shortcut out here for oneway functions
     if (tfunction->is_oneway()) {
-      f_service_ <<
-        indent() << "callback()" << endl;
-      indent_down();
-      f_service_ << endl;
       return;
     }
 
@@ -2061,8 +2006,7 @@
       indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << "\", TMessageType.REPLY, seqid)" << endl <<
       indent() << "result.write(oprot)" << endl <<
       indent() << "oprot.writeMessageEnd()" << endl <<
-      indent() << "oprot.trans.flush()" << endl <<
-      indent() << "callback()" << endl;
+      indent() << "oprot.trans.flush()" << endl;
 
     // Close function
     indent_down();
@@ -2621,8 +2565,7 @@
  * @return String of rendered function definition
  */
 string t_py_generator::function_signature(t_function* tfunction,
-                                          bool interface,
-                                          tornado_callback_t callback) {
+                                          bool interface) {
   vector<string> pre;
   vector<string> post;
   string signature = tfunction->get_name() + "(";
@@ -2631,22 +2574,6 @@
     pre.push_back("self");
   }
 
-  if (gen_tornado_) {
-    if (callback == NONE) {
-    } else if (callback == MANDATORY_FOR_ONEWAY_ELSE_NONE) {
-      if (tfunction->is_oneway()) {
-        // Tornado send_* carry the callback so you can block on the write's flush
-        // (rather than on receipt of the response)
-        post.push_back("callback");
-      }
-    } else if (callback == OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) {
-      if (tfunction->is_oneway()) {
-        post.push_back("callback=None");
-      } else {
-        post.push_back("callback");
-      }
-    }
-  }
   signature += argument_list(tfunction->get_arglist(), &pre, &post) + ")";
   return signature;
 }
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index d90f672..c8498c5 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -17,58 +17,91 @@
 # under the License.
 #
 
-from cStringIO import StringIO
+from __future__ import absolute_import
 import logging
 import socket
 import struct
 
-from thrift.transport import TTransport
-from thrift.transport.TTransport import TTransportException
+from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
 
-from tornado import gen
-from tornado import iostream
-from tornado import tcpserver
+from io import BytesIO
+from collections import deque
+from contextlib import contextmanager
+from tornado import gen, iostream, ioloop, tcpserver, concurrent
+
+__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
 
 
-class TTornadoStreamTransport(TTransport.TTransportBase):
+class _Lock(object):
+    def __init__(self):
+        self._waiters = deque()
+
+    def acquired(self):
+        return len(self._waiters) > 0
+
+    @gen.coroutine
+    def acquire(self):
+        blocker = self._waiters[-1] if self.acquired() else None
+        future = concurrent.Future()
+        self._waiters.append(future)
+        if blocker:
+            yield blocker
+
+        raise gen.Return(self._lock_context())
+
+    def release(self):
+        assert self.acquired(), 'Lock not aquired'
+        future = self._waiters.popleft()
+        future.set_result(None)
+
+    @contextmanager
+    def _lock_context(self):
+        try:
+            yield
+        finally:
+            self.release()
+
+
+class TTornadoStreamTransport(TTransportBase):
     """a framed, buffered transport over a Tornado stream"""
-    def __init__(self, host, port, stream=None):
+    def __init__(self, host, port, stream=None, io_loop=None):
         self.host = host
         self.port = port
-        self.is_queuing_reads = False
-        self.read_queue = []
-        self.__wbuf = StringIO()
+        self.io_loop = io_loop or ioloop.IOLoop.current()
+        self.__wbuf = BytesIO()
+        self._read_lock = _Lock()
 
         # servers provide a ready-to-go stream
         self.stream = stream
-        if self.stream is not None:
-            self._set_close_callback()
 
-    # not the same number of parameters as TTransportBase.open
-    def open(self, callback):
+    def with_timeout(self, timeout, future):
+        return gen.with_timeout(timeout, future, self.io_loop)
+
+    @gen.coroutine
+    def open(self, timeout=None):
         logging.debug('socket connecting')
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         self.stream = iostream.IOStream(sock)
 
-        def on_close_in_connect(*_):
-            message = 'could not connect to {}:{}'.format(self.host, self.port)
+        try:
+            connect = self.stream.connect((self.host, self.port))
+            if timeout is not None:
+                yield self.with_timeout(timeout, connect)
+            else:
+                yield connect
+        except (socket.error, IOError, ioloop.TimeoutError) as e:
+            message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
             raise TTransportException(
                 type=TTransportException.NOT_OPEN,
                 message=message)
-        self.stream.set_close_callback(on_close_in_connect)
 
-        def finish(*_):
-            self._set_close_callback()
-            callback()
+        raise gen.Return(self)
 
-        self.stream.connect((self.host, self.port), callback=finish)
-
-    def _set_close_callback(self):
-        def on_close():
-            raise TTransportException(
-                type=TTransportException.END_OF_FILE,
-                message='socket closed')
-        self.stream.set_close_callback(self.close)
+    def set_close_callback(self, callback):
+        """
+        Should be called only after open() returns
+        """
+        self.stream.set_close_callback(callback)
 
     def close(self):
         # don't raise if we intend to close
@@ -78,51 +111,45 @@
     def read(self, _):
         # The generated code for Tornado shouldn't do individual reads -- only
         # frames at a time
-        assert "you're doing it wrong" is True
+        assert False, "you're doing it wrong"
 
-    @gen.engine
-    def readFrame(self, callback):
-        self.read_queue.append(callback)
-        logging.debug('read queue: %s', self.read_queue)
+    @contextmanager
+    def io_exception_context(self):
+        try:
+            yield
+        except (socket.error, IOError) as e:
+            raise TTransportException(
+                type=TTransportException.END_OF_FILE,
+                message=str(e))
+        except iostream.StreamBufferFullError as e:
+            raise TTransportException(
+                type=TTransportException.UNKNOWN,
+                message=str(e))
 
-        if self.is_queuing_reads:
-            # If a read is already in flight, then the while loop below should
-            # pull it from self.read_queue
-            return
-
-        self.is_queuing_reads = True
-        while self.read_queue:
-            next_callback = self.read_queue.pop()
-            result = yield gen.Task(self._readFrameFromStream)
-            next_callback(result)
-        self.is_queuing_reads = False
-
-    @gen.engine
-    def _readFrameFromStream(self, callback):
-        logging.debug('_readFrameFromStream')
-        frame_header = yield gen.Task(self.stream.read_bytes, 4)
-        frame_length, = struct.unpack('!i', frame_header)
-        logging.debug('received frame header, frame length = %i', frame_length)
-        frame = yield gen.Task(self.stream.read_bytes, frame_length)
-        logging.debug('received frame payload')
-        callback(frame)
+    @gen.coroutine
+    def readFrame(self):
+        # IOStream processes reads one at a time
+        with (yield self._read_lock.acquire()):
+            with self.io_exception_context():
+                frame_header = yield self.stream.read_bytes(4)
+                if len(frame_header) == 0:
+                    raise iostream.StreamClosedError('Read zero bytes from stream')
+                frame_length, = struct.unpack('!i', frame_header)
+                logging.debug('received frame header, frame length = %d', frame_length)
+                frame = yield self.stream.read_bytes(frame_length)
+                logging.debug('received frame payload: %r', frame)
+                raise gen.Return(frame)
 
     def write(self, buf):
         self.__wbuf.write(buf)
 
-    def flush(self, callback=None):
-        wout = self.__wbuf.getvalue()
-        wsz = len(wout)
+    def flush(self):
+        frame = self.__wbuf.getvalue()
         # reset wbuf before write/flush to preserve state on underlying failure
-        self.__wbuf = StringIO()
-        # N.B.: Doing this string concatenation is WAY cheaper than making
-        # two separate calls to the underlying socket object. Socket writes in
-        # Python turn out to be REALLY expensive, but it seems to do a pretty
-        # good job of managing string buffer operations without excessive copies
-        buf = struct.pack("!i", wsz) + wout
-
-        logging.debug('writing frame length = %i', wsz)
-        self.stream.write(buf, callback)
+        frame_length = struct.pack('!i', len(frame))
+        self.__wbuf = BytesIO()
+        with self.io_exception_context():
+            return self.stream.write(frame_length + frame)
 
 
 class TTornadoServer(tcpserver.TCPServer):
@@ -135,19 +162,21 @@
         self._oprot_factory = (oprot_factory if oprot_factory is not None
                                else iprot_factory)
 
+    @gen.coroutine
     def handle_stream(self, stream, address):
+        host, port = address
+        trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
+                                        io_loop=self.io_loop)
+        oprot = self._oprot_factory.getProtocol(trans)
+
         try:
-            host, port = address
-            trans = TTornadoStreamTransport(host=host, port=port, stream=stream)
-            oprot = self._oprot_factory.getProtocol(trans)
-
-            def next_pass():
-                if not trans.stream.closed():
-                    self._processor.process(trans, self._iprot_factory, oprot,
-                                            callback=next_pass)
-
-            next_pass()
-
+            while not trans.stream.closed():
+                frame = yield trans.readFrame()
+                tr = TMemoryBuffer(frame)
+                iprot = self._iprot_factory.getProtocol(tr)
+                yield self._processor.process(iprot, oprot)
         except Exception:
             logging.exception('thrift exception in handle_stream')
             trans.close()
+
+        logging.info('client disconnected %s:%d', host, port)
diff --git a/test/py.tornado/test_suite.py b/test/py.tornado/test_suite.py
index f04ba04..c783962 100755
--- a/test/py.tornado/test_suite.py
+++ b/test/py.tornado/test_suite.py
@@ -22,11 +22,13 @@
 import datetime
 import glob
 import sys
+import os
 import time
 import unittest
 
-sys.path.insert(0, './gen-py.tornado')
-sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+basepath = os.path.abspath(os.path.dirname(__file__))
+sys.path.insert(0, basepath+'/gen-py.tornado')
+sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0])
 
 try:
     __import__('tornado')
@@ -34,11 +36,12 @@
     print "module `tornado` not found, skipping test"
     sys.exit(0)
 
-from tornado import gen, ioloop, stack_context
-from tornado.testing import AsyncTestCase, get_unused_port
+from tornado import gen
+from tornado.testing import AsyncTestCase, get_unused_port, gen_test
 
 from thrift import TTornado
 from thrift.protocol import TBinaryProtocol
+from thrift.transport.TTransport import TTransportException
 
 from ThriftTest import ThriftTest
 from ThriftTest.ttypes import *
@@ -48,31 +51,31 @@
     def __init__(self, test_instance):
         self.test_instance = test_instance
 
-    def testVoid(self, callback):
-        callback()
+    def testVoid(self):
+        pass
 
-    def testString(self, s, callback):
-        callback(s)
+    def testString(self, s):
+        return s
 
-    def testByte(self, b, callback):
-        callback(b)
+    def testByte(self, b):
+        return b
 
-    def testI16(self, i16, callback):
-        callback(i16)
+    def testI16(self, i16):
+        return i16
 
-    def testI32(self, i32, callback):
-        callback(i32)
+    def testI32(self, i32):
+        return i32
 
-    def testI64(self, i64, callback):
-        callback(i64)
+    def testI64(self, i64):
+        return i64
 
-    def testDouble(self, dub, callback):
-        callback(dub)
+    def testDouble(self, dub):
+        return dub
 
-    def testStruct(self, thing, callback):
-        callback(thing)
+    def testStruct(self, thing):
+        return thing
 
-    def testException(self, s, callback):
+    def testException(self, s):
         if s == 'Xception':
             x = Xception()
             x.errorCode = 1001
@@ -80,133 +83,139 @@
             raise x
         elif s == 'throw_undeclared':
             raise ValueError("foo")
-        callback()
 
-    def testOneway(self, seconds, callback=None):
+    def testOneway(self, seconds):
         start = time.time()
+
         def fire_oneway():
             end = time.time()
             self.test_instance.stop((start, end, seconds))
 
-        ioloop.IOLoop.instance().add_timeout(
+        self.test_instance.io_loop.add_timeout(
             datetime.timedelta(seconds=seconds),
             fire_oneway)
 
-        if callback:
-            callback()
+    def testNest(self, thing):
+        return thing
 
-    def testNest(self, thing, callback):
-        callback(thing)
+    @gen.coroutine
+    def testMap(self, thing):
+        yield gen.moment
+        raise gen.Return(thing)
 
-    def testMap(self, thing, callback):
-        callback(thing)
+    def testSet(self, thing):
+        return thing
 
-    def testSet(self, thing, callback):
-        callback(thing)
+    def testList(self, thing):
+        return thing
 
-    def testList(self, thing, callback):
-        callback(thing)
+    def testEnum(self, thing):
+        return thing
 
-    def testEnum(self, thing, callback):
-        callback(thing)
-
-    def testTypedef(self, thing, callback):
-        callback(thing)
+    def testTypedef(self, thing):
+        return thing
 
 
 class ThriftTestCase(AsyncTestCase):
-    def get_new_ioloop(self):
-        return ioloop.IOLoop.instance()
-
     def setUp(self):
+        super(ThriftTestCase, self).setUp()
+
         self.port = get_unused_port()
-        self.io_loop = self.get_new_ioloop()
 
         # server
         self.handler = TestHandler(self)
         self.processor = ThriftTest.Processor(self.handler)
         self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
 
-        self.server = TTornado.TTornadoServer(self.processor, self.pfactory)
+        self.server = TTornado.TTornadoServer(self.processor, self.pfactory, io_loop=self.io_loop)
         self.server.bind(self.port)
         self.server.start(1)
 
         # client
-        transport = TTornado.TTornadoStreamTransport('localhost', self.port)
+        transport = TTornado.TTornadoStreamTransport('localhost', self.port, io_loop=self.io_loop)
         pfactory = TBinaryProtocol.TBinaryProtocolFactory()
+        self.io_loop.run_sync(transport.open)
         self.client = ThriftTest.Client(transport, pfactory)
-        transport.open(callback=self.stop)
-        self.wait(timeout=1)
 
+    @gen_test
     def test_void(self):
-        self.client.testVoid(callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, None)
+        v = yield self.client.testVoid()
+        self.assertEqual(v, None)
 
+    @gen_test
     def test_string(self):
-        self.client.testString('Python', callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, 'Python')
+        v = yield self.client.testString('Python')
+        self.assertEqual(v, 'Python')
 
+    @gen_test
     def test_byte(self):
-        self.client.testByte(63, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, 63)
+        v = yield self.client.testByte(63)
+        self.assertEqual(v, 63)
 
+    @gen_test
     def test_i32(self):
-        self.client.testI32(-1, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, -1)
+        v = yield self.client.testI32(-1)
+        self.assertEqual(v, -1)
 
-        self.client.testI32(0, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, 0)
+        v = yield self.client.testI32(0)
+        self.assertEqual(v, 0)
 
+    @gen_test
     def test_i64(self):
-        self.client.testI64(-34359738368, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, -34359738368)
+        v = yield self.client.testI64(-34359738368)
+        self.assertEqual(v, -34359738368)
 
+    @gen_test
     def test_double(self):
-        self.client.testDouble(-5.235098235, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, -5.235098235)
+        v = yield self.client.testDouble(-5.235098235)
+        self.assertEqual(v, -5.235098235)
 
+    @gen_test
     def test_struct(self):
         x = Xtruct()
         x.string_thing = "Zero"
         x.byte_thing = 1
         x.i32_thing = -3
         x.i64_thing = -5
-        self.client.testStruct(x, callback=self.stop)
+        y = yield self.client.testStruct(x)
 
-        y = self.wait(timeout=1)
-        self.assertEquals(y.string_thing, "Zero")
-        self.assertEquals(y.byte_thing, 1)
-        self.assertEquals(y.i32_thing, -3)
-        self.assertEquals(y.i64_thing, -5)
-
-    def test_exception(self):
-        self.client.testException('Safe', callback=self.stop)
-        v = self.wait(timeout=1)
-
-        self.client.testException('Xception', callback=self.stop)
-        ex = self.wait(timeout=1)
-        if type(ex) == Xception:
-            self.assertEquals(ex.errorCode, 1001)
-            self.assertEquals(ex.message, 'Xception')
-        else:
-            self.fail("should have gotten exception")
+        self.assertEqual(y.string_thing, "Zero")
+        self.assertEqual(y.byte_thing, 1)
+        self.assertEqual(y.i32_thing, -3)
+        self.assertEqual(y.i64_thing, -5)
 
     def test_oneway(self):
-        def return_from_send():
-            self.stop('done with send')
-        self.client.testOneway(0.5, callback=return_from_send)
-        self.assertEquals(self.wait(timeout=1), 'done with send')
-
+        self.client.testOneway(0.5)
         start, end, seconds = self.wait(timeout=1)
         self.assertAlmostEquals(seconds, (end - start), places=3)
 
+    @gen_test
+    def test_map(self):
+        """
+        TestHandler.testMap is a coroutine, this test checks if gen.Return() from a coroutine works.
+        """
+        expected = {1: 1}
+        res = yield self.client.testMap(expected)
+        self.assertEqual(res, expected)
+
+    @gen_test
+    def test_exception(self):
+        yield self.client.testException('Safe')
+
+        try:
+            yield self.client.testException('Xception')
+        except Xception as ex:
+            self.assertEqual(ex.errorCode, 1001)
+            self.assertEqual(ex.message, 'Xception')
+        else:
+            self.fail("should have gotten exception")
+        try:
+            yield self.client.testException('throw_undeclared')
+        except TTransportException as ex:
+            pass
+        else:
+            self.fail("should have gotten exception")
+
 
 def suite():
     suite = unittest.TestSuite()