Merging EOFException changes from Ben Maurer

Summary: Throw a proper EOFError in this case. Long term we want to change this to properly fit into the Thrift TException heirarchy with a good way to handle the original exception as well. For now, Ben is the primary user of this so we'll go ahead with his patch.

Reviewed By: mcslee

Test Plan: Included in test/py/TestEof.py


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665365 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 61ccd8f..3b7b5c0 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -684,6 +684,9 @@
 static bool
 checkTypeByte(DecodeBuffer* input, TType expected) {
   TType got = readByte(input);
+  if (INT_CONV_ERROR_OCCURRED(got)) {
+    return false;
+  }
 
   if (expected != got) {
     PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field");
@@ -829,11 +832,16 @@
     StructItemSpec parsedspec;
 
     type = readByte(input);
+    if (type == -1) {
+      return false;
+    }
     if (type == T_STOP) {
       break;
     }
     tag = readI16(input);
-
+    if (INT_CONV_ERROR_OCCURRED(tag)) {
+      return false;
+    }
     if (tag >= 0 && tag < spec_seq_len) {
       item_spec = PyTuple_GET_ITEM(spec_seq, tag);
     } else {
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 60f4233..3968a71 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -47,6 +47,10 @@
       chunk = self.read(sz-have)
       have += len(chunk)
       buff += chunk
+
+      if len(chunk) == 0:
+        raise EOFError()
+
     return buff
 
   def write(self, buf):
@@ -213,7 +217,7 @@
 
   def cstringio_refill(self, partialread, reqlen):
     # only one shot at reading...
-    raise EOFException()
+    raise EOFError()
 
 class TFramedTransportFactory:
 
diff --git a/test/py/TestEof.py b/test/py/TestEof.py
new file mode 100644
index 0000000..a23a441
--- /dev/null
+++ b/test/py/TestEof.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python
+
+import sys, glob
+sys.path.insert(0, './gen-py')
+sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.protocol import TBinaryProtocol
+import unittest
+import time
+
+class TestEof(unittest.TestCase):
+
+  def setUp(self):
+    trans = TTransport.TMemoryBuffer()
+    prot = TBinaryProtocol.TBinaryProtocol(trans)
+
+    x = Xtruct()
+    x.string_thing = "Zero"
+    x.byte_thing = 0
+
+    x.write(prot)
+
+    x = Xtruct()
+    x.string_thing = "One"
+    x.byte_thing = 1
+
+    x.write(prot)
+
+    self.data = trans.getvalue()
+
+  def testTransportReadAll(self):
+    """Test that readAll on any type of transport throws an EOFError"""
+    trans = TTransport.TMemoryBuffer(self.data)
+    trans.readAll(1)
+
+    try:
+      trans.readAll(10000)
+    except EOFError:
+      return
+
+    self.fail("Should have gotten EOFError")
+
+  def eofTestHelper(self, pfactory):
+    trans = TTransport.TMemoryBuffer(self.data)
+    prot = pfactory.getProtocol(trans)
+
+    x = Xtruct()
+    x.read(prot)
+    self.assertEqual(x.string_thing, "Zero")
+    self.assertEqual(x.byte_thing, 0)
+
+    x = Xtruct()
+    x.read(prot)
+    self.assertEqual(x.string_thing, "One")
+    self.assertEqual(x.byte_thing, 1)
+
+    try:
+      x = Xtruct()
+      x.read(prot)
+    except EOFError:
+      return
+
+    self.fail("Should have gotten EOFError")
+
+  def eofTestHelperStress(self, pfactory):
+    """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file"""
+    # TODO: we should make sure this covers more of the code paths
+
+    for i in xrange(0, len(self.data) + 1):
+      trans = TTransport.TMemoryBuffer(self.data[0:i])
+      prot = pfactory.getProtocol(trans)
+      try:
+        x = Xtruct()
+        x.read(prot)
+        x.read(prot)
+        x.read(prot)
+      except EOFError:
+        continue
+      self.fail("Should have gotten an EOFError")
+
+  def testBinaryProtocolEof(self):
+    """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream"""
+    self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory())
+    self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory())
+
+  def testBinaryProtocolAcceleratedEof(self):
+    """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream"""
+    self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+    self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory())
+
+suite = unittest.TestSuite()
+loader = unittest.TestLoader()
+
+suite.addTest(loader.loadTestsFromTestCase(TestEof))
+
+testRunner = unittest.TextTestRunner(verbosity=2)
+testRunner.run(suite)