THRIFT-212. python: Make TFramedTransport implement CReadableTransport
This involved adding a few methods to provide lower-level access to the
internal read buffer. This will allow us to use TBinaryProtocolAccelerated
with TFramedTransport.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@739632 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index ddb368f..67f97ba 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -228,7 +228,7 @@
return framed
-class TFramedTransport(TTransportBase):
+class TFramedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and frames its I/O when writing."""
@@ -274,3 +274,18 @@
buf = pack("!i", wsz) + wout
self.__trans.write(buf)
self.__trans.flush()
+
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ readFrame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = StringIO(prefix)
+ return self.__rbuf
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index 4be8b8c..a99bce6 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -64,12 +64,49 @@
protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
+class AcceleratedFramedTest(unittest.TestCase):
+ def testSplit(self):
+ """Test FramedTransport and BinaryProtocolAccelerated
+
+ Tests that TBinaryProtocolAccelerated and TFramedTransport
+ play nicely together when a read spans a frame"""
+
+ protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
+ bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1))
+
+ databuf = TTransport.TMemoryBuffer()
+ prot = protocol_factory.getProtocol(databuf)
+ prot.writeI32(42)
+ prot.writeString(bigstring)
+ prot.writeI16(24)
+ data = databuf.getvalue()
+ cutpoint = len(data)/2
+ parts = [ data[:cutpoint], data[cutpoint:] ]
+
+ framed_buffer = TTransport.TMemoryBuffer()
+ framed_writer = TTransport.TFramedTransport(framed_buffer)
+ for part in parts:
+ framed_writer.write(part)
+ framed_writer.flush()
+ self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
+
+ # Recreate framed_buffer so we can read from it.
+ framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
+ framed_reader = TTransport.TFramedTransport(framed_buffer)
+ prot = protocol_factory.getProtocol(framed_reader)
+ self.assertEqual(prot.readI32(), 42)
+ self.assertEqual(prot.readString(), bigstring)
+ self.assertEqual(prot.readI16(), 24)
+
+
+
def suite():
suite = unittest.TestSuite()
loader = unittest.TestLoader()
suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
+ suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
return suite
if __name__ == "__main__":