Merging EOFException changes from Ben Maurer
Summary: Throw a proper EOFError in this case. Long term we want to change this to properly fit into the Thrift TException heirarchy with a good way to handle the original exception as well. For now, Ben is the primary user of this so we'll go ahead with his patch.
Reviewed By: mcslee
Test Plan: Included in test/py/TestEof.py
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665365 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 61ccd8f..3b7b5c0 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -684,6 +684,9 @@
static bool
checkTypeByte(DecodeBuffer* input, TType expected) {
TType got = readByte(input);
+ if (INT_CONV_ERROR_OCCURRED(got)) {
+ return false;
+ }
if (expected != got) {
PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field");
@@ -829,11 +832,16 @@
StructItemSpec parsedspec;
type = readByte(input);
+ if (type == -1) {
+ return false;
+ }
if (type == T_STOP) {
break;
}
tag = readI16(input);
-
+ if (INT_CONV_ERROR_OCCURRED(tag)) {
+ return false;
+ }
if (tag >= 0 && tag < spec_seq_len) {
item_spec = PyTuple_GET_ITEM(spec_seq, tag);
} else {
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 60f4233..3968a71 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -47,6 +47,10 @@
chunk = self.read(sz-have)
have += len(chunk)
buff += chunk
+
+ if len(chunk) == 0:
+ raise EOFError()
+
return buff
def write(self, buf):
@@ -213,7 +217,7 @@
def cstringio_refill(self, partialread, reqlen):
# only one shot at reading...
- raise EOFException()
+ raise EOFError()
class TFramedTransportFactory:
diff --git a/test/py/TestEof.py b/test/py/TestEof.py
new file mode 100644
index 0000000..a23a441
--- /dev/null
+++ b/test/py/TestEof.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python
+
+import sys, glob
+sys.path.insert(0, './gen-py')
+sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.protocol import TBinaryProtocol
+import unittest
+import time
+
+class TestEof(unittest.TestCase):
+
+ def setUp(self):
+ trans = TTransport.TMemoryBuffer()
+ prot = TBinaryProtocol.TBinaryProtocol(trans)
+
+ x = Xtruct()
+ x.string_thing = "Zero"
+ x.byte_thing = 0
+
+ x.write(prot)
+
+ x = Xtruct()
+ x.string_thing = "One"
+ x.byte_thing = 1
+
+ x.write(prot)
+
+ self.data = trans.getvalue()
+
+ def testTransportReadAll(self):
+ """Test that readAll on any type of transport throws an EOFError"""
+ trans = TTransport.TMemoryBuffer(self.data)
+ trans.readAll(1)
+
+ try:
+ trans.readAll(10000)
+ except EOFError:
+ return
+
+ self.fail("Should have gotten EOFError")
+
+ def eofTestHelper(self, pfactory):
+ trans = TTransport.TMemoryBuffer(self.data)
+ prot = pfactory.getProtocol(trans)
+
+ x = Xtruct()
+ x.read(prot)
+ self.assertEqual(x.string_thing, "Zero")
+ self.assertEqual(x.byte_thing, 0)
+
+ x = Xtruct()
+ x.read(prot)
+ self.assertEqual(x.string_thing, "One")
+ self.assertEqual(x.byte_thing, 1)
+
+ try:
+ x = Xtruct()
+ x.read(prot)
+ except EOFError:
+ return
+
+ self.fail("Should have gotten EOFError")
+
+ def eofTestHelperStress(self, pfactory):
+ """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file"""
+ # TODO: we should make sure this covers more of the code paths
+
+ for i in xrange(0, len(self.data) + 1):
+ trans = TTransport.TMemoryBuffer(self.data[0:i])
+ prot = pfactory.getProtocol(trans)
+ try:
+ x = Xtruct()
+ x.read(prot)
+ x.read(prot)
+ x.read(prot)
+ except EOFError:
+ continue
+ self.fail("Should have gotten an EOFError")
+
+ def testBinaryProtocolEof(self):
+ """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream"""
+ self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory())
+ self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory())
+
+ def testBinaryProtocolAcceleratedEof(self):
+ """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream"""
+ self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+ self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+
+suite = unittest.TestSuite()
+loader = unittest.TestLoader()
+
+suite.addTest(loader.loadTestsFromTestCase(TestEof))
+
+testRunner = unittest.TextTestRunner(verbosity=2)
+testRunner.run(suite)