blob: 34357024bd99609034ab1ae34001e330505eec5a [file] [log] [blame]
zeshuai00726e6c842020-05-06 14:37:43 +08001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import _import_local_thrift # noqa
21from thrift.protocol import TCompactProtocol
22from thrift.transport import TTransport
23import unittest
Carel Combrinka715bdf2025-10-30 07:44:21 +010024import uuid
zeshuai00726e6c842020-05-06 14:37:43 +080025
26CLEAR = 0
27FIELD_WRITE = 1
28VALUE_WRITE = 2
29CONTAINER_WRITE = 3
30BOOL_WRITE = 4
31FIELD_READ = 5
32CONTAINER_READ = 6
33VALUE_READ = 7
34BOOL_READ = 8
35
36
37def testNaked(type, data):
38 buf = TTransport.TMemoryBuffer()
39 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
40 protocol = TCompactProtocol.TCompactProtocol(transport)
41
42 if type.capitalize() == 'Byte':
43 protocol.state = VALUE_WRITE
44 protocol.writeByte(data)
45
46 elif type.capitalize() == 'I16':
47 protocol.state = CONTAINER_WRITE
48 protocol.writeI16(data)
49
50 elif type.capitalize() == 'I32':
51 protocol.state = CONTAINER_WRITE
52 protocol.writeI32(data)
53
54 elif type.capitalize() == 'I64':
55 protocol.state = CONTAINER_WRITE
56 protocol.writeI64(data)
57
58 elif type.capitalize() == 'String':
59 protocol.state = CONTAINER_WRITE
60 protocol.writeString(data)
61
62 elif type.capitalize() == 'Double':
63 protocol.state = VALUE_WRITE
64 protocol.writeDouble(data)
65
66 elif type.capitalize() == 'Binary':
67 protocol.state = FIELD_WRITE
68 protocol.writeBinary(data)
69
70 elif type.capitalize() == 'Bool':
71 protocol.state = CONTAINER_WRITE
72 protocol.writeBool(True)
73
Carel Combrinka715bdf2025-10-30 07:44:21 +010074 if type.capitalize() == 'Uuid':
75 protocol.state = CONTAINER_WRITE
76 protocol.writeUuid(data)
77
zeshuai00726e6c842020-05-06 14:37:43 +080078 transport.flush()
79 data_r = buf.getvalue()
80 buf = TTransport.TMemoryBuffer(data_r)
81 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
82 protocol = TCompactProtocol.TCompactProtocol(transport)
83 if type.capitalize() == 'Byte':
84 protocol.state = VALUE_READ
85 return protocol.readByte()
86
87 elif type.capitalize() == 'I16':
88 protocol.state = CONTAINER_READ
89 return protocol.readI16()
90
91 elif type.capitalize() == 'I32':
92 protocol.state = CONTAINER_READ
93 return protocol.readI32()
94
95 elif type.capitalize() == 'I64':
96 protocol.state = CONTAINER_READ
97 return protocol.readI64()
98
99 elif type.capitalize() == 'String':
100 protocol.state = VALUE_READ
101 return protocol.readString()
102
103 elif type.capitalize() == 'Double':
104 protocol.state = VALUE_READ
105 return protocol.readDouble()
106
107 elif type.capitalize() == 'Binary':
108 protocol.state = FIELD_READ
109 return protocol.readBinary()
110
111 elif type.capitalize() == 'Bool':
112 protocol.state = CONTAINER_READ
113 return protocol.readBool()
114
Carel Combrinka715bdf2025-10-30 07:44:21 +0100115 if type.capitalize() == 'Uuid':
116 protocol.state = CONTAINER_READ
117 return protocol.readUuid()
118
zeshuai00726e6c842020-05-06 14:37:43 +0800119
120def testField(type, data):
Carel Combrinka715bdf2025-10-30 07:44:21 +0100121 TType = {"Bool": 2, "Byte": 3, "Binary": 5, "I16": 6, "I32": 8, "I64": 10, "Double": 11, "String": 12, "Uuid": 13}
zeshuai00726e6c842020-05-06 14:37:43 +0800122 buf = TTransport.TMemoryBuffer()
123 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
124 protocol = TCompactProtocol.TCompactProtocol(transport)
125 protocol.writeStructBegin('struct')
126 protocol.writeFieldBegin("field", TType[type.capitalize()], 10)
127 if type.capitalize() == 'Byte':
128 protocol.writeByte(data)
129
130 elif type.capitalize() == 'I16':
131 protocol.writeI16(data)
132
133 elif type.capitalize() == 'I32':
134 protocol.writeI32(data)
135
136 elif type.capitalize() == 'I64':
137 protocol.writeI64(data)
138
139 elif type.capitalize() == 'String':
140 protocol.writeString(data)
141
142 elif type.capitalize() == 'Double':
143 protocol.writeDouble(data)
144
145 elif type.capitalize() == 'Binary':
146 protocol.writeBinary(data)
147
148 elif type.capitalize() == 'Bool':
149 protocol.writeBool(data)
150
Carel Combrinka715bdf2025-10-30 07:44:21 +0100151 if type.capitalize() == 'Uuid':
152 protocol.writeUuid(data)
153
zeshuai00726e6c842020-05-06 14:37:43 +0800154 protocol.writeFieldEnd()
155 protocol.writeStructEnd()
156
157 transport.flush()
158 data_r = buf.getvalue()
159
160 buf = TTransport.TMemoryBuffer(data_r)
161 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
162 protocol = TCompactProtocol.TCompactProtocol(transport)
163 protocol.readStructBegin()
164 protocol.readFieldBegin()
165 if type.capitalize() == 'Byte':
166 return protocol.readByte()
167
168 elif type.capitalize() == 'I16':
169 return protocol.readI16()
170
171 elif type.capitalize() == 'I32':
172 return protocol.readI32()
173
174 elif type.capitalize() == 'I64':
175 return protocol.readI32()
176
177 elif type.capitalize() == 'String':
178 return protocol.readString()
179
180 elif type.capitalize() == 'Double':
181 return protocol.readDouble()
182
183 elif type.capitalize() == 'Binary':
184 return protocol.readBinary()
185
186 elif type.capitalize() == 'Bool':
187 return protocol.readBool()
188
Carel Combrinka715bdf2025-10-30 07:44:21 +0100189 if type.capitalize() == 'Uuid':
190 return protocol.readUuid()
191
zeshuai00726e6c842020-05-06 14:37:43 +0800192 protocol.readFieldEnd()
193 protocol.readStructEnd()
194
195
196def testMessage(data):
197 message = {}
198 message['name'] = data[0]
199 message['type'] = data[1]
200 message['seqid'] = data[2]
201
202 buf = TTransport.TMemoryBuffer()
203 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
204 protocol = TCompactProtocol.TCompactProtocol(transport)
205 protocol.writeMessageBegin(message['name'], message['type'], message['seqid'])
206 protocol.writeMessageEnd()
207
208 transport.flush()
209 data_r = buf.getvalue()
210
211 buf = TTransport.TMemoryBuffer(data_r)
212 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
213 protocol = TCompactProtocol.TCompactProtocol(transport)
214 result = protocol.readMessageBegin()
215 protocol.readMessageEnd()
216 return result
217
218
219class TestTCompactProtocol(unittest.TestCase):
220
221 def __init__(self, *args, **kwargs):
222 unittest.TestCase.__init__(self, *args, **kwargs)
223
224 def test_TCompactProtocol_write_read(self):
225 try:
226 testNaked('Byte', 123)
227 for i in range(0, 128):
228 self.assertEqual(i, testField('Byte', i))
229 self.assertEqual(-i, testField('Byte', -i))
230
231 self.assertEqual(0, testNaked("I16", 0))
232 self.assertEqual(1, testNaked("I16", 1))
233 self.assertEqual(15000, testNaked("I16", 15000))
234 self.assertEqual(0x7fff, testNaked('I16', 0x7fff))
235 self.assertEqual(-1, testNaked('I16', -1))
236 self.assertEqual(-15000, testNaked('I16', -15000))
237 self.assertEqual(-0x7fff, testNaked('I16', -0x7fff))
238 self.assertEqual(32767, testNaked('I16', 32767))
239
240 self.assertEqual(0, testField('I16', 0))
241 self.assertEqual(1, testField('I16', 1))
242 self.assertEqual(7, testField('I16', 7))
243 self.assertEqual(150, testField('I16', 150))
244 self.assertEqual(15000, testField('I16', 15000))
245 self.assertEqual(0x7fff, testField('I16', 0x7fff))
246 self.assertEqual(-1, testField('I16', -1))
247 self.assertEqual(-7, testField('I16', -7))
248 self.assertEqual(-150, testField('I16', -150))
249 self.assertEqual(-15000, testField('I16', -15000))
250 self.assertEqual(-0xfff, testField('I16', -0xfff))
251
252 self.assertEqual(0, testNaked('I32', 0))
253 self.assertEqual(1, testNaked('I32', 1))
254 self.assertEqual(15000, testNaked('I32', 15000))
255 self.assertEqual(0xfff, testNaked('I32', 0xfff))
256 self.assertEqual(-1, testNaked('I32', -1))
257 self.assertEqual(-15000, testNaked('I32', -15000))
258 self.assertEqual(-0xfff, testNaked('I32', -0xfff))
259 self.assertEqual(2147483647, testNaked('I32', 2147483647))
260 self.assertEqual(-2147483647, testNaked('I32', -2147483647))
261
262 self.assertEqual(0, testField('I32', 0))
263 self.assertEqual(1, testField('I32', 1))
264 self.assertEqual(7, testField('I32', 7))
265 self.assertEqual(150, testField('I32', 150))
266 self.assertEqual(15000, testField('I32', 15000))
267 self.assertEqual(31337, testField('I32', 31337))
268 self.assertEqual(0xffff, testField('I32', 0xffff))
269 self.assertEqual(0xffffff, testField('I32', 0xffffff))
270 self.assertEqual(-1, testField('I32', -1))
271 self.assertEqual(-7, testField('I32', -7))
272 self.assertEqual(-150, testField('I32', -150))
273 self.assertEqual(-15000, testField('I32', -15000))
274 self.assertEqual(-0xffff, testField('I32', -0xffff))
275 self.assertEqual(-0xffffff, testField('I32', -0xffffff))
276
277 self.assertEqual(9223372036854775807, testNaked("I64", 9223372036854775807))
278 self.assertEqual(-9223372036854775807, testNaked('I64', -9223372036854775807))
279 self.assertEqual(-0, testNaked('I64', 0))
280 self.assertEqual(True, testNaked('Bool', True))
281 self.assertEqual(3.14159261, testNaked('Double', 3.14159261))
282 self.assertEqual("hello thrift", testNaked('String', "hello thrift"))
283 self.assertEqual(True, testField('Bool', True))
284 self.assertEqual(3.14159261, testField('Double', 3.14159261))
285 self.assertEqual("hello thrift", testField('String', "hello thrift"))
Carel Combrinka715bdf2025-10-30 07:44:21 +0100286 self.assertEqual(uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}'), testNaked("Uuid", uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')))
287 self.assertEqual(uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}'), testField("Uuid", uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')))
288
zeshuai00726e6c842020-05-06 14:37:43 +0800289 TMessage = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
290 test_data = [("short message name", TMessage["T_CALL"], 0),
291 ("1", TMessage["T_REPLY"], 12345),
292 ("loooooooooooooooooooong", TMessage["T_EXCEPTION"], 1 << 16),
293 ("one way push", TMessage["T_ONEWAY"], 12),
294 ("JANKY", TMessage["T_CALL"], 0)]
295 for dt in test_data:
296 result = testMessage(dt)
297 self.assertEqual(result[0], dt[0])
298 self.assertEqual(result[1], dt[1])
299 self.assertEqual(result[2], dt[2])
300 except Exception as e:
301 print("Assertion fail")
302 raise e
303
304
305if __name__ == "__main__":
306 unittest.main()