THRIFT-4405: Enhance python cross test client for pedantic sequence id handling
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index ddcce8d..cc9185c 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -23,10 +23,15 @@
import sys
import time
import unittest
-from optparse import OptionParser
+from optparse import OptionParser
from util import local_libpath
+sys.path.insert(0, local_libpath())
+
+from thrift.protocol import TProtocolDecorator
+from thrift.protocol import TProtocol
+
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
@@ -268,6 +273,43 @@
self.assertEqual(self.client.testString('Python'), 'Python')
+# LAST_SEQID is a global because we have one transport and multiple protocols
+# running on it (when multiplexec)
+LAST_SEQID = None
+
+class TPedanticSequenceIdProtocolWrapper(TProtocolDecorator.TProtocolDecorator):
+ """
+ Wraps any protocol with sequence ID checking: looks for outbound
+ uniqueness as well as request/response alignment.
+ """
+ def __init__(self, protocol):
+ # TProtocolDecorator.__new__ does all the heavy lifting
+ pass
+
+ def writeMessageBegin(self, name, type, seqid):
+ global LAST_SEQID
+ if LAST_SEQID and LAST_SEQID == seqid:
+ raise TProtocol.TProtocolException(INVALID_DATA,
+ "Python client reused sequence ID {0}".format(seqid))
+ LAST_SEQID = seqid
+ super(TPedanticSequenceIdProtocolWrapper, self).writeMessageBegin(
+ name, type, seqid)
+
+ def readMessageBegin(self):
+ global LAST_SEQID
+ (name, type, seqid) =\
+ super(TPedanticSequenceIdProtocolWrapper, self).readMessageBegin()
+ if LAST_SEQID != seqid:
+ raise TProtocol.TProtocolException(INVALID_DATA,
+ "We sent seqid {0} and server returned seqid {1}".format(
+ self.last, seqid))
+ return (name, type, seqid)
+
+
+def make_pedantic(proto):
+ """ Wrap a protocol in the pedantic sequence ID wrapper. """
+ return TPedanticSequenceIdProtocolWrapper(proto)
+
class MultiplexedOptionalTest(AbstractTest):
def get_protocol2(self, transport):
return None
@@ -275,83 +317,83 @@
class BinaryTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+ return make_pedantic(TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport))
class MultiplexedBinaryTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- wrapped_proto = TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+ wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
def get_protocol2(self, transport):
- wrapped_proto = TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport)
+ wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolFactory().getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
class AcceleratedBinaryTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- return TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+ return make_pedantic(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
class MultiplexedAcceleratedBinaryTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- wrapped_proto = TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+ wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
def get_protocol2(self, transport):
- wrapped_proto = TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+ wrapped_proto = make_pedantic(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
class CompactTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- return TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+ return make_pedantic(TCompactProtocol.TCompactProtocolFactory().getProtocol(transport))
class MultiplexedCompactTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- wrapped_proto = TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+ wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolFactory().getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
def get_protocol2(self, transport):
- wrapped_proto = TCompactProtocol.TCompactProtocolFactory().getProtocol(transport)
+ wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolFactory().getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
class AcceleratedCompactTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- return TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+ return make_pedantic(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
class MultiplexedAcceleratedCompactTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- wrapped_proto = TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+ wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
def get_protocol2(self, transport):
- wrapped_proto = TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport)
+ wrapped_proto = make_pedantic(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False).getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
class JSONTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- return TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+ return make_pedantic(TJSONProtocol.TJSONProtocolFactory().getProtocol(transport))
class MultiplexedJSONTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
- wrapped_proto = TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+ wrapped_proto = make_pedantic(TJSONProtocol.TJSONProtocolFactory().getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "ThriftTest")
def get_protocol2(self, transport):
- wrapped_proto = TJSONProtocol.TJSONProtocolFactory().getProtocol(transport)
+ wrapped_proto = make_pedantic(TJSONProtocol.TJSONProtocolFactory().getProtocol(transport))
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
class HeaderTest(MultiplexedOptionalTest):
def get_protocol(self, transport):
factory = THeaderProtocol.THeaderProtocolFactory()
- return factory.getProtocol(transport)
+ return make_pedantic(factory.getProtocol(transport))
def suite():
@@ -424,7 +466,6 @@
if options.genpydir:
sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir))
- sys.path.insert(0, local_libpath())
if options.http_path:
options.trans = 'http'