blob: 52bedd5e31f0cc6349d40c61456d7d26495a98f2 [file] [log] [blame]
Gavin McDonald0b75e1a2010-10-28 02:12:01 +00001#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12# http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22import sys, glob
23sys.path.insert(0, './gen-py')
24sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
25
26from ThriftTest.ttypes import *
27from thrift.transport import TTransport
28from thrift.transport import TSocket
29from thrift.protocol import TBinaryProtocol
30import unittest
31import time
32
33class AbstractTest(unittest.TestCase):
34
35 def setUp(self):
36 self.v1obj = VersioningTestV1(
37 begin_in_both=12345,
38 old_string='aaa',
39 end_in_both=54321,
40 )
41
42 self.v2obj = VersioningTestV2(
43 begin_in_both=12345,
44 newint=1,
45 newbyte=2,
46 newshort=3,
47 newlong=4,
48 newdouble=5.0,
49 newstruct=Bonk(message="Hello!", type=123),
50 newlist=[7,8,9],
51 newset=[42,1,8],
52 newmap={1:2,2:3},
53 newstring="Hola!",
54 end_in_both=54321,
55 )
56
57 def _serialize(self, obj):
58 trans = TTransport.TMemoryBuffer()
59 prot = self.protocol_factory.getProtocol(trans)
60 obj.write(prot)
61 return trans.getvalue()
62
63 def _deserialize(self, objtype, data):
64 prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
65 ret = objtype()
66 ret.read(prot)
67 return ret
68
69 def testForwards(self):
70 obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
71 self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
72 self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
73
74 def testBackwards(self):
75 obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
76 self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
77 self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
78
79
80class NormalBinaryTest(AbstractTest):
81 protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
82
83class AcceleratedBinaryTest(AbstractTest):
84 protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
85
86
87class AcceleratedFramedTest(unittest.TestCase):
88 def testSplit(self):
89 """Test FramedTransport and BinaryProtocolAccelerated
90
91 Tests that TBinaryProtocolAccelerated and TFramedTransport
92 play nicely together when a read spans a frame"""
93
94 protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
95 bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1))
96
97 databuf = TTransport.TMemoryBuffer()
98 prot = protocol_factory.getProtocol(databuf)
99 prot.writeI32(42)
100 prot.writeString(bigstring)
101 prot.writeI16(24)
102 data = databuf.getvalue()
103 cutpoint = len(data)/2
104 parts = [ data[:cutpoint], data[cutpoint:] ]
105
106 framed_buffer = TTransport.TMemoryBuffer()
107 framed_writer = TTransport.TFramedTransport(framed_buffer)
108 for part in parts:
109 framed_writer.write(part)
110 framed_writer.flush()
111 self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
112
113 # Recreate framed_buffer so we can read from it.
114 framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
115 framed_reader = TTransport.TFramedTransport(framed_buffer)
116 prot = protocol_factory.getProtocol(framed_reader)
117 self.assertEqual(prot.readI32(), 42)
118 self.assertEqual(prot.readString(), bigstring)
119 self.assertEqual(prot.readI16(), 24)
120
121
122
123def suite():
124 suite = unittest.TestSuite()
125 loader = unittest.TestLoader()
126
127 suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
128 suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
129 suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
130 return suite
131
132if __name__ == "__main__":
133 unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))