blob: b257626b1f8a7d79e9aadaccfe818440c2d79075 [file] [log] [blame]
Zezeng Wangc3728122020-04-27 15:48:19 +08001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import unittest
21
22import _import_local_thrift # noqa
23from thrift.protocol.TBinaryProtocol import TBinaryProtocol
24from thrift.transport import TTransport
25
26
27def testNaked(type, data):
28 buf = TTransport.TMemoryBuffer()
29 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
30 protocol = TBinaryProtocol(transport)
31 if type.capitalize() == 'Byte':
32 protocol.writeByte(data)
33
34 if type.capitalize() == 'I16':
35 protocol.writeI16(data)
36
37 if type.capitalize() == 'I32':
38 protocol.writeI32(data)
39
40 if type.capitalize() == 'I64':
41 protocol.writeI64(data)
42
43 if type.capitalize() == 'String':
44 protocol.writeString(data)
45
46 if type.capitalize() == 'Double':
47 protocol.writeDouble(data)
48
49 if type.capitalize() == 'Binary':
50 protocol.writeBinary(data)
51
52 if type.capitalize() == 'Bool':
53 protocol.writeBool(data)
54
55 transport.flush()
56 data_r = buf.getvalue()
57 buf = TTransport.TMemoryBuffer(data_r)
58 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
59 protocol = TBinaryProtocol(transport)
60 if type.capitalize() == 'Byte':
61 return protocol.readByte()
62
63 if type.capitalize() == 'I16':
64 return protocol.readI16()
65
66 if type.capitalize() == 'I32':
67 return protocol.readI32()
68
69 if type.capitalize() == 'I64':
70 return protocol.readI64()
71
72 if type.capitalize() == 'String':
73 return protocol.readString()
74
75 if type.capitalize() == 'Double':
76 return protocol.readDouble()
77
78 if type.capitalize() == 'Binary':
79 return protocol.readBinary()
80
81 if type.capitalize() == 'Bool':
82 return protocol.readBool()
83
84
85def testField(type, data):
86 TType = {"Bool": 2, "Byte": 3, "Binary": 5, "I16": 6, "I32": 8, "I64": 10, "Double": 11, "String": 12}
87 buf = TTransport.TMemoryBuffer()
88 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
89 protocol = TBinaryProtocol(transport)
90 protocol.writeStructBegin('struct')
91 protocol.writeFieldBegin("field", TType[type.capitalize()], 10)
92 if type.capitalize() == 'Byte':
93 protocol.writeByte(data)
94
95 if type.capitalize() == 'I16':
96 protocol.writeI16(data)
97
98 if type.capitalize() == 'I32':
99 protocol.writeI32(data)
100
101 if type.capitalize() == 'I64':
102 protocol.writeI64(data)
103
104 if type.capitalize() == 'String':
105 protocol.writeString(data)
106
107 if type.capitalize() == 'Double':
108 protocol.writeDouble(data)
109
110 if type.capitalize() == 'Binary':
111 protocol.writeBinary(data)
112
113 if type.capitalize() == 'Bool':
114 protocol.writeBool(data)
115
116 protocol.writeFieldEnd()
117 protocol.writeStructEnd()
118
119 transport.flush()
120 data_r = buf.getvalue()
121
122 buf = TTransport.TMemoryBuffer(data_r)
123 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
124 protocol = TBinaryProtocol(transport)
125 protocol.readStructBegin()
126 protocol.readFieldBegin()
127 if type.capitalize() == 'Byte':
128 return protocol.readByte()
129
130 if type.capitalize() == 'I16':
131 return protocol.readI16()
132
133 if type.capitalize() == 'I32':
134 return protocol.readI32()
135
136 if type.capitalize() == 'I64':
137 return protocol.readI64()
138
139 if type.capitalize() == 'String':
140 return protocol.readString()
141
142 if type.capitalize() == 'Double':
143 return protocol.readDouble()
144
145 if type.capitalize() == 'Binary':
146 return protocol.readBinary()
147
148 if type.capitalize() == 'Bool':
149 return protocol.readBool()
150
151 protocol.readFieldEnd()
152 protocol.readStructEnd()
153
154
bwangelme58000cc2023-11-06 12:21:38 +0800155def testMessage(data, strict=True):
Zezeng Wangc3728122020-04-27 15:48:19 +0800156 message = {}
157 message['name'] = data[0]
158 message['type'] = data[1]
159 message['seqid'] = data[2]
160
bwangelme58000cc2023-11-06 12:21:38 +0800161 strictRead, strictWrite = True, True
162 if not strict:
163 strictRead, strictWrite = False, False
164
Zezeng Wangc3728122020-04-27 15:48:19 +0800165 buf = TTransport.TMemoryBuffer()
166 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
bwangelme58000cc2023-11-06 12:21:38 +0800167 protocol = TBinaryProtocol(transport, strictRead=strictRead, strictWrite=strictWrite)
Zezeng Wangc3728122020-04-27 15:48:19 +0800168 protocol.writeMessageBegin(message['name'], message['type'], message['seqid'])
169 protocol.writeMessageEnd()
170
171 transport.flush()
172 data_r = buf.getvalue()
173
174 buf = TTransport.TMemoryBuffer(data_r)
175 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
bwangelme58000cc2023-11-06 12:21:38 +0800176 protocol = TBinaryProtocol(transport, strictRead=strictRead, strictWrite=strictWrite)
Zezeng Wangc3728122020-04-27 15:48:19 +0800177 result = protocol.readMessageBegin()
178 protocol.readMessageEnd()
179 return result
180
181
182class TestTBinaryProtocol(unittest.TestCase):
183
184 def test_TBinaryProtocol_write_read(self):
185 try:
186 testNaked('Byte', 123)
187 for i in range(0, 128):
188 self.assertEqual(i, testField('Byte', i))
189 self.assertEqual(-i, testField('Byte', -i))
190
191 self.assertEqual(0, testNaked("I16", 0))
192 self.assertEqual(1, testNaked("I16", 1))
193 self.assertEqual(15000, testNaked("I16", 15000))
194 self.assertEqual(0x7fff, testNaked("I16", 0x7fff))
195 self.assertEqual(-1, testNaked("I16", -1))
196 self.assertEqual(-15000, testNaked("I16", -15000))
197 self.assertEqual(-0x7fff, testNaked("I16", -0x7fff))
198 self.assertEqual(32767, testNaked("I16", 32767))
199 self.assertEqual(-32768, testNaked("I16", -32768))
200
201 self.assertEqual(0, testField("I16", 0))
202 self.assertEqual(1, testField("I16", 1))
203 self.assertEqual(7, testField("I16", 7))
204 self.assertEqual(150, testField("I16", 150))
205 self.assertEqual(15000, testField("I16", 15000))
206 self.assertEqual(0x7fff, testField("I16", 0x7fff))
207 self.assertEqual(-1, testField("I16", -1))
208 self.assertEqual(-7, testField("I16", -7))
209 self.assertEqual(-150, testField("I16", -150))
210 self.assertEqual(-15000, testField("I16", -15000))
211 self.assertEqual(-0xfff, testField("I16", -0xfff))
212
213 self.assertEqual(0, testNaked("I32", 0))
214 self.assertEqual(1, testNaked("I32", 1))
215 self.assertEqual(15000, testNaked("I32", 15000))
216 self.assertEqual(0xffff, testNaked("I32", 0xffff))
217 self.assertEqual(-1, testNaked("I32", -1))
218 self.assertEqual(-15000, testNaked("I32", -15000))
219 self.assertEqual(-0xffff, testNaked("I32", -0xffff))
220 self.assertEqual(2147483647, testNaked("I32", 2147483647))
221 self.assertEqual(-2147483647, testNaked("I32", -2147483647))
222
223 self.assertEqual(0, testField("I32", 0))
224 self.assertEqual(1, testField("I32", 1))
225 self.assertEqual(7, testField("I32", 7))
226 self.assertEqual(150, testField("I32", 150))
227 self.assertEqual(15000, testField("I32", 15000))
228 self.assertEqual(31337, testField("I32", 31337))
229 self.assertEqual(0xffff, testField("I32", 0xffff))
230 self.assertEqual(0xffffff, testField("I32", 0xffffff))
231 self.assertEqual(-1, testField("I32", -1))
232 self.assertEqual(-7, testField("I32", -7))
233 self.assertEqual(-150, testField("I32", -150))
234 self.assertEqual(-15000, testField("I32", -15000))
235 self.assertEqual(-0xffff, testField("I32", -0xffff))
236 self.assertEqual(-0xffffff, testField("I32", -0xffffff))
237
238 self.assertEqual(9223372036854775807, testNaked("I64", 9223372036854775807))
239 self.assertEqual(-9223372036854775807, testNaked("I64", -9223372036854775807))
240 self.assertEqual(-0, testNaked("I64", 0))
241
242 self.assertEqual(True, testNaked("Bool", True))
243 self.assertEqual(3.14159261, testNaked("Double", 3.14159261))
244 self.assertEqual("hello thrift", testNaked("String", "hello thrift"))
245 self.assertEqual(True, testField('Bool', True))
246 self.assertEqual(3.1415926, testNaked("Double", 3.1415926))
247 self.assertEqual("hello thrift", testNaked("String", "hello thrift"))
248
249 TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
250 test_data = [("short message name", TMessageType['T_CALL'], 0),
251 ("1", TMessageType['T_REPLY'], 12345),
252 ("loooooooooooooooooooooooooooooooooong", TMessageType['T_EXCEPTION'], 1 << 16),
253 ("one way push", TMessageType['T_ONEWAY'], 12),
254 ("Janky", TMessageType['T_CALL'], 0)]
255
256 for dt in test_data:
257 result = testMessage(dt)
258 self.assertEqual(result[0], dt[0])
259 self.assertEqual(result[1], dt[1])
260 self.assertEqual(result[2], dt[2])
261
262 except Exception as e:
263 print("Assertion fail")
264 raise e
265
bwangelme58000cc2023-11-06 12:21:38 +0800266 def test_TBinaryProtocol_no_strict_write_read(self):
267 TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
268 test_data = [("short message name", TMessageType['T_CALL'], 0),
269 ("1", TMessageType['T_REPLY'], 12345),
270 ("loooooooooooooooooooooooooooooooooong", TMessageType['T_EXCEPTION'], 1 << 16),
271 ("one way push", TMessageType['T_ONEWAY'], 12),
272 ("Janky", TMessageType['T_CALL'], 0)]
273
274 try:
275 for dt in test_data:
276 result = testMessage(dt, strict=False)
277 self.assertEqual(result[0], dt[0])
278 self.assertEqual(result[1], dt[1])
279 self.assertEqual(result[2], dt[2])
280 except Exception as e:
281 print("Assertion fail")
282 raise e
283
Zezeng Wangc3728122020-04-27 15:48:19 +0800284
285if __name__ == '__main__':
286 unittest.main()