THRIFT-2231 Support tornado-4.x (Python)

Client: Python
Patch: Roey Berman
Signed-off-by: Roger Meier <roger@apache.org>
diff --git a/test/py.tornado/test_suite.py b/test/py.tornado/test_suite.py
index f04ba04..c783962 100755
--- a/test/py.tornado/test_suite.py
+++ b/test/py.tornado/test_suite.py
@@ -22,11 +22,13 @@
 import datetime
 import glob
 import sys
+import os
 import time
 import unittest
 
-sys.path.insert(0, './gen-py.tornado')
-sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+basepath = os.path.abspath(os.path.dirname(__file__))
+sys.path.insert(0, basepath+'/gen-py.tornado')
+sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0])
 
 try:
     __import__('tornado')
@@ -34,11 +36,12 @@
     print "module `tornado` not found, skipping test"
     sys.exit(0)
 
-from tornado import gen, ioloop, stack_context
-from tornado.testing import AsyncTestCase, get_unused_port
+from tornado import gen
+from tornado.testing import AsyncTestCase, get_unused_port, gen_test
 
 from thrift import TTornado
 from thrift.protocol import TBinaryProtocol
+from thrift.transport.TTransport import TTransportException
 
 from ThriftTest import ThriftTest
 from ThriftTest.ttypes import *
@@ -48,31 +51,31 @@
     def __init__(self, test_instance):
         self.test_instance = test_instance
 
-    def testVoid(self, callback):
-        callback()
+    def testVoid(self):
+        pass
 
-    def testString(self, s, callback):
-        callback(s)
+    def testString(self, s):
+        return s
 
-    def testByte(self, b, callback):
-        callback(b)
+    def testByte(self, b):
+        return b
 
-    def testI16(self, i16, callback):
-        callback(i16)
+    def testI16(self, i16):
+        return i16
 
-    def testI32(self, i32, callback):
-        callback(i32)
+    def testI32(self, i32):
+        return i32
 
-    def testI64(self, i64, callback):
-        callback(i64)
+    def testI64(self, i64):
+        return i64
 
-    def testDouble(self, dub, callback):
-        callback(dub)
+    def testDouble(self, dub):
+        return dub
 
-    def testStruct(self, thing, callback):
-        callback(thing)
+    def testStruct(self, thing):
+        return thing
 
-    def testException(self, s, callback):
+    def testException(self, s):
         if s == 'Xception':
             x = Xception()
             x.errorCode = 1001
@@ -80,133 +83,139 @@
             raise x
         elif s == 'throw_undeclared':
             raise ValueError("foo")
-        callback()
 
-    def testOneway(self, seconds, callback=None):
+    def testOneway(self, seconds):
         start = time.time()
+
         def fire_oneway():
             end = time.time()
             self.test_instance.stop((start, end, seconds))
 
