blob: 6d57aeba9a82b32310e7420761488435eae88366 [file] [log] [blame]
David Reissabafd792010-09-27 17:28:15 +00001from TProtocol import *
2from struct import pack, unpack
3
4__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
5
6CLEAR = 0
7FIELD_WRITE = 1
8VALUE_WRITE = 2
9CONTAINER_WRITE = 3
10BOOL_WRITE = 4
11FIELD_READ = 5
12CONTAINER_READ = 6
13VALUE_READ = 7
14BOOL_READ = 8
15
16def make_helper(v_from, container):
17 def helper(func):
18 def nested(self, *args, **kwargs):
19 assert self.state in (v_from, container), (self.state, v_from, container)
20 return func(self, *args, **kwargs)
21 return nested
22 return helper
23writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
24reader = make_helper(VALUE_READ, CONTAINER_READ)
25
26def makeZigZag(n, bits):
27 return (n << 1) ^ (n >> (bits - 1))
28
29def fromZigZag(n):
30 return (n >> 1) ^ -(n & 1)
31
32def writeVarint(trans, n):
33 out = []
34 while True:
35 if n & ~0x7f == 0:
36 out.append(n)
37 break
38 else:
39 out.append((n & 0xff) | 0x80)
40 n = n >> 7
41 trans.write(''.join(map(chr, out)))
42
43def readVarint(trans):
44 result = 0
45 shift = 0
46 while True:
47 x = trans.readAll(1)
48 byte = ord(x)
49 result |= (byte & 0x7f) << shift
50 if byte >> 7 == 0:
51 return result
52 shift += 7
53
54class CompactType:
Bryan Duxburydf4cffd2011-03-15 17:16:09 +000055 STOP = 0x00
56 TRUE = 0x01
57 FALSE = 0x02
David Reissabafd792010-09-27 17:28:15 +000058 BYTE = 0x03
59 I16 = 0x04
60 I32 = 0x05
61 I64 = 0x06
62 DOUBLE = 0x07
63 BINARY = 0x08
64 LIST = 0x09
65 SET = 0x0A
66 MAP = 0x0B
67 STRUCT = 0x0C
68
Bryan Duxburydf4cffd2011-03-15 17:16:09 +000069CTYPES = {TType.STOP: CompactType.STOP,
70 TType.BOOL: CompactType.TRUE, # used for collection
David Reissabafd792010-09-27 17:28:15 +000071 TType.BYTE: CompactType.BYTE,
72 TType.I16: CompactType.I16,
73 TType.I32: CompactType.I32,
74 TType.I64: CompactType.I64,
75 TType.DOUBLE: CompactType.DOUBLE,
76 TType.STRING: CompactType.BINARY,
77 TType.STRUCT: CompactType.STRUCT,
78 TType.LIST: CompactType.LIST,
79 TType.SET: CompactType.SET,
Bryan Duxburydf4cffd2011-03-15 17:16:09 +000080 TType.MAP: CompactType.MAP
David Reissabafd792010-09-27 17:28:15 +000081 }
82
83TTYPES = {}
84for k, v in CTYPES.items():
85 TTYPES[v] = k
86TTYPES[CompactType.FALSE] = TType.BOOL
87del k
88del v
89
90class TCompactProtocol(TProtocolBase):
91 "Compact implementation of the Thrift protocol driver."
92
93 PROTOCOL_ID = 0x82
94 VERSION = 1
95 VERSION_MASK = 0x1f
96 TYPE_MASK = 0xe0
97 TYPE_SHIFT_AMOUNT = 5
98
99 def __init__(self, trans):
100 TProtocolBase.__init__(self, trans)
101 self.state = CLEAR
102 self.__last_fid = 0
103 self.__bool_fid = None
104 self.__bool_value = None
105 self.__structs = []
106 self.__containers = []
107
108 def __writeVarint(self, n):
109 writeVarint(self.trans, n)
110
111 def writeMessageBegin(self, name, type, seqid):
112 assert self.state == CLEAR
113 self.__writeUByte(self.PROTOCOL_ID)
114 self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
115 self.__writeVarint(seqid)
116 self.__writeString(name)
117 self.state = VALUE_WRITE
118
119 def writeMessageEnd(self):
120 assert self.state == VALUE_WRITE
121 self.state = CLEAR
122
123 def writeStructBegin(self, name):
124 assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
125 self.__structs.append((self.state, self.__last_fid))
126 self.state = FIELD_WRITE
127 self.__last_fid = 0
128
129 def writeStructEnd(self):
130 assert self.state == FIELD_WRITE
131 self.state, self.__last_fid = self.__structs.pop()
132
133 def writeFieldStop(self):
134 self.__writeByte(0)
135
136 def __writeFieldHeader(self, type, fid):
137 delta = fid - self.__last_fid
138 if 0 < delta <= 15:
139 self.__writeUByte(delta << 4 | type)
140 else:
141 self.__writeByte(type)
142 self.__writeI16(fid)
143 self.__last_fid = fid
144
145 def writeFieldBegin(self, name, type, fid):
146 assert self.state == FIELD_WRITE, self.state
147 if type == TType.BOOL:
148 self.state = BOOL_WRITE
149 self.__bool_fid = fid
150 else:
151 self.state = VALUE_WRITE
152 self.__writeFieldHeader(CTYPES[type], fid)
153
154 def writeFieldEnd(self):
155 assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
156 self.state = FIELD_WRITE
157
158 def __writeUByte(self, byte):
159 self.trans.write(pack('!B', byte))
160
161 def __writeByte(self, byte):
162 self.trans.write(pack('!b', byte))
163
164 def __writeI16(self, i16):
165 self.__writeVarint(makeZigZag(i16, 16))
166
167 def __writeSize(self, i32):
168 self.__writeVarint(i32)
169
170 def writeCollectionBegin(self, etype, size):
171 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
172 if size <= 14:
173 self.__writeUByte(size << 4 | CTYPES[etype])
174 else:
175 self.__writeUByte(0xf0 | CTYPES[etype])
176 self.__writeSize(size)
177 self.__containers.append(self.state)
178 self.state = CONTAINER_WRITE
179 writeSetBegin = writeCollectionBegin
180 writeListBegin = writeCollectionBegin
181
182 def writeMapBegin(self, ktype, vtype, size):
183 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
184 if size == 0:
185 self.__writeByte(0)
186 else:
187 self.__writeSize(size)
188 self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
189 self.__containers.append(self.state)
190 self.state = CONTAINER_WRITE
191
192 def writeCollectionEnd(self):
193 assert self.state == CONTAINER_WRITE, self.state
194 self.state = self.__containers.pop()
195 writeMapEnd = writeCollectionEnd
196 writeSetEnd = writeCollectionEnd
197 writeListEnd = writeCollectionEnd
198
199 def writeBool(self, bool):
200 if self.state == BOOL_WRITE:
Bryan Duxburydf4cffd2011-03-15 17:16:09 +0000201 if bool:
202 ctype = CompactType.TRUE
203 else:
204 ctype = CompactType.FALSE
205 self.__writeFieldHeader(ctype, self.__bool_fid)
David Reissabafd792010-09-27 17:28:15 +0000206 elif self.state == CONTAINER_WRITE:
Bryan Duxbury54df97c2011-07-13 18:11:29 +0000207 if bool:
208 self.__writeByte(CompactType.TRUE)
209 else:
210 self.__writeByte(CompactType.FALSE)
David Reissabafd792010-09-27 17:28:15 +0000211 else:
Bryan Duxburydf4cffd2011-03-15 17:16:09 +0000212 raise AssertionError, "Invalid state in compact protocol"
David Reissabafd792010-09-27 17:28:15 +0000213
214 writeByte = writer(__writeByte)
215 writeI16 = writer(__writeI16)
216
217 @writer
218 def writeI32(self, i32):
219 self.__writeVarint(makeZigZag(i32, 32))
220
221 @writer
222 def writeI64(self, i64):
223 self.__writeVarint(makeZigZag(i64, 64))
224
225 @writer
226 def writeDouble(self, dub):
227 self.trans.write(pack('!d', dub))
228
229 def __writeString(self, s):
230 self.__writeSize(len(s))
231 self.trans.write(s)
232 writeString = writer(__writeString)
233
234 def readFieldBegin(self):
235 assert self.state == FIELD_READ, self.state
236 type = self.__readUByte()
237 if type & 0x0f == TType.STOP:
238 return (None, 0, 0)
239 delta = type >> 4
240 if delta == 0:
241 fid = self.__readI16()
242 else:
243 fid = self.__last_fid + delta
244 self.__last_fid = fid
245 type = type & 0x0f
246 if type == CompactType.TRUE:
247 self.state = BOOL_READ
248 self.__bool_value = True
249 elif type == CompactType.FALSE:
250 self.state = BOOL_READ
251 self.__bool_value = False
252 else:
253 self.state = VALUE_READ
254 return (None, self.__getTType(type), fid)
255
256 def readFieldEnd(self):
257 assert self.state in (VALUE_READ, BOOL_READ), self.state
258 self.state = FIELD_READ
259
260 def __readUByte(self):
261 result, = unpack('!B', self.trans.readAll(1))
262 return result
263
264 def __readByte(self):
265 result, = unpack('!b', self.trans.readAll(1))
266 return result
267
268 def __readVarint(self):
269 return readVarint(self.trans)
270
271 def __readZigZag(self):
272 return fromZigZag(self.__readVarint())
273
274 def __readSize(self):
275 result = self.__readVarint()
276 if result < 0:
277 raise TException("Length < 0")
278 return result
279
280 def readMessageBegin(self):
281 assert self.state == CLEAR
282 proto_id = self.__readUByte()
283 if proto_id != self.PROTOCOL_ID:
284 raise TProtocolException(TProtocolException.BAD_VERSION,
285 'Bad protocol id in the message: %d' % proto_id)
286 ver_type = self.__readUByte()
287 type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
288 version = ver_type & self.VERSION_MASK
289 if version != self.VERSION:
290 raise TProtocolException(TProtocolException.BAD_VERSION,
291 'Bad version: %d (expect %d)' % (version, self.VERSION))
292 seqid = self.__readVarint()
293 name = self.__readString()
294 return (name, type, seqid)
295
296 def readMessageEnd(self):
Bryan Duxbury59d4efd2011-03-21 17:38:22 +0000297 assert self.state == CLEAR
David Reissabafd792010-09-27 17:28:15 +0000298 assert len(self.__structs) == 0
David Reissabafd792010-09-27 17:28:15 +0000299
300 def readStructBegin(self):
301 assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
302 self.__structs.append((self.state, self.__last_fid))
303 self.state = FIELD_READ
304 self.__last_fid = 0
305
306 def readStructEnd(self):
307 assert self.state == FIELD_READ
308 self.state, self.__last_fid = self.__structs.pop()
309
310 def readCollectionBegin(self):
311 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
312 size_type = self.__readUByte()
313 size = size_type >> 4
314 type = self.__getTType(size_type)
315 if size == 15:
316 size = self.__readSize()
317 self.__containers.append(self.state)
318 self.state = CONTAINER_READ
319 return type, size
320 readSetBegin = readCollectionBegin
321 readListBegin = readCollectionBegin
322
323 def readMapBegin(self):
324 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
325 size = self.__readSize()
326 types = 0
327 if size > 0:
328 types = self.__readUByte()
329 vtype = self.__getTType(types)
330 ktype = self.__getTType(types >> 4)
331 self.__containers.append(self.state)
332 self.state = CONTAINER_READ
333 return (ktype, vtype, size)
334
335 def readCollectionEnd(self):
336 assert self.state == CONTAINER_READ, self.state
337 self.state = self.__containers.pop()
338 readSetEnd = readCollectionEnd
339 readListEnd = readCollectionEnd
340 readMapEnd = readCollectionEnd
341
342 def readBool(self):
343 if self.state == BOOL_READ:
Bryan Duxbury54df97c2011-07-13 18:11:29 +0000344 return self.__bool_value == CompactType.TRUE
David Reissabafd792010-09-27 17:28:15 +0000345 elif self.state == CONTAINER_READ:
Bryan Duxbury54df97c2011-07-13 18:11:29 +0000346 return self.__readByte() == CompactType.TRUE
David Reissabafd792010-09-27 17:28:15 +0000347 else:
348 raise AssertionError, "Invalid state in compact protocol: %d" % self.state
349
350 readByte = reader(__readByte)
351 __readI16 = __readZigZag
352 readI16 = reader(__readZigZag)
353 readI32 = reader(__readZigZag)
354 readI64 = reader(__readZigZag)
355
356 @reader
357 def readDouble(self):
358 buff = self.trans.readAll(8)
359 val, = unpack('!d', buff)
360 return val
361
362 def __readString(self):
363 len = self.__readSize()
364 return self.trans.readAll(len)
365 readString = reader(__readString)
366
367 def __getTType(self, byte):
368 return TTYPES[byte & 0x0f]
369
370
371class TCompactProtocolFactory:
372 def __init__(self):
373 pass
374
375 def getProtocol(self, trans):
376 return TCompactProtocol(trans)