blob: f7d05ff975db41fbbde491a2802e7d2369f61951 [file] [log] [blame]
Zezeng Wangc3728122020-04-27 15:48:19 +08001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import unittest
21
22import _import_local_thrift # noqa
23from thrift.protocol.TBinaryProtocol import TBinaryProtocol
24from thrift.transport import TTransport
25
26
27def testNaked(type, data):
28 buf = TTransport.TMemoryBuffer()
29 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
30 protocol = TBinaryProtocol(transport)
31 if type.capitalize() == 'Byte':
32 protocol.writeByte(data)
33
34 if type.capitalize() == 'I16':
35 protocol.writeI16(data)
36
37 if type.capitalize() == 'I32':
38 protocol.writeI32(data)
39
40 if type.capitalize() == 'I64':
41 protocol.writeI64(data)
42
43 if type.capitalize() == 'String':
44 protocol.writeString(data)
45
46 if type.capitalize() == 'Double':
47 protocol.writeDouble(data)
48
49 if type.capitalize() == 'Binary':
50 protocol.writeBinary(data)
51
52 if type.capitalize() == 'Bool':
53 protocol.writeBool(data)
54
55 transport.flush()
56 data_r = buf.getvalue()
57 buf = TTransport.TMemoryBuffer(data_r)
58 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
59 protocol = TBinaryProtocol(transport)
60 if type.capitalize() == 'Byte':
61 return protocol.readByte()
62
63 if type.capitalize() == 'I16':
64 return protocol.readI16()
65
66 if type.capitalize() == 'I32':
67 return protocol.readI32()
68
69 if type.capitalize() == 'I64':
70 return protocol.readI64()
71
72 if type.capitalize() == 'String':
73 return protocol.readString()
74
75 if type.capitalize() == 'Double':
76 return protocol.readDouble()
77
78 if type.capitalize() == 'Binary':
79 return protocol.readBinary()
80
81 if type.capitalize() == 'Bool':
82 return protocol.readBool()
83
84
85def testField(type, data):
86 TType = {"Bool": 2, "Byte": 3, "Binary": 5, "I16": 6, "I32": 8, "I64": 10, "Double": 11, "String": 12}
87 buf = TTransport.TMemoryBuffer()
88 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
89 protocol = TBinaryProtocol(transport)
90 protocol.writeStructBegin('struct')
91 protocol.writeFieldBegin("field", TType[type.capitalize()], 10)
92 if type.capitalize() == 'Byte':
93 protocol.writeByte(data)
94
95 if type.capitalize() == 'I16':
96 protocol.writeI16(data)
97
98 if type.capitalize() == 'I32':
99 protocol.writeI32(data)
100
101 if type.capitalize() == 'I64':
102 protocol.writeI64(data)
103
104 if type.capitalize() == 'String':
105 protocol.writeString(data)
106
107 if type.capitalize() == 'Double':
108 protocol.writeDouble(data)
109
110 if type.capitalize() == 'Binary':
111 protocol.writeBinary(data)
112
113 if type.capitalize() == 'Bool':
114 protocol.writeBool(data)
115
116 protocol.writeFieldEnd()
117 protocol.writeStructEnd()
118
119 transport.flush()
120 data_r = buf.getvalue()
121
122 buf = TTransport.TMemoryBuffer(data_r)
123 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
124 protocol = TBinaryProtocol(transport)
125 protocol.readStructBegin()
126 protocol.readFieldBegin()
127 if type.capitalize() == 'Byte':
128 return protocol.readByte()
129
130 if type.capitalize() == 'I16':
131 return protocol.readI16()
132
133 if type.capitalize() == 'I32':
134 return protocol.readI32()
135
136 if type.capitalize() == 'I64':
137 return protocol.readI64()
138
139 if type.capitalize() == 'String':
140 return protocol.readString()
141
142 if type.capitalize() == 'Double':
143 return protocol.readDouble()
144
145 if type.capitalize() == 'Binary':
146 return protocol.readBinary()
147
148 if type.capitalize() == 'Bool':
149 return protocol.readBool()
150
151 protocol.readFieldEnd()
152 protocol.readStructEnd()
153
154
155def testMessage(data):
156 message = {}
157 message['name'] = data[0]
158 message['type'] = data[1]
159 message['seqid'] = data[2]
160
161 buf = TTransport.TMemoryBuffer()
162 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
163 protocol = TBinaryProtocol(transport)
164 protocol.writeMessageBegin(message['name'], message['type'], message['seqid'])
165 protocol.writeMessageEnd()
166
167 transport.flush()
168 data_r = buf.getvalue()
169
170 buf = TTransport.TMemoryBuffer(data_r)
171 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
172 protocol = TBinaryProtocol(transport)
173 result = protocol.readMessageBegin()
174 protocol.readMessageEnd()
175 return result
176
177
178class TestTBinaryProtocol(unittest.TestCase):
179
180 def test_TBinaryProtocol_write_read(self):
181 try:
182 testNaked('Byte', 123)
183 for i in range(0, 128):
184 self.assertEqual(i, testField('Byte', i))
185 self.assertEqual(-i, testField('Byte', -i))
186
187 self.assertEqual(0, testNaked("I16", 0))
188 self.assertEqual(1, testNaked("I16", 1))
189 self.assertEqual(15000, testNaked("I16", 15000))
190 self.assertEqual(0x7fff, testNaked("I16", 0x7fff))
191 self.assertEqual(-1, testNaked("I16", -1))
192 self.assertEqual(-15000, testNaked("I16", -15000))
193 self.assertEqual(-0x7fff, testNaked("I16", -0x7fff))
194 self.assertEqual(32767, testNaked("I16", 32767))
195 self.assertEqual(-32768, testNaked("I16", -32768))
196
197 self.assertEqual(0, testField("I16", 0))
198 self.assertEqual(1, testField("I16", 1))
199 self.assertEqual(7, testField("I16", 7))
200 self.assertEqual(150, testField("I16", 150))
201 self.assertEqual(15000, testField("I16", 15000))
202 self.assertEqual(0x7fff, testField("I16", 0x7fff))
203 self.assertEqual(-1, testField("I16", -1))
204 self.assertEqual(-7, testField("I16", -7))
205 self.assertEqual(-150, testField("I16", -150))
206 self.assertEqual(-15000, testField("I16", -15000))
207 self.assertEqual(-0xfff, testField("I16", -0xfff))
208
209 self.assertEqual(0, testNaked("I32", 0))
210 self.assertEqual(1, testNaked("I32", 1))
211 self.assertEqual(15000, testNaked("I32", 15000))
212 self.assertEqual(0xffff, testNaked("I32", 0xffff))
213 self.assertEqual(-1, testNaked("I32", -1))
214 self.assertEqual(-15000, testNaked("I32", -15000))
215 self.assertEqual(-0xffff, testNaked("I32", -0xffff))
216 self.assertEqual(2147483647, testNaked("I32", 2147483647))
217 self.assertEqual(-2147483647, testNaked("I32", -2147483647))
218
219 self.assertEqual(0, testField("I32", 0))
220 self.assertEqual(1, testField("I32", 1))
221 self.assertEqual(7, testField("I32", 7))
222 self.assertEqual(150, testField("I32", 150))
223 self.assertEqual(15000, testField("I32", 15000))
224 self.assertEqual(31337, testField("I32", 31337))
225 self.assertEqual(0xffff, testField("I32", 0xffff))
226 self.assertEqual(0xffffff, testField("I32", 0xffffff))
227 self.assertEqual(-1, testField("I32", -1))
228 self.assertEqual(-7, testField("I32", -7))
229 self.assertEqual(-150, testField("I32", -150))
230 self.assertEqual(-15000, testField("I32", -15000))
231 self.assertEqual(-0xffff, testField("I32", -0xffff))
232 self.assertEqual(-0xffffff, testField("I32", -0xffffff))
233
234 self.assertEqual(9223372036854775807, testNaked("I64", 9223372036854775807))
235 self.assertEqual(-9223372036854775807, testNaked("I64", -9223372036854775807))
236 self.assertEqual(-0, testNaked("I64", 0))
237
238 self.assertEqual(True, testNaked("Bool", True))
239 self.assertEqual(3.14159261, testNaked("Double", 3.14159261))
240 self.assertEqual("hello thrift", testNaked("String", "hello thrift"))
241 self.assertEqual(True, testField('Bool', True))
242 self.assertEqual(3.1415926, testNaked("Double", 3.1415926))
243 self.assertEqual("hello thrift", testNaked("String", "hello thrift"))
244
245 TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
246 test_data = [("short message name", TMessageType['T_CALL'], 0),
247 ("1", TMessageType['T_REPLY'], 12345),
248 ("loooooooooooooooooooooooooooooooooong", TMessageType['T_EXCEPTION'], 1 << 16),
249 ("one way push", TMessageType['T_ONEWAY'], 12),
250 ("Janky", TMessageType['T_CALL'], 0)]
251
252 for dt in test_data:
253 result = testMessage(dt)
254 self.assertEqual(result[0], dt[0])
255 self.assertEqual(result[1], dt[1])
256 self.assertEqual(result[2], dt[2])
257
258 except Exception as e:
259 print("Assertion fail")
260 raise e
261
262
263if __name__ == '__main__':
264 unittest.main()