blob: 16fd9be28c33a73c96ae33ec7cabd23dec29315c [file] [log] [blame]
Roger Meierf4eec7a2011-09-11 18:16:21 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090020from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
David Reissabafd792010-09-27 17:28:15 +000021from struct import pack, unpack
22
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090023from ..compat import binary_to_str, str_to_binary
24
David Reissabafd792010-09-27 17:28:15 +000025__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
26
27CLEAR = 0
28FIELD_WRITE = 1
29VALUE_WRITE = 2
30CONTAINER_WRITE = 3
31BOOL_WRITE = 4
32FIELD_READ = 5
33CONTAINER_READ = 6
34VALUE_READ = 7
35BOOL_READ = 8
36
Bryan Duxbury69720412012-01-03 17:32:30 +000037
David Reissabafd792010-09-27 17:28:15 +000038def make_helper(v_from, container):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090039 def helper(func):
40 def nested(self, *args, **kwargs):
41 assert self.state in (v_from, container), (self.state, v_from, container)
42 return func(self, *args, **kwargs)
43 return nested
44 return helper
David Reissabafd792010-09-27 17:28:15 +000045writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
46reader = make_helper(VALUE_READ, CONTAINER_READ)
47
Bryan Duxbury69720412012-01-03 17:32:30 +000048
David Reissabafd792010-09-27 17:28:15 +000049def makeZigZag(n, bits):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090050 checkIntegerLimits(n, bits)
51 return (n << 1) ^ (n >> (bits - 1))
David Reissabafd792010-09-27 17:28:15 +000052
Bryan Duxbury69720412012-01-03 17:32:30 +000053
David Reissabafd792010-09-27 17:28:15 +000054def fromZigZag(n):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090055 return (n >> 1) ^ -(n & 1)
David Reissabafd792010-09-27 17:28:15 +000056
Bryan Duxbury69720412012-01-03 17:32:30 +000057
David Reissabafd792010-09-27 17:28:15 +000058def writeVarint(trans, n):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090059 out = bytearray()
60 while True:
61 if n & ~0x7f == 0:
62 out.append(n)
63 break
64 else:
65 out.append((n & 0xff) | 0x80)
66 n = n >> 7
67 trans.write(bytes(out))
David Reissabafd792010-09-27 17:28:15 +000068
Bryan Duxbury69720412012-01-03 17:32:30 +000069
David Reissabafd792010-09-27 17:28:15 +000070def readVarint(trans):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090071 result = 0
72 shift = 0
73 while True:
74 x = trans.readAll(1)
75 byte = ord(x)
76 result |= (byte & 0x7f) << shift
77 if byte >> 7 == 0:
78 return result
79 shift += 7
David Reissabafd792010-09-27 17:28:15 +000080
Bryan Duxbury69720412012-01-03 17:32:30 +000081
Nobuaki Sukegawab9c859a2015-12-21 01:10:25 +090082class CompactType(object):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090083 STOP = 0x00
84 TRUE = 0x01
85 FALSE = 0x02
86 BYTE = 0x03
87 I16 = 0x04
88 I32 = 0x05
89 I64 = 0x06
90 DOUBLE = 0x07
91 BINARY = 0x08
92 LIST = 0x09
93 SET = 0x0A
94 MAP = 0x0B
95 STRUCT = 0x0C
David Reissabafd792010-09-27 17:28:15 +000096
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090097CTYPES = {
98 TType.STOP: CompactType.STOP,
99 TType.BOOL: CompactType.TRUE, # used for collection
100 TType.BYTE: CompactType.BYTE,
101 TType.I16: CompactType.I16,
102 TType.I32: CompactType.I32,
103 TType.I64: CompactType.I64,
104 TType.DOUBLE: CompactType.DOUBLE,
105 TType.STRING: CompactType.BINARY,
106 TType.STRUCT: CompactType.STRUCT,
107 TType.LIST: CompactType.LIST,
108 TType.SET: CompactType.SET,
109 TType.MAP: CompactType.MAP,
110}
David Reissabafd792010-09-27 17:28:15 +0000111
112TTYPES = {}
113for k, v in CTYPES.items():
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900114 TTYPES[v] = k
David Reissabafd792010-09-27 17:28:15 +0000115TTYPES[CompactType.FALSE] = TType.BOOL
116del k
117del v
118
Bryan Duxbury69720412012-01-03 17:32:30 +0000119
David Reissabafd792010-09-27 17:28:15 +0000120class TCompactProtocol(TProtocolBase):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900121 """Compact implementation of the Thrift protocol driver."""
David Reissabafd792010-09-27 17:28:15 +0000122
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900123 PROTOCOL_ID = 0x82
124 VERSION = 1
125 VERSION_MASK = 0x1f
126 TYPE_MASK = 0xe0
127 TYPE_BITS = 0x07
128 TYPE_SHIFT_AMOUNT = 5
David Reissabafd792010-09-27 17:28:15 +0000129
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900130 def __init__(self, trans,
131 string_length_limit=None,
132 container_length_limit=None):
133 TProtocolBase.__init__(self, trans)
134 self.state = CLEAR
135 self.__last_fid = 0
136 self.__bool_fid = None
137 self.__bool_value = None
138 self.__structs = []
139 self.__containers = []
140 self.string_length_limit = string_length_limit
141 self.container_length_limit = container_length_limit
Nobuaki Sukegawa7b545b52016-01-11 13:46:04 +0900142
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900143 def _check_string_length(self, length):
144 self._check_length(self.string_length_limit, length)
Nobuaki Sukegawa7b545b52016-01-11 13:46:04 +0900145
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900146 def _check_container_length(self, length):
147 self._check_length(self.container_length_limit, length)
David Reissabafd792010-09-27 17:28:15 +0000148
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900149 def __writeVarint(self, n):
150 writeVarint(self.trans, n)
David Reissabafd792010-09-27 17:28:15 +0000151
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900152 def writeMessageBegin(self, name, type, seqid):
153 assert self.state == CLEAR
154 self.__writeUByte(self.PROTOCOL_ID)
155 self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
156 self.__writeVarint(seqid)
157 self.__writeBinary(str_to_binary(name))
158 self.state = VALUE_WRITE
David Reissabafd792010-09-27 17:28:15 +0000159
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900160 def writeMessageEnd(self):
161 assert self.state == VALUE_WRITE
162 self.state = CLEAR
David Reissabafd792010-09-27 17:28:15 +0000163
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900164 def writeStructBegin(self, name):
165 assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
166 self.__structs.append((self.state, self.__last_fid))
167 self.state = FIELD_WRITE
168 self.__last_fid = 0
David Reissabafd792010-09-27 17:28:15 +0000169
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900170 def writeStructEnd(self):
171 assert self.state == FIELD_WRITE
172 self.state, self.__last_fid = self.__structs.pop()
David Reissabafd792010-09-27 17:28:15 +0000173
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900174 def writeFieldStop(self):
175 self.__writeByte(0)
David Reissabafd792010-09-27 17:28:15 +0000176
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900177 def __writeFieldHeader(self, type, fid):
178 delta = fid - self.__last_fid
179 if 0 < delta <= 15:
180 self.__writeUByte(delta << 4 | type)
181 else:
182 self.__writeByte(type)
183 self.__writeI16(fid)
184 self.__last_fid = fid
David Reissabafd792010-09-27 17:28:15 +0000185
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900186 def writeFieldBegin(self, name, type, fid):
187 assert self.state == FIELD_WRITE, self.state
188 if type == TType.BOOL:
189 self.state = BOOL_WRITE
190 self.__bool_fid = fid
191 else:
192 self.state = VALUE_WRITE
193 self.__writeFieldHeader(CTYPES[type], fid)
David Reissabafd792010-09-27 17:28:15 +0000194
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900195 def writeFieldEnd(self):
196 assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
197 self.state = FIELD_WRITE
David Reissabafd792010-09-27 17:28:15 +0000198
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900199 def __writeUByte(self, byte):
200 self.trans.write(pack('!B', byte))
David Reissabafd792010-09-27 17:28:15 +0000201
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900202 def __writeByte(self, byte):
203 self.trans.write(pack('!b', byte))
David Reissabafd792010-09-27 17:28:15 +0000204
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900205 def __writeI16(self, i16):
206 self.__writeVarint(makeZigZag(i16, 16))
David Reissabafd792010-09-27 17:28:15 +0000207
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900208 def __writeSize(self, i32):
209 self.__writeVarint(i32)
David Reissabafd792010-09-27 17:28:15 +0000210
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900211 def writeCollectionBegin(self, etype, size):
212 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
213 if size <= 14:
214 self.__writeUByte(size << 4 | CTYPES[etype])
215 else:
216 self.__writeUByte(0xf0 | CTYPES[etype])
217 self.__writeSize(size)
218 self.__containers.append(self.state)
219 self.state = CONTAINER_WRITE
220 writeSetBegin = writeCollectionBegin
221 writeListBegin = writeCollectionBegin
David Reissabafd792010-09-27 17:28:15 +0000222
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900223 def writeMapBegin(self, ktype, vtype, size):
224 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
225 if size == 0:
226 self.__writeByte(0)
227 else:
228 self.__writeSize(size)
229 self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
230 self.__containers.append(self.state)
231 self.state = CONTAINER_WRITE
David Reissabafd792010-09-27 17:28:15 +0000232
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900233 def writeCollectionEnd(self):
234 assert self.state == CONTAINER_WRITE, self.state
235 self.state = self.__containers.pop()
236 writeMapEnd = writeCollectionEnd
237 writeSetEnd = writeCollectionEnd
238 writeListEnd = writeCollectionEnd
David Reissabafd792010-09-27 17:28:15 +0000239
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900240 def writeBool(self, bool):
241 if self.state == BOOL_WRITE:
242 if bool:
243 ctype = CompactType.TRUE
244 else:
245 ctype = CompactType.FALSE
246 self.__writeFieldHeader(ctype, self.__bool_fid)
247 elif self.state == CONTAINER_WRITE:
248 if bool:
249 self.__writeByte(CompactType.TRUE)
250 else:
251 self.__writeByte(CompactType.FALSE)
252 else:
253 raise AssertionError("Invalid state in compact protocol")
David Reissabafd792010-09-27 17:28:15 +0000254
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900255 writeByte = writer(__writeByte)
256 writeI16 = writer(__writeI16)
David Reissabafd792010-09-27 17:28:15 +0000257
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900258 @writer
259 def writeI32(self, i32):
260 self.__writeVarint(makeZigZag(i32, 32))
David Reissabafd792010-09-27 17:28:15 +0000261
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900262 @writer
263 def writeI64(self, i64):
264 self.__writeVarint(makeZigZag(i64, 64))
David Reissabafd792010-09-27 17:28:15 +0000265
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900266 @writer
267 def writeDouble(self, dub):
268 self.trans.write(pack('<d', dub))
David Reissabafd792010-09-27 17:28:15 +0000269
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900270 def __writeBinary(self, s):
271 self.__writeSize(len(s))
272 self.trans.write(s)
273 writeBinary = writer(__writeBinary)
David Reissabafd792010-09-27 17:28:15 +0000274
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900275 def readFieldBegin(self):
276 assert self.state == FIELD_READ, self.state
277 type = self.__readUByte()
278 if type & 0x0f == TType.STOP:
279 return (None, 0, 0)
280 delta = type >> 4
281 if delta == 0:
282 fid = self.__readI16()
283 else:
284 fid = self.__last_fid + delta
285 self.__last_fid = fid
286 type = type & 0x0f
287 if type == CompactType.TRUE:
288 self.state = BOOL_READ
289 self.__bool_value = True
290 elif type == CompactType.FALSE:
291 self.state = BOOL_READ
292 self.__bool_value = False
293 else:
294 self.state = VALUE_READ
295 return (None, self.__getTType(type), fid)
David Reissabafd792010-09-27 17:28:15 +0000296
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900297 def readFieldEnd(self):
298 assert self.state in (VALUE_READ, BOOL_READ), self.state
299 self.state = FIELD_READ
David Reissabafd792010-09-27 17:28:15 +0000300
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900301 def __readUByte(self):
302 result, = unpack('!B', self.trans.readAll(1))
303 return result
David Reissabafd792010-09-27 17:28:15 +0000304
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900305 def __readByte(self):
306 result, = unpack('!b', self.trans.readAll(1))
307 return result
David Reissabafd792010-09-27 17:28:15 +0000308
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900309 def __readVarint(self):
310 return readVarint(self.trans)
David Reissabafd792010-09-27 17:28:15 +0000311
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900312 def __readZigZag(self):
313 return fromZigZag(self.__readVarint())
David Reissabafd792010-09-27 17:28:15 +0000314
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900315 def __readSize(self):
316 result = self.__readVarint()
317 if result < 0:
318 raise TProtocolException("Length < 0")
319 return result
David Reissabafd792010-09-27 17:28:15 +0000320
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900321 def readMessageBegin(self):
322 assert self.state == CLEAR
323 proto_id = self.__readUByte()
324 if proto_id != self.PROTOCOL_ID:
325 raise TProtocolException(TProtocolException.BAD_VERSION,
326 'Bad protocol id in the message: %d' % proto_id)
327 ver_type = self.__readUByte()
328 type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
329 version = ver_type & self.VERSION_MASK
330 if version != self.VERSION:
331 raise TProtocolException(TProtocolException.BAD_VERSION,
332 'Bad version: %d (expect %d)' % (version, self.VERSION))
333 seqid = self.__readVarint()
334 name = binary_to_str(self.__readBinary())
335 return (name, type, seqid)
David Reissabafd792010-09-27 17:28:15 +0000336
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900337 def readMessageEnd(self):
338 assert self.state == CLEAR
339 assert len(self.__structs) == 0
David Reissabafd792010-09-27 17:28:15 +0000340
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900341 def readStructBegin(self):
342 assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
343 self.__structs.append((self.state, self.__last_fid))
344 self.state = FIELD_READ
345 self.__last_fid = 0
David Reissabafd792010-09-27 17:28:15 +0000346
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900347 def readStructEnd(self):
348 assert self.state == FIELD_READ
349 self.state, self.__last_fid = self.__structs.pop()
David Reissabafd792010-09-27 17:28:15 +0000350
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900351 def readCollectionBegin(self):
352 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
353 size_type = self.__readUByte()
354 size = size_type >> 4
355 type = self.__getTType(size_type)
356 if size == 15:
357 size = self.__readSize()
358 self._check_container_length(size)
359 self.__containers.append(self.state)
360 self.state = CONTAINER_READ
361 return type, size
362 readSetBegin = readCollectionBegin
363 readListBegin = readCollectionBegin
David Reissabafd792010-09-27 17:28:15 +0000364
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900365 def readMapBegin(self):
366 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
367 size = self.__readSize()
368 self._check_container_length(size)
369 types = 0
370 if size > 0:
371 types = self.__readUByte()
372 vtype = self.__getTType(types)
373 ktype = self.__getTType(types >> 4)
374 self.__containers.append(self.state)
375 self.state = CONTAINER_READ
376 return (ktype, vtype, size)
David Reissabafd792010-09-27 17:28:15 +0000377
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900378 def readCollectionEnd(self):
379 assert self.state == CONTAINER_READ, self.state
380 self.state = self.__containers.pop()
381 readSetEnd = readCollectionEnd
382 readListEnd = readCollectionEnd
383 readMapEnd = readCollectionEnd
David Reissabafd792010-09-27 17:28:15 +0000384
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900385 def readBool(self):
386 if self.state == BOOL_READ:
387 return self.__bool_value == CompactType.TRUE
388 elif self.state == CONTAINER_READ:
389 return self.__readByte() == CompactType.TRUE
390 else:
391 raise AssertionError("Invalid state in compact protocol: %d" %
392 self.state)
David Reissabafd792010-09-27 17:28:15 +0000393
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900394 readByte = reader(__readByte)
395 __readI16 = __readZigZag
396 readI16 = reader(__readZigZag)
397 readI32 = reader(__readZigZag)
398 readI64 = reader(__readZigZag)
David Reissabafd792010-09-27 17:28:15 +0000399
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900400 @reader
401 def readDouble(self):
402 buff = self.trans.readAll(8)
403 val, = unpack('<d', buff)
404 return val
David Reissabafd792010-09-27 17:28:15 +0000405
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900406 def __readBinary(self):
407 size = self.__readSize()
408 self._check_string_length(size)
409 return self.trans.readAll(size)
410 readBinary = reader(__readBinary)
David Reissabafd792010-09-27 17:28:15 +0000411
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900412 def __getTType(self, byte):
413 return TTYPES[byte & 0x0f]
David Reissabafd792010-09-27 17:28:15 +0000414
415
Nobuaki Sukegawab9c859a2015-12-21 01:10:25 +0900416class TCompactProtocolFactory(object):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900417 def __init__(self,
418 string_length_limit=None,
419 container_length_limit=None):
420 self.string_length_limit = string_length_limit
421 self.container_length_limit = container_length_limit
David Reissabafd792010-09-27 17:28:15 +0000422
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900423 def getProtocol(self, trans):
424 return TCompactProtocol(trans,
425 self.string_length_limit,
426 self.container_length_limit)
Nobuaki Sukegawa6525f6a2016-02-11 13:58:39 +0900427
428
429class TCompactProtocolAccelerated(TCompactProtocol):
430 """C-Accelerated version of TCompactProtocol.
431
432 This class does not override any of TCompactProtocol's methods,
433 but the generated code recognizes it directly and will call into
434 our C module to do the encoding, bypassing this object entirely.
435 We inherit from TCompactProtocol so that the normal TCompactProtocol
436 encoding can happen if the fastbinary module doesn't work for some
437 reason.
438 To disable this behavior, pass fallback=False constructor argument.
439
440 In order to take advantage of the C module, just use
441 TCompactProtocolAccelerated instead of TCompactProtocol.
442 """
443 pass
444
445 def __init__(self, *args, **kwargs):
446 fallback = kwargs.pop('fallback', True)
447 super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
448 try:
449 from thrift.protocol import fastbinary
450 except ImportError:
451 if not fallback:
452 raise
453 else:
454 self._fast_decode = fastbinary.decode_compact
455 self._fast_encode = fastbinary.encode_compact
456
457
458class TCompactProtocolAcceleratedFactory(object):
459 def __init__(self,
460 string_length_limit=None,
461 container_length_limit=None,
462 fallback=True):
463 self.string_length_limit = string_length_limit
464 self.container_length_limit = container_length_limit
465 self._fallback = fallback
466
467 def getProtocol(self, trans):
468 return TCompactProtocolAccelerated(
469 trans,
470 string_length_limit=self.string_length_limit,
471 container_length_limit=self.container_length_limit,
472 fallback=self._fallback)