THRIFT-1094. py: bug in TCompactProto python readMessageEnd method and updated test cases


This patch fixes a TCompactProtocol bug and expands the test cases to exercise the problem.

Patch: Will Pierce

git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1083877 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 7ff0798..280b54f 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -291,9 +291,8 @@
     return (name, type, seqid)
 
   def readMessageEnd(self):
-    assert self.state == VALUE_READ
+    assert self.state == CLEAR
     assert len(self.__structs) == 0
-    self.state = CLEAR
 
   def readStructBegin(self):
     assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index 2bd6094..dced91a 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -19,42 +19,93 @@
 # under the License.
 #
 
+from __future__ import division
 import time
 import subprocess
 import sys
 import os
 import signal
+from optparse import OptionParser
+
+parser = OptionParser()
+parser.add_option("--port", type="int", dest="port", default=9090,
+    help="port number for server to listen on")
+options, args = parser.parse_args()
+
+FRAMED = ["TNonblockingServer"]
+EXTRA_DELAY = ['TProcessPoolServer']
+EXTRA_SLEEP = 3.5
+
+PROTOS= [
+    'accel',
+    'binary',
+    'compact' ]
+
+SERVERS = [
+  "TSimpleServer",
+  "TThreadedServer",
+  "TThreadPoolServer",
+  "TProcessPoolServer", # new!
+  "TForkingServer",
+  "TNonblockingServer",
+  "THttpServer" ]
+
+# Test for presence of multiprocessing module, and if it is not present, then
+# remove it from the list of available servers.
+try:
+  import multiprocessing
+except:
+  print 'Warning: the multiprocessing module is unavailable. Skipping tests for TProcessPoolServer'
+  SERVERS.remove('TProcessPoolServer')
+
+
+# commandline permits a single class name to be specified to override SERVERS=[...]
+if len(args) == 1:
+  if args[0] in SERVERS:
+    SERVERS = args
+  else:
+    print 'Unavailable server type "%s", please choose one of: %s' % (args[0], SERVERS)
+    sys.exit(0)
+
 
 def relfile(fname):
     return os.path.join(os.path.dirname(__file__), fname)
 
-FRAMED = ["TNonblockingServer"]
-
-def runTest(server_class):
-    print "Testing ", server_class
-    serverproc = subprocess.Popen([sys.executable, relfile("TestServer.py"), server_class])
+def runTest(server_class, proto, port):
+    server_args = [sys.executable, # /usr/bin/python or similar
+      relfile('TestServer.py'), # ./TestServer.py
+      '--proto=%s' % proto, # accel, binary or compact
+      '--port=%d' % port, # usually 9090, given on cmdline
+      server_class] # name of class to test, from SERVERS[] or cmdline
+    print "Testing server %s: %s" % (server_class, ' '.join(server_args))
+    serverproc = subprocess.Popen(server_args)
     time.sleep(0.25)
     try:
-        argv = [sys.executable, relfile("TestClient.py")]
+        argv = [sys.executable, relfile("TestClient.py"),
+           '--proto=%s' % (proto), '--port=%d' % (port) ]
         if server_class in FRAMED:
             argv.append('--framed')
         if server_class == 'THttpServer':
             argv.append('--http=/')
+        print 'Testing client %s: %s' % (server_class, ' '.join(argv))
         ret = subprocess.call(argv)
         if ret != 0:
-            raise Exception("subprocess failed")
+            raise Exception("subprocess %s failed, args: %s" % (server_class, ' '.join(argv)))
     finally:
-        # fixme: should check that server didn't die
-        os.kill(serverproc.pid, signal.SIGKILL)
-
+        # check that server didn't die
+        time.sleep(0.05)
+        serverproc.poll()
+        if serverproc.returncode is not None:
+          print 'Server process (%s) failed with retcode %d' % (' '.join(server_args), serverproc.returncode)
+          raise Exception('subprocess %s died, args: %s' % (server_class, ' '.join(server_args)))
+        else:
+          if server_class in EXTRA_DELAY:
+            print 'Giving %s (proto=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class, proto, EXTRA_SLEEP)
+            time.sleep(EXTRA_SLEEP)
+          os.kill(serverproc.pid, signal.SIGKILL)
     # wait for shutdown
