| #!/usr/bin/env python |
| |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| from ThriftTest.ttypes import Xtruct |
| from thrift.transport import TTransport |
| from thrift.protocol import TBinaryProtocol |
| from thrift.protocol import TCompactProtocol |
| import unittest |
| |
| |
| class TestEof(unittest.TestCase): |
| |
| def make_data(self, pfactory=None): |
| trans = TTransport.TMemoryBuffer() |
| if pfactory: |
| prot = pfactory.getProtocol(trans) |
| else: |
| prot = TBinaryProtocol.TBinaryProtocol(trans) |
| |
| x = Xtruct() |
| x.string_thing = "Zero" |
| x.byte_thing = 0 |
| |
| x.write(prot) |
| |
| x = Xtruct() |
| x.string_thing = "One" |
| x.byte_thing = 1 |
| |
| x.write(prot) |
| |
| return trans.getvalue() |
| |
| def testTransportReadAll(self): |
| """Test that readAll on any type of transport throws an EOFError""" |
| trans = TTransport.TMemoryBuffer(self.make_data()) |
| trans.readAll(1) |
| |
| try: |
| trans.readAll(10000) |
| except EOFError: |
| return |
| |
| self.fail("Should have gotten EOFError") |
| |
| def eofTestHelper(self, pfactory): |
| trans = TTransport.TMemoryBuffer(self.make_data(pfactory)) |
| prot = pfactory.getProtocol(trans) |
| |
| x = Xtruct() |
| x.read(prot) |
| self.assertEqual(x.string_thing, "Zero") |
| self.assertEqual(x.byte_thing, 0) |
| |
| x = Xtruct() |
| x.read(prot) |
| self.assertEqual(x.string_thing, "One") |
| self.assertEqual(x.byte_thing, 1) |
| |
| try: |
| x = Xtruct() |
| x.read(prot) |
| except EOFError: |
| return |
| |
| self.fail("Should have gotten EOFError") |
| |
| def eofTestHelperStress(self, pfactory): |
| """Test the ability of TBinaryProtocol to deal with the removal of every byte in the file""" |
| # TODO: we should make sure this covers more of the code paths |
| |
| data = self.make_data(pfactory) |
| for i in range(0, len(data) + 1): |
| trans = TTransport.TMemoryBuffer(data[0:i]) |
| prot = pfactory.getProtocol(trans) |
| try: |
| x = Xtruct() |
| x.read(prot) |
| x.read(prot) |
| x.read(prot) |
| except EOFError: |
| continue |
| self.fail("Should have gotten an EOFError") |
| |
| def testBinaryProtocolEof(self): |
| """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream""" |
| self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory()) |
| self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory()) |
| |
| def testBinaryProtocolAcceleratedBinaryEof(self): |
| """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream""" |
| self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)) |
| self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)) |
| |
| def testCompactProtocolEof(self): |
| """Test that TCompactProtocol throws an EOFError when it reaches the end of the stream""" |
| self.eofTestHelper(TCompactProtocol.TCompactProtocolFactory()) |
| self.eofTestHelperStress(TCompactProtocol.TCompactProtocolFactory()) |
| |
| def testCompactProtocolAcceleratedCompactEof(self): |
| """Test that TCompactProtocolAccelerated throws an EOFError when it reaches the end of the stream""" |
| self.eofTestHelper(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)) |
| self.eofTestHelperStress(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)) |
| |
| |
| def suite(): |
| suite = unittest.TestSuite() |
| loader = unittest.TestLoader() |
| suite.addTest(loader.loadTestsFromTestCase(TestEof)) |
| return suite |
| |
| if __name__ == "__main__": |
| unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) |