blob: 1d6af8ee2ad57879d30c50e79b2dcc9b38603f6e [file] [log] [blame]
zeshuai00726e6c842020-05-06 14:37:43 +08001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import _import_local_thrift # noqa
21from thrift.protocol import TCompactProtocol
22from thrift.transport import TTransport
23import unittest
24
25CLEAR = 0
26FIELD_WRITE = 1
27VALUE_WRITE = 2
28CONTAINER_WRITE = 3
29BOOL_WRITE = 4
30FIELD_READ = 5
31CONTAINER_READ = 6
32VALUE_READ = 7
33BOOL_READ = 8
34
35
36def testNaked(type, data):
37 buf = TTransport.TMemoryBuffer()
38 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
39 protocol = TCompactProtocol.TCompactProtocol(transport)
40
41 if type.capitalize() == 'Byte':
42 protocol.state = VALUE_WRITE
43 protocol.writeByte(data)
44
45 elif type.capitalize() == 'I16':
46 protocol.state = CONTAINER_WRITE
47 protocol.writeI16(data)
48
49 elif type.capitalize() == 'I32':
50 protocol.state = CONTAINER_WRITE
51 protocol.writeI32(data)
52
53 elif type.capitalize() == 'I64':
54 protocol.state = CONTAINER_WRITE
55 protocol.writeI64(data)
56
57 elif type.capitalize() == 'String':
58 protocol.state = CONTAINER_WRITE
59 protocol.writeString(data)
60
61 elif type.capitalize() == 'Double':
62 protocol.state = VALUE_WRITE
63 protocol.writeDouble(data)
64
65 elif type.capitalize() == 'Binary':
66 protocol.state = FIELD_WRITE
67 protocol.writeBinary(data)
68
69 elif type.capitalize() == 'Bool':
70 protocol.state = CONTAINER_WRITE
71 protocol.writeBool(True)
72
73 transport.flush()
74 data_r = buf.getvalue()
75 buf = TTransport.TMemoryBuffer(data_r)
76 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
77 protocol = TCompactProtocol.TCompactProtocol(transport)
78 if type.capitalize() == 'Byte':
79 protocol.state = VALUE_READ
80 return protocol.readByte()
81
82 elif type.capitalize() == 'I16':
83 protocol.state = CONTAINER_READ
84 return protocol.readI16()
85
86 elif type.capitalize() == 'I32':
87 protocol.state = CONTAINER_READ
88 return protocol.readI32()
89
90 elif type.capitalize() == 'I64':
91 protocol.state = CONTAINER_READ
92 return protocol.readI64()
93
94 elif type.capitalize() == 'String':
95 protocol.state = VALUE_READ
96 return protocol.readString()
97
98 elif type.capitalize() == 'Double':
99 protocol.state = VALUE_READ
100 return protocol.readDouble()
101
102 elif type.capitalize() == 'Binary':
103 protocol.state = FIELD_READ
104 return protocol.readBinary()
105
106 elif type.capitalize() == 'Bool':
107 protocol.state = CONTAINER_READ
108 return protocol.readBool()
109
110
111def testField(type, data):
112 TType = {"Bool": 2, "Byte": 3, "Binary": 5, "I16": 6, "I32": 8, "I64": 10, "Double": 11, "String": 12}
113 buf = TTransport.TMemoryBuffer()
114 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
115 protocol = TCompactProtocol.TCompactProtocol(transport)
116 protocol.writeStructBegin('struct')
117 protocol.writeFieldBegin("field", TType[type.capitalize()], 10)
118 if type.capitalize() == 'Byte':
119 protocol.writeByte(data)
120
121 elif type.capitalize() == 'I16':
122 protocol.writeI16(data)
123
124 elif type.capitalize() == 'I32':
125 protocol.writeI32(data)
126
127 elif type.capitalize() == 'I64':
128 protocol.writeI64(data)
129
130 elif type.capitalize() == 'String':
131 protocol.writeString(data)
132
133 elif type.capitalize() == 'Double':
134 protocol.writeDouble(data)
135
136 elif type.capitalize() == 'Binary':
137 protocol.writeBinary(data)
138
139 elif type.capitalize() == 'Bool':
140 protocol.writeBool(data)
141
142 protocol.writeFieldEnd()
143 protocol.writeStructEnd()
144
145 transport.flush()
146 data_r = buf.getvalue()
147
148 buf = TTransport.TMemoryBuffer(data_r)
149 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
150 protocol = TCompactProtocol.TCompactProtocol(transport)
151 protocol.readStructBegin()
152 protocol.readFieldBegin()
153 if type.capitalize() == 'Byte':
154 return protocol.readByte()
155
156 elif type.capitalize() == 'I16':
157 return protocol.readI16()
158
159 elif type.capitalize() == 'I32':
160 return protocol.readI32()
161
162 elif type.capitalize() == 'I64':
163 return protocol.readI32()
164
165 elif type.capitalize() == 'String':
166 return protocol.readString()
167
168 elif type.capitalize() == 'Double':
169 return protocol.readDouble()
170
171 elif type.capitalize() == 'Binary':
172 return protocol.readBinary()
173
174 elif type.capitalize() == 'Bool':
175 return protocol.readBool()
176
177 protocol.readFieldEnd()
178 protocol.readStructEnd()
179
180
181def testMessage(data):
182 message = {}
183 message['name'] = data[0]
184 message['type'] = data[1]
185 message['seqid'] = data[2]
186
187 buf = TTransport.TMemoryBuffer()
188 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
189 protocol = TCompactProtocol.TCompactProtocol(transport)
190 protocol.writeMessageBegin(message['name'], message['type'], message['seqid'])
191 protocol.writeMessageEnd()
192
193 transport.flush()
194 data_r = buf.getvalue()
195
196 buf = TTransport.TMemoryBuffer(data_r)
197 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
198 protocol = TCompactProtocol.TCompactProtocol(transport)
199 result = protocol.readMessageBegin()
200 protocol.readMessageEnd()
201 return result
202
203
204class TestTCompactProtocol(unittest.TestCase):
205
206 def __init__(self, *args, **kwargs):
207 unittest.TestCase.__init__(self, *args, **kwargs)
208
209 def test_TCompactProtocol_write_read(self):
210 try:
211 testNaked('Byte', 123)
212 for i in range(0, 128):
213 self.assertEqual(i, testField('Byte', i))
214 self.assertEqual(-i, testField('Byte', -i))
215
216 self.assertEqual(0, testNaked("I16", 0))
217 self.assertEqual(1, testNaked("I16", 1))
218 self.assertEqual(15000, testNaked("I16", 15000))
219 self.assertEqual(0x7fff, testNaked('I16', 0x7fff))
220 self.assertEqual(-1, testNaked('I16', -1))
221 self.assertEqual(-15000, testNaked('I16', -15000))
222 self.assertEqual(-0x7fff, testNaked('I16', -0x7fff))
223 self.assertEqual(32767, testNaked('I16', 32767))
224
225 self.assertEqual(0, testField('I16', 0))
226 self.assertEqual(1, testField('I16', 1))
227 self.assertEqual(7, testField('I16', 7))
228 self.assertEqual(150, testField('I16', 150))
229 self.assertEqual(15000, testField('I16', 15000))
230 self.assertEqual(0x7fff, testField('I16', 0x7fff))
231 self.assertEqual(-1, testField('I16', -1))
232 self.assertEqual(-7, testField('I16', -7))
233 self.assertEqual(-150, testField('I16', -150))
234 self.assertEqual(-15000, testField('I16', -15000))
235 self.assertEqual(-0xfff, testField('I16', -0xfff))
236
237 self.assertEqual(0, testNaked('I32', 0))
238 self.assertEqual(1, testNaked('I32', 1))
239 self.assertEqual(15000, testNaked('I32', 15000))
240 self.assertEqual(0xfff, testNaked('I32', 0xfff))
241 self.assertEqual(-1, testNaked('I32', -1))
242 self.assertEqual(-15000, testNaked('I32', -15000))
243 self.assertEqual(-0xfff, testNaked('I32', -0xfff))
244 self.assertEqual(2147483647, testNaked('I32', 2147483647))
245 self.assertEqual(-2147483647, testNaked('I32', -2147483647))
246
247 self.assertEqual(0, testField('I32', 0))
248 self.assertEqual(1, testField('I32', 1))
249 self.assertEqual(7, testField('I32', 7))
250 self.assertEqual(150, testField('I32', 150))
251 self.assertEqual(15000, testField('I32', 15000))
252 self.assertEqual(31337, testField('I32', 31337))
253 self.assertEqual(0xffff, testField('I32', 0xffff))
254 self.assertEqual(0xffffff, testField('I32', 0xffffff))
255 self.assertEqual(-1, testField('I32', -1))
256 self.assertEqual(-7, testField('I32', -7))
257 self.assertEqual(-150, testField('I32', -150))
258 self.assertEqual(-15000, testField('I32', -15000))
259 self.assertEqual(-0xffff, testField('I32', -0xffff))
260 self.assertEqual(-0xffffff, testField('I32', -0xffffff))
261
262 self.assertEqual(9223372036854775807, testNaked("I64", 9223372036854775807))
263 self.assertEqual(-9223372036854775807, testNaked('I64', -9223372036854775807))
264 self.assertEqual(-0, testNaked('I64', 0))
265 self.assertEqual(True, testNaked('Bool', True))
266 self.assertEqual(3.14159261, testNaked('Double', 3.14159261))
267 self.assertEqual("hello thrift", testNaked('String', "hello thrift"))
268 self.assertEqual(True, testField('Bool', True))
269 self.assertEqual(3.14159261, testField('Double', 3.14159261))
270 self.assertEqual("hello thrift", testField('String', "hello thrift"))
271 TMessage = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
272 test_data = [("short message name", TMessage["T_CALL"], 0),
273 ("1", TMessage["T_REPLY"], 12345),
274 ("loooooooooooooooooooong", TMessage["T_EXCEPTION"], 1 << 16),
275 ("one way push", TMessage["T_ONEWAY"], 12),
276 ("JANKY", TMessage["T_CALL"], 0)]
277 for dt in test_data:
278 result = testMessage(dt)
279 self.assertEqual(result[0], dt[0])
280 self.assertEqual(result[1], dt[1])
281 self.assertEqual(result[2], dt[2])
282 except Exception as e:
283 print("Assertion fail")
284 raise e
285
286
287if __name__ == "__main__":
288 unittest.main()