-        ioloop.IOLoop.instance().add_timeout(
+        self.test_instance.io_loop.add_timeout(
             datetime.timedelta(seconds=seconds),
             fire_oneway)
 
-        if callback:
-            callback()
+    def testNest(self, thing):
+        return thing
 
-    def testNest(self, thing, callback):
-        callback(thing)
+    @gen.coroutine
+    def testMap(self, thing):
+        yield gen.moment
+        raise gen.Return(thing)
 
-    def testMap(self, thing, callback):
-        callback(thing)
+    def testSet(self, thing):
+        return thing
 
-    def testSet(self, thing, callback):
-        callback(thing)
+    def testList(self, thing):
+        return thing
 
-    def testList(self, thing, callback):
-        callback(thing)
+    def testEnum(self, thing):
+        return thing
 
-    def testEnum(self, thing, callback):
-        callback(thing)
-
-    def testTypedef(self, thing, callback):
-        callback(thing)
+    def testTypedef(self, thing):
+        return thing
 
 
 class ThriftTestCase(AsyncTestCase):
-    def get_new_ioloop(self):
-        return ioloop.IOLoop.instance()
-
     def setUp(self):
+        super(ThriftTestCase, self).setUp()
+
         self.port = get_unused_port()
-        self.io_loop = self.get_new_ioloop()
 
         # server
         self.handler = TestHandler(self)
         self.processor = ThriftTest.Processor(self.handler)
         self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
 
-        self.server = TTornado.TTornadoServer(self.processor, self.pfactory)
+        self.server = TTornado.TTornadoServer(self.processor, self.pfactory, io_loop=self.io_loop)
         self.server.bind(self.port)
         self.server.start(1)
 
         # client
-        transport = TTornado.TTornadoStreamTransport('localhost', self.port)
+        transport = TTornado.TTornadoStreamTransport('localhost', self.port, io_loop=self.io_loop)
         pfactory = TBinaryProtocol.TBinaryProtocolFactory()
+        self.io_loop.run_sync(transport.open)
         self.client = ThriftTest.Client(transport, pfactory)
-        transport.open(callback=self.stop)
-        self.wait(timeout=1)
 
+    @gen_test
     def test_void(self):
-        self.client.testVoid(callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, None)
+        v = yield self.client.testVoid()
+        self.assertEqual(v, None)
 
+    @gen_test
     def test_string(self):
-        self.client.testString('Python', callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, 'Python')
+        v = yield self.client.testString('Python')
+        self.assertEqual(v, 'Python')
 
+    @gen_test
     def test_byte(self):
-        self.client.testByte(63, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, 63)
+        v = yield self.client.testByte(63)
+        self.assertEqual(v, 63)
 
+    @gen_test
     def test_i32(self):
-        self.client.testI32(-1, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, -1)
+        v = yield self.client.testI32(-1)
+        self.assertEqual(v, -1)
 
-        self.client.testI32(0, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, 0)
+        v = yield self.client.testI32(0)
+        self.assertEqual(v, 0)
 
+    @gen_test
     def test_i64(self):
-        self.client.testI64(-34359738368, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, -34359738368)
+        v = yield self.client.testI64(-34359738368)
+        self.assertEqual(v, -34359738368)
 
+    @gen_test
     def test_double(self):
-        self.client.testDouble(-5.235098235, callback=self.stop)
-        v = self.wait(timeout=1)
-        self.assertEquals(v, -5.235098235)
+        v = yield self.client.testDouble(-5.235098235)
+        self.assertEqual(v, -5.235098235)
 
+    @gen_test
     def test_struct(self):
         x = Xtruct()
         x.string_thing = "Zero"
         x.byte_thing = 1
         x.i32_thing = -3
         x.i64_thing = -5
-        self.client.testStruct(x, callback=self.stop)
+        y = yield self.client.testStruct(x)
 
-        y = self.wait(timeout=1)
-        self.assertEquals(y.string_thing, "Zero")
-        self.assertEquals(y.byte_thing, 1)
-        self.assertEquals(y.i32_thing, -3)
-        self.assertEquals(y.i64_thing, -5)
-
-    def test_exception(self):
-        self.client.testException('Safe', callback=self.stop)
-        v = self.wait(timeout=1)
-
-        self.client.testException('Xception', callback=self.stop)
-        ex = self.wait(timeout=1)
-        if type(ex) == Xception:
-            self.assertEquals(ex.errorCode, 1001)
-            self.assertEquals(ex.message, 'Xception')
-        else:
-            self.fail("should have gotten exception")
+        self.assertEqual(y.string_thing, "Zero")
+        self.assertEqual(y.byte_thing, 1)
+        self.assertEqual(y.i32_thing, -3)
+        self.assertEqual(y.i64_thing, -5)
 
     def test_oneway(self):
-        def return_from_send():
-            self.stop('done with send')
-        self.client.testOneway(0.5, callback=return_from_send)
-        self.assertEquals(self.wait(timeout=1), 'done with send')
-
+        self.client.testOneway(0.5)
         start, end, seconds = self.wait(timeout=1)
         self.assertAlmostEquals(seconds, (end - start), places=3)
 
+    @gen_test
+    def test_map(self):
+        """
+        TestHandler.testMap is a coroutine, this test checks if gen.Return() from a coroutine works.
+        """
+        expected = {1: 1}
+        res = yield self.client.testMap(expected)
+        self.assertEqual(res, expected)
+
+    @gen_test
+    def test_exception(self):
+        yield self.client.testException('Safe')
+
+        try:
+            yield self.client.testException('Xception')
+        except Xception as ex:
+            self.assertEqual(ex.errorCode, 1001)
+            self.assertEqual(ex.message, 'Xception')
+        else:
+            self.fail("should have gotten exception")
+        try:
+            yield self.client.testException('throw_undeclared')
+        except TTransportException as ex:
+            pass
+        else:
+            self.fail("should have gotten exception")
+
 
 def suite():
     suite = unittest.TestSuite()