THRIFT-2231 Support tornado-4.x (Python)
Client: Python
Patch: Roey Berman
Signed-off-by: Roger Meier <roger@apache.org>
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index 4d155e0..5ac92c3 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -223,17 +223,6 @@
t_doc* tdoc);
/**
- * a type for specifying to function_signature what type of Tornado callback
- * parameter to add
- */
-
- enum tornado_callback_t {
- NONE = 0,
- MANDATORY_FOR_ONEWAY_ELSE_NONE = 1,
- OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY = 2,
- };
-
- /**
* Helper rendering functions
*/
@@ -245,8 +234,7 @@
std::string render_field_default_value(t_field* tfield);
std::string type_name(t_type* ttype);
std::string function_signature(t_function* tfunction,
- bool interface=false,
- tornado_callback_t callback=NONE);
+ bool interface=false);
std::string argument_list(t_struct* tstruct,
std::vector<std::string> *pre=NULL,
std::vector<std::string> *post=NULL);
@@ -1067,7 +1055,8 @@
"from thrift.transport import TTwisted" << endl;
} else if (gen_tornado_) {
f_service_ << "from tornado import gen" << endl;
- f_service_ << "from tornado import stack_context" << endl;
+ f_service_ << "from tornado import concurrent" << endl;
+ f_service_ << "from thrift.transport import TTransport" << endl;
}
f_service_ << endl;
@@ -1156,7 +1145,7 @@
vector<t_function*>::iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
f_service_ <<
- indent() << "def " << function_signature(*f_iter, true, OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) << ":" << endl;
+ indent() << "def " << function_signature(*f_iter, true) << ":" << endl;
indent_up();
generate_python_docstring(f_service_, (*f_iter));
f_service_ <<
@@ -1229,6 +1218,7 @@
indent() << " else iprot_factory)" << endl <<
indent() << " self._seqid = 0" << endl <<
indent() << " self._reqs = {}" << endl <<
+ indent() << " self._transport.io_loop.spawn_callback(self._start_receiving)" << endl <<
endl;
} else {
f_service_ <<
@@ -1257,18 +1247,29 @@
if (gen_tornado_ && extends.empty()) {
f_service_ <<
indent() << "@gen.engine" << endl <<
- indent() << "def recv_dispatch(self):" << endl <<
- indent() << " \"\"\"read a response from the wire. schedule exactly one per send that" << endl <<
- indent() << " expects a response, but it doesn't matter which callee gets which" << endl <<
- indent() << " response; they're dispatched here properly\"\"\"" << endl <<
- endl <<
- indent() << " # wait for a frame header" << endl <<
- indent() << " frame = yield gen.Task(self._transport.readFrame)" << endl <<
- indent() << " tr = TTransport.TMemoryBuffer(frame)" << endl <<
- indent() << " iprot = self._iprot_factory.getProtocol(tr)" << endl <<
- indent() << " (fname, mtype, rseqid) = iprot.readMessageBegin()" << endl <<
- indent() << " method = getattr(self, 'recv_' + fname)" << endl <<
- indent() << " method(iprot, mtype, rseqid)" << endl <<
+ indent() << "def _start_receiving(self):" << endl <<
+ indent() << " while True:" << endl <<
+ indent() << " try:" << endl <<
+ indent() << " frame = yield self._transport.readFrame()" << endl <<
+ indent() << " except TTransport.TTransportException as e:" << endl <<
+ indent() << " for future in self._reqs.itervalues():" << endl <<
+ indent() << " future.set_exception(e)" << endl <<
+ indent() << " self._reqs = {}" << endl <<
+ indent() << " return" << endl <<
+ indent() << " tr = TTransport.TMemoryBuffer(frame)" << endl <<
+ indent() << " iprot = self._iprot_factory.getProtocol(tr)" << endl <<
+ indent() << " (fname, mtype, rseqid) = iprot.readMessageBegin()" << endl <<
+ indent() << " method = getattr(self, 'recv_' + fname)" << endl <<
+ indent() << " future = self._reqs.pop(rseqid, None)" << endl <<
+ indent() << " if not future:" << endl <<
+ indent() << " # future has already been discarded" << endl <<
+ indent() << " continue" << endl <<
+ indent() << " try:" << endl <<
+ indent() << " result = method(iprot, mtype, rseqid)" << endl <<
+ indent() << " except Exception as e:" << endl <<
+ indent() << " future.set_exception(e)" << endl <<
+ indent() << " else:" << endl <<
+ indent() << " future.set_result(result)" << endl <<
endl;
}
@@ -1283,7 +1284,7 @@
// Open function
indent(f_service_) <<
- "def " << function_signature(*f_iter, false, OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) << ":" << endl;
+ "def " << function_signature(*f_iter, false) << ":" << endl;
indent_up();
generate_python_docstring(f_service_, (*f_iter));
if (gen_twisted_) {
@@ -1296,7 +1297,7 @@
indent(f_service_) << "self._seqid += 1" << endl;
if (!(*f_iter)->is_oneway()) {
indent(f_service_) <<
- "self._reqs[self._seqid] = callback" << endl;
+ "future = self._reqs[self._seqid] = concurrent.Future()" << endl;
}
}
@@ -1313,15 +1314,6 @@
f_service_ << (*fld_iter)->get_name();
}
- if (gen_tornado_ && (*f_iter)->is_oneway()) {
- if (first) {
- first = false;
- } else {
- f_service_ << ", ";
- }
- f_service_ << "callback";
- }
-
f_service_ << ")" << endl;
if (!(*f_iter)->is_oneway()) {
@@ -1329,7 +1321,7 @@
if (gen_twisted_) {
f_service_ << "return d" << endl;
} else if (gen_tornado_) {
- f_service_ << "self.recv_dispatch()" << endl;
+ f_service_ << "return future" << endl;
} else {
if (!(*f_iter)->get_returntype()->is_void()) {
f_service_ << "return ";
@@ -1347,7 +1339,7 @@
f_service_ << endl;
indent(f_service_) <<
- "def send_" << function_signature(*f_iter, false, MANDATORY_FOR_ONEWAY_ELSE_NONE) << ":" << endl;
+ "def send_" << function_signature(*f_iter, false) << ":" << endl;
indent_up();
@@ -1374,24 +1366,11 @@
}
// Write to the stream
- if (gen_twisted_) {
+ if (gen_twisted_ || gen_tornado_) {
f_service_ <<
indent() << "args.write(oprot)" << endl <<
indent() << "oprot.writeMessageEnd()" << endl <<
indent() << "oprot.trans.flush()" << endl;
- } else if (gen_tornado_) {
- f_service_ <<
- indent() << "args.write(oprot)" << endl <<
- indent() << "oprot.writeMessageEnd()" << endl;
- if ((*f_iter)->is_oneway()) {
- // send_* carry the callback so you can block on the write's flush
- // (rather than on receipt of the response)
- f_service_ <<
- indent() << "oprot.trans.flush(callback=callback)" << endl;
- } else {
- f_service_ <<
- indent() << "oprot.trans.flush()" << endl;
- }
} else {
f_service_ <<
indent() << "args.write(self._oprot)" << endl <<
@@ -1426,11 +1405,10 @@
f_service_ <<
indent() << "d = self._reqs.pop(rseqid)" << endl;
} else if (gen_tornado_) {
- f_service_ <<
- indent() << "callback = self._reqs.pop(rseqid)" << endl;
} else {
f_service_ <<
- indent() << "(fname, mtype, rseqid) = self._iprot.readMessageBegin()" << endl;
+ indent() << "iprot = self._iprot" << endl <<
+ indent() << "(fname, mtype, rseqid) = iprot.readMessageBegin()" << endl;
}
f_service_ <<
@@ -1445,23 +1423,14 @@
indent() << "result = " << resultname << "()" << endl <<
indent() << "result.read(iprot)" << endl <<
indent() << "iprot.readMessageEnd()" << endl;
- } else if (gen_tornado_) {
+ } else {
f_service_ <<
indent() << " x.read(iprot)" << endl <<
indent() << " iprot.readMessageEnd()" << endl <<
- indent() << " callback(x)" << endl <<
- indent() << " return" << endl <<
+ indent() << " raise x" << endl <<
indent() << "result = " << resultname << "()" << endl <<
indent() << "result.read(iprot)" << endl <<
indent() << "iprot.readMessageEnd()" << endl;
- } else {
- f_service_ <<
- indent() << " x.read(self._iprot)" << endl <<
- indent() << " self._iprot.readMessageEnd()" << endl <<
- indent() << " raise x" << endl <<
- indent() << "result = " << resultname << "()" << endl <<
- indent() << "result.read(self._iprot)" << endl <<
- indent() << "self._iprot.readMessageEnd()" << endl;
}
// Careful, only return _result if not a void function
@@ -1471,10 +1440,6 @@
if (gen_twisted_) {
f_service_ <<
indent() << " return d.callback(result.success)" << endl;
- } else if (gen_tornado_) {
- f_service_ <<
- indent() << " callback(result.success)" << endl <<
- indent() << " return" << endl;
} else {
f_service_ <<
indent() << " return result.success" << endl;
@@ -1490,11 +1455,6 @@
if (gen_twisted_) {
f_service_ <<
indent() << " return d.errback(result." << (*x_iter)->get_name() << ")" << endl;
-
- } else if (gen_tornado_) {
- f_service_ <<
- indent() << " callback(result." << (*x_iter)->get_name() << ")" << endl <<
- indent() << " return" << endl;
} else {
f_service_ <<
indent() << " raise result." << (*x_iter)->get_name() << "" << endl;
@@ -1506,10 +1466,6 @@
if (gen_twisted_) {
f_service_ <<
indent() << "return d.callback(None)" << endl;
- } else if (gen_tornado_) {
- f_service_ <<
- indent() << "callback(None)" << endl <<
- indent() << "return" << endl;
} else {
f_service_ <<
indent() << "return" << endl;
@@ -1518,10 +1474,6 @@
if (gen_twisted_) {
f_service_ <<
indent() << "return d.errback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl;
- } else if (gen_tornado_) {
- f_service_ <<
- indent() << "callback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl <<
- indent() << "return" << endl;
} else {
f_service_ <<
indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl;
@@ -1785,22 +1737,9 @@
f_service_ << endl;
// Generate the server implementation
- if (gen_tornado_) {
- f_service_ <<
- indent() << "@gen.engine" << endl <<
- indent() << "def process(self, transport, iprot_factory, oprot, callback):" << endl;
- indent_up();
- f_service_ <<
- indent() << "# wait for a frame header" << endl <<
- indent() << "frame = yield gen.Task(transport.readFrame)" << endl <<
- indent() << "tr = TTransport.TMemoryBuffer(frame)" << endl <<
- indent() << "iprot = iprot_factory.getProtocol(tr)" << endl <<
- endl;
- } else {
- f_service_ <<
- indent() << "def process(self, iprot, oprot):" << endl;
- indent_up();
- }
+ f_service_ <<
+ indent() << "def process(self, iprot, oprot):" << endl;
+ indent_up();
f_service_ <<
indent() << "(name, type, seqid) = iprot.readMessageBegin()" << endl;
@@ -1821,8 +1760,6 @@
if (gen_twisted_) {
f_service_ <<
indent() << " return defer.succeed(None)" << endl;
- } else if (gen_tornado_) {
- // nothing
} else {
f_service_ <<
indent() << " return" << endl;
@@ -1831,13 +1768,9 @@
f_service_ <<
indent() << "else:" << endl;
- if (gen_twisted_) {
+ if (gen_twisted_ || gen_tornado_) {
f_service_ <<
indent() << " return self._processMap[name](self, seqid, iprot, oprot)" << endl;
- } else if (gen_tornado_) {
- f_service_ <<
- indent() << " yield gen.Task(self._processMap[name], self, seqid, iprot, oprot)" << endl <<
- indent() << "callback()" << endl;
} else {
f_service_ <<
indent() << " self._processMap[name](self, seqid, iprot, oprot)" << endl;
@@ -1870,9 +1803,9 @@
// Open function
if (gen_tornado_) {
f_service_ <<
- indent() << "@gen.engine" << endl <<
+ indent() << "@gen.coroutine" << endl <<
indent() << "def process_" << tfunction->get_name() <<
- "(self, seqid, iprot, oprot, callback):" << endl;
+ "(self, seqid, iprot, oprot):" << endl;
} else {
f_service_ <<
indent() << "def process_" << tfunction->get_name() <<
@@ -1996,6 +1929,7 @@
}
} else if (gen_tornado_) {
+ /*
if (!tfunction->is_oneway() && xceptions.size() > 0) {
f_service_ <<
endl <<
@@ -2014,21 +1948,27 @@
f_service_ <<
endl <<
- indent() << "with stack_context.ExceptionStackContext(handle_exception):" << endl;
+ indent() << "try:" << endl;
indent_up();
}
+ */
// Generate the function call
t_struct* arg_struct = tfunction->get_arglist();
const std::vector<t_field*>& fields = arg_struct->get_members();
vector<t_field*>::const_iterator f_iter;
+ if (xceptions.size() > 0) {
+ f_service_ <<
+ indent() << "try:" << endl;
+ indent_up();
+ }
f_service_ << indent();
if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
f_service_ << "result.success = ";
}
f_service_ <<
- "yield gen.Task(self._handler." << tfunction->get_name() << ", ";
+ "yield gen.maybe_future(self._handler." << tfunction->get_name() << "(";
bool first = true;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if (first) {
@@ -2038,22 +1978,27 @@
}
f_service_ << "args." << (*f_iter)->get_name();
}
- f_service_ << ")" << endl;
-
- if (xceptions.size() > 0) {
- f_service_ << endl;
- }
+ f_service_ << "))" << endl;
if (!tfunction->is_oneway() && xceptions.size() > 0) {
indent_down();
+ for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
+ f_service_ <<
+ indent() << "except " << type_name((*x_iter)->get_type()) << ", " << (*x_iter)->get_name() << ":" << endl;
+ if (!tfunction->is_oneway()) {
+ indent_up();
+ f_service_ <<
+ indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << endl;
+ indent_down();
+ } else {
+ f_service_ <<
+ indent() << "pass" << endl;
+ }
+ }
}
// Shortcut out here for oneway functions
if (tfunction->is_oneway()) {
- f_service_ <<
- indent() << "callback()" << endl;
- indent_down();
- f_service_ << endl;
return;
}
@@ -2061,8 +2006,7 @@
indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << "\", TMessageType.REPLY, seqid)" << endl <<
indent() << "result.write(oprot)" << endl <<
indent() << "oprot.writeMessageEnd()" << endl <<
- indent() << "oprot.trans.flush()" << endl <<
- indent() << "callback()" << endl;
+ indent() << "oprot.trans.flush()" << endl;
// Close function
indent_down();
@@ -2621,8 +2565,7 @@
* @return String of rendered function definition
*/
string t_py_generator::function_signature(t_function* tfunction,
- bool interface,
- tornado_callback_t callback) {
+ bool interface) {
vector<string> pre;
vector<string> post;
string signature = tfunction->get_name() + "(";
@@ -2631,22 +2574,6 @@
pre.push_back("self");
}
- if (gen_tornado_) {
- if (callback == NONE) {
- } else if (callback == MANDATORY_FOR_ONEWAY_ELSE_NONE) {
- if (tfunction->is_oneway()) {
- // Tornado send_* carry the callback so you can block on the write's flush
- // (rather than on receipt of the response)
- post.push_back("callback");
- }
- } else if (callback == OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) {
- if (tfunction->is_oneway()) {
- post.push_back("callback=None");
- } else {
- post.push_back("callback");
- }
- }
- }
signature += argument_list(tfunction->get_arglist(), &pre, &post) + ")";
return signature;
}
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index d90f672..c8498c5 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -17,58 +17,91 @@
# under the License.
#
-from cStringIO import StringIO
+from __future__ import absolute_import
import logging
import socket
import struct
-from thrift.transport import TTransport
-from thrift.transport.TTransport import TTransportException
+from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
-from tornado import gen
-from tornado import iostream
-from tornado import tcpserver
+from io import BytesIO
+from collections import deque
+from contextlib import contextmanager
+from tornado import gen, iostream, ioloop, tcpserver, concurrent
+
+__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
-class TTornadoStreamTransport(TTransport.TTransportBase):
+class _Lock(object):
+ def __init__(self):
+ self._waiters = deque()
+
+ def acquired(self):
+ return len(self._waiters) > 0
+
+ @gen.coroutine
+ def acquire(self):
+ blocker = self._waiters[-1] if self.acquired() else None
+ future = concurrent.Future()
+ self._waiters.append(future)
+ if blocker:
+ yield blocker
+
+ raise gen.Return(self._lock_context())
+
+ def release(self):
+ assert self.acquired(), 'Lock not aquired'
+ future = self._waiters.popleft()
+ future.set_result(None)
+
+ @contextmanager
+ def _lock_context(self):
+ try:
+ yield
+ finally:
+ self.release()
+
+
+class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
- def __init__(self, host, port, stream=None):
+ def __init__(self, host, port, stream=None, io_loop=None):
self.host = host
self.port = port
- self.is_queuing_reads = False
- self.read_queue = []
- self.__wbuf = StringIO()
+ self.io_loop = io_loop or ioloop.IOLoop.current()
+ self.__wbuf = BytesIO()
+ self._read_lock = _Lock()
# servers provide a ready-to-go stream
self.stream = stream
- if self.stream is not None:
- self._set_close_callback()
- # not the same number of parameters as TTransportBase.open
- def open(self, callback):
+ def with_timeout(self, timeout, future):
+ return gen.with_timeout(timeout, future, self.io_loop)
+
+ @gen.coroutine
+ def open(self, timeout=None):
logging.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
- def on_close_in_connect(*_):
- message = 'could not connect to {}:{}'.format(self.host, self.port)
+ try:
+ connect = self.stream.connect((self.host, self.port))
+ if timeout is not None:
+ yield self.with_timeout(timeout, connect)
+ else:
+ yield connect
+ except (socket.error, IOError, ioloop.TimeoutError) as e:
+ message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
- self.stream.set_close_callback(on_close_in_connect)
- def finish(*_):
- self._set_close_callback()
- callback()
+ raise gen.Return(self)
- self.stream.connect((self.host, self.port), callback=finish)
-
- def _set_close_callback(self):
- def on_close():
- raise TTransportException(
- type=TTransportException.END_OF_FILE,
- message='socket closed')
- self.stream.set_close_callback(self.close)
+ def set_close_callback(self, callback):
+ """
+ Should be called only after open() returns
+ """
+ self.stream.set_close_callback(callback)
def close(self):
# don't raise if we intend to close
@@ -78,51 +111,45 @@
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
- assert "you're doing it wrong" is True
+ assert False, "you're doing it wrong"
- @gen.engine
- def readFrame(self, callback):
- self.read_queue.append(callback)
- logging.debug('read queue: %s', self.read_queue)
+ @contextmanager
+ def io_exception_context(self):
+ try:
+ yield
+ except (socket.error, IOError) as e:
+ raise TTransportException(
+ type=TTransportException.END_OF_FILE,
+ message=str(e))
+ except iostream.StreamBufferFullError as e:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message=str(e))
- if self.is_queuing_reads:
- # If a read is already in flight, then the while loop below should
- # pull it from self.read_queue
- return
-
- self.is_queuing_reads = True
- while self.read_queue:
- next_callback = self.read_queue.pop()
- result = yield gen.Task(self._readFrameFromStream)
- next_callback(result)
- self.is_queuing_reads = False
-
- @gen.engine
- def _readFrameFromStream(self, callback):
- logging.debug('_readFrameFromStream')
- frame_header = yield gen.Task(self.stream.read_bytes, 4)
- frame_length, = struct.unpack('!i', frame_header)
- logging.debug('received frame header, frame length = %i', frame_length)
- frame = yield gen.Task(self.stream.read_bytes, frame_length)
- logging.debug('received frame payload')
- callback(frame)
+ @gen.coroutine
+ def readFrame(self):
+ # IOStream processes reads one at a time
+ with (yield self._read_lock.acquire()):
+ with self.io_exception_context():
+ frame_header = yield self.stream.read_bytes(4)
+ if len(frame_header) == 0:
+ raise iostream.StreamClosedError('Read zero bytes from stream')
+ frame_length, = struct.unpack('!i', frame_header)
+ logging.debug('received frame header, frame length = %d', frame_length)
+ frame = yield self.stream.read_bytes(frame_length)
+ logging.debug('received frame payload: %r', frame)
+ raise gen.Return(frame)
def write(self, buf):
self.__wbuf.write(buf)
- def flush(self, callback=None):
- wout = self.__wbuf.getvalue()
- wsz = len(wout)
+ def flush(self):
+ frame = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = StringIO()
- # N.B.: Doing this string concatenation is WAY cheaper than making
- # two separate calls to the underlying socket object. Socket writes in
- # Python turn out to be REALLY expensive, but it seems to do a pretty
- # good job of managing string buffer operations without excessive copies
- buf = struct.pack("!i", wsz) + wout
-
- logging.debug('writing frame length = %i', wsz)
- self.stream.write(buf, callback)
+ frame_length = struct.pack('!i', len(frame))
+ self.__wbuf = BytesIO()
+ with self.io_exception_context():
+ return self.stream.write(frame_length + frame)
class TTornadoServer(tcpserver.TCPServer):
@@ -135,19 +162,21 @@
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
+ @gen.coroutine
def handle_stream(self, stream, address):
+ host, port = address
+ trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
+ io_loop=self.io_loop)
+ oprot = self._oprot_factory.getProtocol(trans)
+
try:
- host, port = address
- trans = TTornadoStreamTransport(host=host, port=port, stream=stream)
- oprot = self._oprot_factory.getProtocol(trans)
-
- def next_pass():
- if not trans.stream.closed():
- self._processor.process(trans, self._iprot_factory, oprot,
- callback=next_pass)
-
- next_pass()
-
+ while not trans.stream.closed():
+ frame = yield trans.readFrame()
+ tr = TMemoryBuffer(frame)
+ iprot = self._iprot_factory.getProtocol(tr)
+ yield self._processor.process(iprot, oprot)
except Exception:
logging.exception('thrift exception in handle_stream')
trans.close()
+
+ logging.info('client disconnected %s:%d', host, port)
diff --git a/test/py.tornado/test_suite.py b/test/py.tornado/test_suite.py
index f04ba04..c783962 100755
--- a/test/py.tornado/test_suite.py
+++ b/test/py.tornado/test_suite.py
@@ -22,11 +22,13 @@
import datetime
import glob
import sys
+import os
import time
import unittest
-sys.path.insert(0, './gen-py.tornado')
-sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+basepath = os.path.abspath(os.path.dirname(__file__))
+sys.path.insert(0, basepath+'/gen-py.tornado')
+sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0])
try:
__import__('tornado')
@@ -34,11 +36,12 @@
print "module `tornado` not found, skipping test"
sys.exit(0)
-from tornado import gen, ioloop, stack_context
-from tornado.testing import AsyncTestCase, get_unused_port
+from tornado import gen
+from tornado.testing import AsyncTestCase, get_unused_port, gen_test
from thrift import TTornado
from thrift.protocol import TBinaryProtocol
+from thrift.transport.TTransport import TTransportException
from ThriftTest import ThriftTest
from ThriftTest.ttypes import *
@@ -48,31 +51,31 @@
def __init__(self, test_instance):
self.test_instance = test_instance
- def testVoid(self, callback):
- callback()
+ def testVoid(self):
+ pass
- def testString(self, s, callback):
- callback(s)
+ def testString(self, s):
+ return s
- def testByte(self, b, callback):
- callback(b)
+ def testByte(self, b):
+ return b
- def testI16(self, i16, callback):
- callback(i16)
+ def testI16(self, i16):
+ return i16
- def testI32(self, i32, callback):
- callback(i32)
+ def testI32(self, i32):
+ return i32
- def testI64(self, i64, callback):
- callback(i64)
+ def testI64(self, i64):
+ return i64
- def testDouble(self, dub, callback):
- callback(dub)
+ def testDouble(self, dub):
+ return dub
- def testStruct(self, thing, callback):
- callback(thing)
+ def testStruct(self, thing):
+ return thing
- def testException(self, s, callback):
+ def testException(self, s):
if s == 'Xception':
x = Xception()
x.errorCode = 1001
@@ -80,133 +83,139 @@
raise x
elif s == 'throw_undeclared':
raise ValueError("foo")
- callback()
- def testOneway(self, seconds, callback=None):
+ def testOneway(self, seconds):
start = time.time()
+
def fire_oneway():
end = time.time()
self.test_instance.stop((start, end, seconds))
- ioloop.IOLoop.instance().add_timeout(
+ self.test_instance.io_loop.add_timeout(
datetime.timedelta(seconds=seconds),
fire_oneway)
- if callback:
- callback()
+ def testNest(self, thing):
+ return thing
- def testNest(self, thing, callback):
- callback(thing)
+ @gen.coroutine
+ def testMap(self, thing):
+ yield gen.moment
+ raise gen.Return(thing)
- def testMap(self, thing, callback):
- callback(thing)
+ def testSet(self, thing):
+ return thing
- def testSet(self, thing, callback):
- callback(thing)
+ def testList(self, thing):
+ return thing
- def testList(self, thing, callback):
- callback(thing)
+ def testEnum(self, thing):
+ return thing
- def testEnum(self, thing, callback):
- callback(thing)
-
- def testTypedef(self, thing, callback):
- callback(thing)
+ def testTypedef(self, thing):
+ return thing
class ThriftTestCase(AsyncTestCase):
- def get_new_ioloop(self):
- return ioloop.IOLoop.instance()
-
def setUp(self):
+ super(ThriftTestCase, self).setUp()
+
self.port = get_unused_port()
- self.io_loop = self.get_new_ioloop()
# server
self.handler = TestHandler(self)
self.processor = ThriftTest.Processor(self.handler)
self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
- self.server = TTornado.TTornadoServer(self.processor, self.pfactory)
+ self.server = TTornado.TTornadoServer(self.processor, self.pfactory, io_loop=self.io_loop)
self.server.bind(self.port)
self.server.start(1)
# client
- transport = TTornado.TTornadoStreamTransport('localhost', self.port)
+ transport = TTornado.TTornadoStreamTransport('localhost', self.port, io_loop=self.io_loop)
pfactory = TBinaryProtocol.TBinaryProtocolFactory()
+ self.io_loop.run_sync(transport.open)
self.client = ThriftTest.Client(transport, pfactory)
- transport.open(callback=self.stop)
- self.wait(timeout=1)
+ @gen_test
def test_void(self):
- self.client.testVoid(callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, None)
+ v = yield self.client.testVoid()
+ self.assertEqual(v, None)
+ @gen_test
def test_string(self):
- self.client.testString('Python', callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, 'Python')
+ v = yield self.client.testString('Python')
+ self.assertEqual(v, 'Python')
+ @gen_test
def test_byte(self):
- self.client.testByte(63, callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, 63)
+ v = yield self.client.testByte(63)
+ self.assertEqual(v, 63)
+ @gen_test
def test_i32(self):
- self.client.testI32(-1, callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, -1)
+ v = yield self.client.testI32(-1)
+ self.assertEqual(v, -1)
- self.client.testI32(0, callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, 0)
+ v = yield self.client.testI32(0)
+ self.assertEqual(v, 0)
+ @gen_test
def test_i64(self):
- self.client.testI64(-34359738368, callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, -34359738368)
+ v = yield self.client.testI64(-34359738368)
+ self.assertEqual(v, -34359738368)
+ @gen_test
def test_double(self):
- self.client.testDouble(-5.235098235, callback=self.stop)
- v = self.wait(timeout=1)
- self.assertEquals(v, -5.235098235)
+ v = yield self.client.testDouble(-5.235098235)
+ self.assertEqual(v, -5.235098235)
+ @gen_test
def test_struct(self):
x = Xtruct()
x.string_thing = "Zero"
x.byte_thing = 1
x.i32_thing = -3
x.i64_thing = -5
- self.client.testStruct(x, callback=self.stop)
+ y = yield self.client.testStruct(x)
- y = self.wait(timeout=1)
- self.assertEquals(y.string_thing, "Zero")
- self.assertEquals(y.byte_thing, 1)
- self.assertEquals(y.i32_thing, -3)
- self.assertEquals(y.i64_thing, -5)
-
- def test_exception(self):
- self.client.testException('Safe', callback=self.stop)
- v = self.wait(timeout=1)
-
- self.client.testException('Xception', callback=self.stop)
- ex = self.wait(timeout=1)
- if type(ex) == Xception:
- self.assertEquals(ex.errorCode, 1001)
- self.assertEquals(ex.message, 'Xception')
- else:
- self.fail("should have gotten exception")
+ self.assertEqual(y.string_thing, "Zero")
+ self.assertEqual(y.byte_thing, 1)
+ self.assertEqual(y.i32_thing, -3)
+ self.assertEqual(y.i64_thing, -5)
def test_oneway(self):
- def return_from_send():
- self.stop('done with send')
- self.client.testOneway(0.5, callback=return_from_send)
- self.assertEquals(self.wait(timeout=1), 'done with send')
-
+ self.client.testOneway(0.5)
start, end, seconds = self.wait(timeout=1)
self.assertAlmostEquals(seconds, (end - start), places=3)
+ @gen_test
+ def test_map(self):
+ """
+ TestHandler.testMap is a coroutine, this test checks if gen.Return() from a coroutine works.
+ """
+ expected = {1: 1}
+ res = yield self.client.testMap(expected)
+ self.assertEqual(res, expected)
+
+ @gen_test
+ def test_exception(self):
+ yield self.client.testException('Safe')
+
+ try:
+ yield self.client.testException('Xception')
+ except Xception as ex:
+ self.assertEqual(ex.errorCode, 1001)
+ self.assertEqual(ex.message, 'Xception')
+ else:
+ self.fail("should have gotten exception")
+ try:
+ yield self.client.testException('throw_undeclared')
+ except TTransportException as ex:
+ pass
+ else:
+ self.fail("should have gotten exception")
+
def suite():
suite = unittest.TestSuite()