-    time.sleep(1)
+    time.sleep(0.5)
 
-map(runTest, [
-  "TSimpleServer",
-  "TThreadedServer",
-  "TThreadPoolServer",
-  "TForkingServer",
-  "TNonblockingServer",
-  "THttpServer",
-  ])
+for try_server in SERVERS:
+  for try_proto in PROTOS:
+    runTest(try_server, try_proto, options.port)
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 0a38b03..eecb850 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -29,13 +29,13 @@
 from thrift.transport import TSocket
 from thrift.transport import THttpClient
 from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
 import unittest
 import time
 from optparse import OptionParser
 
 
 parser = OptionParser()
-parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090)
 parser.add_option("--port", type="int", dest="port",
     help="connect to server at port")
 parser.add_option("--host", type="string", dest="host",
@@ -50,7 +50,9 @@
 parser.add_option('-q', '--quiet', action="store_const", 
     dest="verbose", const=0,
     help="minimal output")
-
+parser.add_option('--proto',  dest="proto", type="string",
+    help="protocol to use, one of: accel, binary, compact")
+parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
 options, args = parser.parse_args()
 
 class AbstractTest(unittest.TestCase):
@@ -81,19 +83,24 @@
 
   def testString(self):
     self.assertEqual(self.client.testString('Python'), 'Python')
+    self.assertEqual(self.client.testString(''), '')
 
   def testByte(self):
     self.assertEqual(self.client.testByte(63), 63)
+    self.assertEqual(self.client.testByte(-127), -127)
 
   def testI32(self):
     self.assertEqual(self.client.testI32(-1), -1)
     self.assertEqual(self.client.testI32(0), 0)
 
   def testI64(self):
+    self.assertEqual(self.client.testI64(1), 1)
     self.assertEqual(self.client.testI64(-34359738368), -34359738368)
 
   def testDouble(self):
     self.assertEqual(self.client.testDouble(-5.235098235), -5.235098235)
+    self.assertEqual(self.client.testDouble(0), 0)
+    self.assertEqual(self.client.testDouble(-1), -1)
 
   def testStruct(self):
     x = Xtruct()
@@ -102,11 +109,57 @@
     x.i32_thing = -3
     x.i64_thing = -5
     y = self.client.testStruct(x)
+    self.assertEqual(y, x)
 
-    self.assertEqual(y.string_thing, "Zero")
-    self.assertEqual(y.byte_thing, 1)
-    self.assertEqual(y.i32_thing, -3)
-    self.assertEqual(y.i64_thing, -5)
+  def testNest(self):
+    inner = Xtruct(string_thing="Zero", byte_thing=1, i32_thing=-3,
+      i64_thing=-5)
+    x = Xtruct2(struct_thing=inner)
+    y = self.client.testNest(x)
+    self.assertEqual(y, x)
+
+  def testMap(self):
+    x = {0:1, 1:2, 2:3, 3:4, -1:-2}
+    y = self.client.testMap(x)
+    self.assertEqual(y, x)
+
+  def testSet(self):
+    x = set([8, 1, 42])
+    y = self.client.testSet(x)
+    self.assertEqual(y, x)
+
+  def testList(self):
+    x = [1, 4, 9, -42]
+    y = self.client.testList(x)
+    self.assertEqual(y, x)
+
+  def testEnum(self):
+    x = Numberz.FIVE
+    y = self.client.testEnum(x)
+    self.assertEqual(y, x)
+
+  def testTypedef(self):
+    x = 0xffffffffffffff # 7 bytes of 0xff
+    y = self.client.testTypedef(x)
+    self.assertEqual(y, x)
+
+  def testMapMap(self):
+    # does not work: dict() is not a hashable type, so a dict() cannot be used as a key in another dict()
+    #x = { {1:10, 2:20}, {1:100, 2:200, 3:300}, {1:1000, 2:2000, 3:3000, 4:4000} }
+    try:
+      y = self.client.testMapMap()
+    except:
+      pass
+
+  def testMulti(self):
+    xpected = Xtruct(byte_thing=74, i32_thing=0xff00ff, i64_thing=0xffffffffd0d0)
+    y = self.client.testMulti(xpected.byte_thing,
+          xpected.i32_thing,
+          xpected.i64_thing,
+          { 0:'abc' },
+          Numberz.FIVE,
+          0xf0f0f0)
+    self.assertEqual(y, xpected)
 
   def testException(self):
     self.client.testException('Safe')
@@ -125,27 +178,35 @@
 
   def testOneway(self):
     start = time.time()
-    self.client.testOneway(0.5)
+    self.client.testOneway(1) # type is int, not float
     end = time.time()
-    self.assertTrue(end - start < 0.2,
+    self.assertTrue(end - start < 3,
                     "oneway sleep took %f sec" % (end - start))
   
   def testOnewayThenNormal(self):
-    self.client.testOneway(0.5)
+    self.client.testOneway(1) # type is int, not float
     self.assertEqual(self.client.testString('Python'), 'Python')
 
 class NormalBinaryTest(AbstractTest):
   protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
 
+class CompactTest(AbstractTest):
+  protocol_factory = TCompactProtocol.TCompactProtocolFactory()
+
 class AcceleratedBinaryTest(AbstractTest):
   protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
 
 def suite():
   suite = unittest.TestSuite()
   loader = unittest.TestLoader()
-
-  suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
-  suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
+  if options.proto == 'binary': # look for --proto on cmdline
+    suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
+  elif options.proto == 'accel':
+    suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
+  elif options.proto == 'compact':
+    suite.addTest(loader.loadTestsFromTestCase(CompactTest))
+  else:
+    raise AssertionError('Unknown protocol given with --proto: %s' % options.proto)
   return suite
 
 class OwnArgsTestProgram(unittest.TestProgram):
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index 581bed6..99d925a 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -18,53 +18,75 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
+from __future__ import division
 import sys, glob, time
 sys.path.insert(0, './gen-py')
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+from optparse import OptionParser
 
 from ThriftTest import ThriftTest
 from ThriftTest.ttypes import *
 from thrift.transport import TTransport
 from thrift.transport import TSocket
 from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
 from thrift.server import TServer, TNonblockingServer, THttpServer
 
+parser = OptionParser()
+parser.set_defaults(port=9090, verbose=1, proto='binary')
+parser.add_option("--port", type="int", dest="port",
+    help="port number for server to listen on")
+parser.add_option('-v', '--verbose', action="store_const", 
+    dest="verbose", const=2,
+    help="verbose output")
+parser.add_option('--proto',  dest="proto", type="string",
+    help="protocol to use, one of: accel, binary, compact")
+options, args = parser.parse_args()
+
 class TestHandler:
 
   def testVoid(self):
-    print 'testVoid()'
+    if options.verbose:
+      print 'testVoid()'
 
   def testString(self, str):
-    print 'testString(%s)' % str
+    if options.verbose:
+      print 'testString(%s)' % str
     return str
 
   def testByte(self, byte):
-    print 'testByte(%d)' % byte
+    if options.verbose:
+      print 'testByte(%d)' % byte
     return byte
 
   def testI16(self, i16):
-    print 'testI16(%d)' % i16
+    if options.verbose:
+      print 'testI16(%d)' % i16
     return i16
 
   def testI32(self, i32):
-    print 'testI32(%d)' % i32
+    if options.verbose:
+      print 'testI32(%d)' % i32
     return i32
 
   def testI64(self, i64):
-    print 'testI64(%d)' % i64
+    if options.verbose:
+      print 'testI64(%d)' % i64
     return i64
 
   def testDouble(self, dub):
-    print 'testDouble(%f)' % dub
+    if options.verbose:
+      print 'testDouble(%f)' % dub
     return dub
 
   def testStruct(self, thing):
-    print 'testStruct({%s, %d, %d, %d})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing)
+    if options.verbose:
+      print 'testStruct({%s, %d, %d, %d})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing)
     return thing
 
   def testException(self, str):
-    print 'testException(%s)' % str
+    if options.verbose:
+      print 'testException(%s)' % str
     if str == 'Xception':
       x = Xception()
       x.errorCode = 1001
@@ -74,43 +96,90 @@
       raise ValueError("foo")
 
   def testOneway(self, seconds):
-    print 'testOneway(%d) => sleeping...' % seconds
-    time.sleep(seconds)
-    print 'done sleeping'
+    if options.verbose:
+      print 'testOneway(%d) => sleeping...' % seconds
+    time.sleep(seconds / 3) # be quick
+    if options.verbose:
+      print 'done sleeping'
 
   def testNest(self, thing):
+    if options.verbose:
+      print 'testNest(%s)' % thing
     return thing
 
   def testMap(self, thing):
+    if options.verbose:
+      print 'testMap(%s)' % thing
     return thing
 
   def testSet(self, thing):
+    if options.verbose:
+      print 'testSet(%s)' % thing
     return thing
 
   def testList(self, thing):
+    if options.verbose:
+      print 'testList(%s)' % thing
     return thing
 
   def testEnum(self, thing):
+    if options.verbose:
+      print 'testEnum(%s)' % thing
     return thing
 
   def testTypedef(self, thing):
+    if options.verbose:
+      print 'testTypedef(%s)' % thing
     return thing
 
-pfactory = TBinaryProtocol.TBinaryProtocolFactory()
+  def testMapMap(self, thing):
+    if options.verbose:
+      print 'testMapMap(%s)' % thing
+    return thing
+
+  def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5):
+    if options.verbose:
+      print 'testMulti(%s)' % [arg0, arg1, arg2, arg3, arg4, arg5]
+    x = Xtruct(byte_thing=arg0, i32_thing=arg1, i64_thing=arg2)
+    return x
+
+if options.proto == 'binary':
+  pfactory = TBinaryProtocol.TBinaryProtocolFactory()
+elif options.proto == 'accel':
+  pfactory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
+elif options.proto == 'compact':
+  pfactory = TCompactProtocol.TCompactProtocolFactory()
+else:
+  raise AssertionError('Unknown --proto option: %s' % options.proto)
 handler = TestHandler()
 processor = ThriftTest.Processor(handler)
 
-if sys.argv[1] == "THttpServer":
-  server = THttpServer.THttpServer(processor, ('', 9090), pfactory)
+if args[0] == "THttpServer":
+  server = THttpServer.THttpServer(processor, ('', options.port), pfactory)
 else:
   host = None
-  transport = TSocket.TServerSocket(host, 9090)
+  transport = TSocket.TServerSocket(host, options.port)
   tfactory = TTransport.TBufferedTransportFactory()
 
-  if sys.argv[1] == "TNonblockingServer":
-    server = TNonblockingServer.TNonblockingServer(processor, transport)
+  if args[0] == "TNonblockingServer":
+    server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory)
+  elif args[0] == "TProcessPoolServer":
+    import signal
+    def set_alarm():
+      def clean_shutdown(signum, frame):
+        for worker in server.workers:
+          print 'Terminating worker: %s' % worker
+          worker.terminate()
+        print 'Requesting server to stop()'
+        server.stop()
+      signal.signal(signal.SIGALRM, clean_shutdown)
+      signal.alarm(2)
+    from thrift.server import TProcessPoolServer
+    server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory)
+    server.setNumWorkers(5)
+    set_alarm()
   else:
-    ServerClass = getattr(TServer, sys.argv[1])
+    ServerClass = getattr(TServer, args[0])
     server = ServerClass(processor, transport, tfactory, pfactory)
 
 server.serve()