blob: fbc156a8f1df5b57b10e377545fdfe5018c00f6e [file] [log] [blame]
David Reissabafd792010-09-27 17:28:15 +00001from TProtocol import *
2from struct import pack, unpack
3
4__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
5
6CLEAR = 0
7FIELD_WRITE = 1
8VALUE_WRITE = 2
9CONTAINER_WRITE = 3
10BOOL_WRITE = 4
11FIELD_READ = 5
12CONTAINER_READ = 6
13VALUE_READ = 7
14BOOL_READ = 8
15
16def make_helper(v_from, container):
17 def helper(func):
18 def nested(self, *args, **kwargs):
19 assert self.state in (v_from, container), (self.state, v_from, container)
20 return func(self, *args, **kwargs)
21 return nested
22 return helper
23writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
24reader = make_helper(VALUE_READ, CONTAINER_READ)
25
26def makeZigZag(n, bits):
27 return (n << 1) ^ (n >> (bits - 1))
28
29def fromZigZag(n):
30 return (n >> 1) ^ -(n & 1)
31
32def writeVarint(trans, n):
33 out = []
34 while True:
35 if n & ~0x7f == 0:
36 out.append(n)
37 break
38 else:
39 out.append((n & 0xff) | 0x80)
40 n = n >> 7
41 trans.write(''.join(map(chr, out)))
42
43def readVarint(trans):
44 result = 0
45 shift = 0
46 while True:
47 x = trans.readAll(1)
48 byte = ord(x)
49 result |= (byte & 0x7f) << shift
50 if byte >> 7 == 0:
51 return result
52 shift += 7
53
54class CompactType:
55 TRUE = 1
56 FALSE = 2
57 BYTE = 0x03
58 I16 = 0x04
59 I32 = 0x05
60 I64 = 0x06
61 DOUBLE = 0x07
62 BINARY = 0x08
63 LIST = 0x09
64 SET = 0x0A
65 MAP = 0x0B
66 STRUCT = 0x0C
67
68CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection
69 TType.BYTE: CompactType.BYTE,
70 TType.I16: CompactType.I16,
71 TType.I32: CompactType.I32,
72 TType.I64: CompactType.I64,
73 TType.DOUBLE: CompactType.DOUBLE,
74 TType.STRING: CompactType.BINARY,
75 TType.STRUCT: CompactType.STRUCT,
76 TType.LIST: CompactType.LIST,
77 TType.SET: CompactType.SET,
78 TType.MAP: CompactType.MAP,
79 }
80
81TTYPES = {}
82for k, v in CTYPES.items():
83 TTYPES[v] = k
84TTYPES[CompactType.FALSE] = TType.BOOL
85del k
86del v
87
88class TCompactProtocol(TProtocolBase):
89 "Compact implementation of the Thrift protocol driver."
90
91 PROTOCOL_ID = 0x82
92 VERSION = 1
93 VERSION_MASK = 0x1f
94 TYPE_MASK = 0xe0
95 TYPE_SHIFT_AMOUNT = 5
96
97 def __init__(self, trans):
98 TProtocolBase.__init__(self, trans)
99 self.state = CLEAR
100 self.__last_fid = 0
101 self.__bool_fid = None
102 self.__bool_value = None
103 self.__structs = []
104 self.__containers = []
105
106 def __writeVarint(self, n):
107 writeVarint(self.trans, n)
108
109 def writeMessageBegin(self, name, type, seqid):
110 assert self.state == CLEAR
111 self.__writeUByte(self.PROTOCOL_ID)
112 self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
113 self.__writeVarint(seqid)
114 self.__writeString(name)
115 self.state = VALUE_WRITE
116
117 def writeMessageEnd(self):
118 assert self.state == VALUE_WRITE
119 self.state = CLEAR
120
121 def writeStructBegin(self, name):
122 assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
123 self.__structs.append((self.state, self.__last_fid))
124 self.state = FIELD_WRITE
125 self.__last_fid = 0
126
127 def writeStructEnd(self):
128 assert self.state == FIELD_WRITE
129 self.state, self.__last_fid = self.__structs.pop()
130
131 def writeFieldStop(self):
132 self.__writeByte(0)
133
134 def __writeFieldHeader(self, type, fid):
135 delta = fid - self.__last_fid
136 if 0 < delta <= 15:
137 self.__writeUByte(delta << 4 | type)
138 else:
139 self.__writeByte(type)
140 self.__writeI16(fid)
141 self.__last_fid = fid
142
143 def writeFieldBegin(self, name, type, fid):
144 assert self.state == FIELD_WRITE, self.state
145 if type == TType.BOOL:
146 self.state = BOOL_WRITE
147 self.__bool_fid = fid
148 else:
149 self.state = VALUE_WRITE
150 self.__writeFieldHeader(CTYPES[type], fid)
151
152 def writeFieldEnd(self):
153 assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
154 self.state = FIELD_WRITE
155
156 def __writeUByte(self, byte):
157 self.trans.write(pack('!B', byte))
158
159 def __writeByte(self, byte):
160 self.trans.write(pack('!b', byte))
161
162 def __writeI16(self, i16):
163 self.__writeVarint(makeZigZag(i16, 16))
164
165 def __writeSize(self, i32):
166 self.__writeVarint(i32)
167
168 def writeCollectionBegin(self, etype, size):
169 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
170 if size <= 14:
171 self.__writeUByte(size << 4 | CTYPES[etype])
172 else:
173 self.__writeUByte(0xf0 | CTYPES[etype])
174 self.__writeSize(size)
175 self.__containers.append(self.state)
176 self.state = CONTAINER_WRITE
177 writeSetBegin = writeCollectionBegin
178 writeListBegin = writeCollectionBegin
179
180 def writeMapBegin(self, ktype, vtype, size):
181 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
182 if size == 0:
183 self.__writeByte(0)
184 else:
185 self.__writeSize(size)
186 self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
187 self.__containers.append(self.state)
188 self.state = CONTAINER_WRITE
189
190 def writeCollectionEnd(self):
191 assert self.state == CONTAINER_WRITE, self.state
192 self.state = self.__containers.pop()
193 writeMapEnd = writeCollectionEnd
194 writeSetEnd = writeCollectionEnd
195 writeListEnd = writeCollectionEnd
196
197 def writeBool(self, bool):
198 if self.state == BOOL_WRITE:
199 self.__writeFieldHeader(types[bool], self.__bool_fid)
200 elif self.state == CONTAINER_WRITE:
201 self.__writeByte(int(bool))
202 else:
203 raise AssertetionError, "Invalid state in compact protocol"
204
205 writeByte = writer(__writeByte)
206 writeI16 = writer(__writeI16)
207
208 @writer
209 def writeI32(self, i32):
210 self.__writeVarint(makeZigZag(i32, 32))
211
212 @writer
213 def writeI64(self, i64):
214 self.__writeVarint(makeZigZag(i64, 64))
215
216 @writer
217 def writeDouble(self, dub):
218 self.trans.write(pack('!d', dub))
219
220 def __writeString(self, s):
221 self.__writeSize(len(s))
222 self.trans.write(s)
223 writeString = writer(__writeString)
224
225 def readFieldBegin(self):
226 assert self.state == FIELD_READ, self.state
227 type = self.__readUByte()
228 if type & 0x0f == TType.STOP:
229 return (None, 0, 0)
230 delta = type >> 4
231 if delta == 0:
232 fid = self.__readI16()
233 else:
234 fid = self.__last_fid + delta
235 self.__last_fid = fid
236 type = type & 0x0f
237 if type == CompactType.TRUE:
238 self.state = BOOL_READ
239 self.__bool_value = True
240 elif type == CompactType.FALSE:
241 self.state = BOOL_READ
242 self.__bool_value = False
243 else:
244 self.state = VALUE_READ
245 return (None, self.__getTType(type), fid)
246
247 def readFieldEnd(self):
248 assert self.state in (VALUE_READ, BOOL_READ), self.state
249 self.state = FIELD_READ
250
251 def __readUByte(self):
252 result, = unpack('!B', self.trans.readAll(1))
253 return result
254
255 def __readByte(self):
256 result, = unpack('!b', self.trans.readAll(1))
257 return result
258
259 def __readVarint(self):
260 return readVarint(self.trans)
261
262 def __readZigZag(self):
263 return fromZigZag(self.__readVarint())
264
265 def __readSize(self):
266 result = self.__readVarint()
267 if result < 0:
268 raise TException("Length < 0")
269 return result
270
271 def readMessageBegin(self):
272 assert self.state == CLEAR
273 proto_id = self.__readUByte()
274 if proto_id != self.PROTOCOL_ID:
275 raise TProtocolException(TProtocolException.BAD_VERSION,
276 'Bad protocol id in the message: %d' % proto_id)
277 ver_type = self.__readUByte()
278 type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
279 version = ver_type & self.VERSION_MASK
280 if version != self.VERSION:
281 raise TProtocolException(TProtocolException.BAD_VERSION,
282 'Bad version: %d (expect %d)' % (version, self.VERSION))
283 seqid = self.__readVarint()
284 name = self.__readString()
285 return (name, type, seqid)
286
287 def readMessageEnd(self):
288 assert self.state == VALUE_READ
289 assert len(self.__structs) == 0
290 self.state = CLEAR
291
292 def readStructBegin(self):
293 assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
294 self.__structs.append((self.state, self.__last_fid))
295 self.state = FIELD_READ
296 self.__last_fid = 0
297
298 def readStructEnd(self):
299 assert self.state == FIELD_READ
300 self.state, self.__last_fid = self.__structs.pop()
301
302 def readCollectionBegin(self):
303 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
304 size_type = self.__readUByte()
305 size = size_type >> 4
306 type = self.__getTType(size_type)
307 if size == 15:
308 size = self.__readSize()
309 self.__containers.append(self.state)
310 self.state = CONTAINER_READ
311 return type, size
312 readSetBegin = readCollectionBegin
313 readListBegin = readCollectionBegin
314
315 def readMapBegin(self):
316 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
317 size = self.__readSize()
318 types = 0
319 if size > 0:
320 types = self.__readUByte()
321 vtype = self.__getTType(types)
322 ktype = self.__getTType(types >> 4)
323 self.__containers.append(self.state)
324 self.state = CONTAINER_READ
325 return (ktype, vtype, size)
326
327 def readCollectionEnd(self):
328 assert self.state == CONTAINER_READ, self.state
329 self.state = self.__containers.pop()
330 readSetEnd = readCollectionEnd
331 readListEnd = readCollectionEnd
332 readMapEnd = readCollectionEnd
333
334 def readBool(self):
335 if self.state == BOOL_READ:
336 return self.__bool_value
337 elif self.state == CONTAINER_READ:
338 return bool(self.__readByte())
339 else:
340 raise AssertionError, "Invalid state in compact protocol: %d" % self.state
341
342 readByte = reader(__readByte)
343 __readI16 = __readZigZag
344 readI16 = reader(__readZigZag)
345 readI32 = reader(__readZigZag)
346 readI64 = reader(__readZigZag)
347
348 @reader
349 def readDouble(self):
350 buff = self.trans.readAll(8)
351 val, = unpack('!d', buff)
352 return val
353
354 def __readString(self):
355 len = self.__readSize()
356 return self.trans.readAll(len)
357 readString = reader(__readString)
358
359 def __getTType(self, byte):
360 return TTYPES[byte & 0x0f]
361
362
363class TCompactProtocolFactory:
364 def __init__(self):
365 pass
366
367 def getProtocol(self, trans):
368 return TCompactProtocol(trans)