THRIFT-4405: Enhance python cross test client for pedantic sequence id handling
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index ddcce8d..cc9185c 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -23,10 +23,15 @@
 import sys
 import time
 import unittest
-from optparse import OptionParser
 
+from optparse import OptionParser
 from util import local_libpath
 
+sys.path.insert(0, local_libpath())
+
+from thrift.protocol import TProtocolDecorator
+from thrift.protocol import TProtocol
+
 SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
 
 
@@ -268,6 +273,43 @@
         self.assertEqual(self.client.testString('Python'), 'Python')
 
 
+# LAST_SEQID is a global because we have one transport and multiple protocols
+# running on it (when multiplexec)
+LAST_SEQID = None
+
+class TPedanticSequenceIdProtocolWrapper(TProtocolDecorator.TProtocolDecorator):
+    """
+    Wraps any protocol with sequence ID checking: looks for outbound
+    uniqueness as well as request/response alignment.
+    """
+    def __init__(self, protocol):
+        # TProtocolDecorator.__new__ does all the heavy lifting
+        pass
+
+    def writeMessageBegin(self, name, type, seqid):
+        global LAST_SEQID
+        if LAST_SEQID and LAST_SEQID == seqid:
+            raise TProtocol.TProtocolException(INVALID_DATA,
+                "Python client reused sequence ID {0}".format(seqid))
+        LAST_SEQID = seqid
+        super(TPedanticSequenceIdProtocolWrapper, self).writeMessageBegin(
+            name, type, seqid)
+
+    def readMessageBegin(self):
+        global LAST_SEQID
+        (name, type, seqid) =\
+            super(TPedanticSequenceIdProtocolWrapper, self).readMessageBegin()
+        if LAST_SEQID != seqid:
+            raise TProtocol.TProtocolException(INVALID_DATA,
+                "We sent seqid {0} and server returned seqid {1}".format(
+                    self.last, seqid))
+        return (name, type, seqid)
+
+
+def make_pedantic(proto):
+    """ Wrap a protocol in the pedantic sequence ID wrapper. """
+    return TPedanticSequenceIdProtocolWrapper(proto)
+
 class MultiplexedOptionalTest(AbstractTest):
     def get_protocol2(self, transport):
         return None
@@ -275,83 +317,83 @@
 
 class BinaryTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+        return make_pedantic(TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport))
 
 
 class MultiplexedBinaryTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        wrapped_proto = TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+        wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
 
     def get_protocol2(self, transport):
-        wrapped_proto = TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+        wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
 
 
 class AcceleratedBinaryTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        return TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+        return make_pedantic(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
 
 
 class MultiplexedAcceleratedBinaryTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        wrapped_proto = TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+        wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
 
     def get_protocol2(self, transport):
-        wrapped_proto = TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+        wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
 
 
 class CompactTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        return TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+        return make_pedantic(TCompactProtocol.TCompactProtocolFactory().getProtocol(transport))
 
 
 class MultiplexedCompactTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        wrapped_proto = TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+        wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolFactory().getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
 
     def get_protocol2(self, transport):
-        wrapped_proto = TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+        wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolFactory().getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
 
 
 class AcceleratedCompactTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        return TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+        return make_pedantic(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
 
 
 class MultiplexedAcceleratedCompactTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        wrapped_proto = TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+        wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
 
     def get_protocol2(self, transport):
-        wrapped_proto = TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+        wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
 
 
 class JSONTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        return TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+        return make_pedantic(TJSONProtocol.TJSONProtocolFactory().getProtocol(transport))
 
 
 class MultiplexedJSONTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
-        wrapped_proto = TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+        wrapped_proto = make_pedantic(TJSONProtocol.TJSONProtocolFactory().getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
 
     def get_protocol2(self, transport):
-        wrapped_proto = TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+        wrapped_proto = make_pedantic(TJSONProtocol.TJSONProtocolFactory().getProtocol(transport))
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
 
 
 class HeaderTest(MultiplexedOptionalTest):
     def get_protocol(self, transport):
         factory = THeaderProtocol.THeaderProtocolFactory()
-        return factory.getProtocol(transport)
+        return make_pedantic(factory.getProtocol(transport))
 
 
 def suite():
@@ -424,7 +466,6 @@
 
     if options.genpydir:
         sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
-    sys.path.insert(0, local_libpath())
 
     if options.http_path:
         options.trans = 'http'