blob: 016a3317185e4496b19dcae76492f25de3ed79e0 [file] [log] [blame]
Roger Meierf4eec7a2011-09-11 18:16:21 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
David Reissabafd792010-09-27 17:28:15 +000020from TProtocol import *
21from struct import pack, unpack
22
23__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
24
25CLEAR = 0
26FIELD_WRITE = 1
27VALUE_WRITE = 2
28CONTAINER_WRITE = 3
29BOOL_WRITE = 4
30FIELD_READ = 5
31CONTAINER_READ = 6
32VALUE_READ = 7
33BOOL_READ = 8
34
35def make_helper(v_from, container):
36 def helper(func):
37 def nested(self, *args, **kwargs):
38 assert self.state in (v_from, container), (self.state, v_from, container)
39 return func(self, *args, **kwargs)
40 return nested
41 return helper
42writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
43reader = make_helper(VALUE_READ, CONTAINER_READ)
44
45def makeZigZag(n, bits):
46 return (n << 1) ^ (n >> (bits - 1))
47
48def fromZigZag(n):
49 return (n >> 1) ^ -(n & 1)
50
51def writeVarint(trans, n):
52 out = []
53 while True:
54 if n & ~0x7f == 0:
55 out.append(n)
56 break
57 else:
58 out.append((n & 0xff) | 0x80)
59 n = n >> 7
60 trans.write(''.join(map(chr, out)))
61
62def readVarint(trans):
63 result = 0
64 shift = 0
65 while True:
66 x = trans.readAll(1)
67 byte = ord(x)
68 result |= (byte & 0x7f) << shift
69 if byte >> 7 == 0:
70 return result
71 shift += 7
72
73class CompactType:
Bryan Duxburydf4cffd2011-03-15 17:16:09 +000074 STOP = 0x00
75 TRUE = 0x01
76 FALSE = 0x02
David Reissabafd792010-09-27 17:28:15 +000077 BYTE = 0x03
78 I16 = 0x04
79 I32 = 0x05
80 I64 = 0x06
81 DOUBLE = 0x07
82 BINARY = 0x08
83 LIST = 0x09
84 SET = 0x0A
85 MAP = 0x0B
86 STRUCT = 0x0C
87
Bryan Duxburydf4cffd2011-03-15 17:16:09 +000088CTYPES = {TType.STOP: CompactType.STOP,
89 TType.BOOL: CompactType.TRUE, # used for collection
David Reissabafd792010-09-27 17:28:15 +000090 TType.BYTE: CompactType.BYTE,
91 TType.I16: CompactType.I16,
92 TType.I32: CompactType.I32,
93 TType.I64: CompactType.I64,
94 TType.DOUBLE: CompactType.DOUBLE,
95 TType.STRING: CompactType.BINARY,
96 TType.STRUCT: CompactType.STRUCT,
97 TType.LIST: CompactType.LIST,
98 TType.SET: CompactType.SET,
Bryan Duxburydf4cffd2011-03-15 17:16:09 +000099 TType.MAP: CompactType.MAP
David Reissabafd792010-09-27 17:28:15 +0000100 }
101
102TTYPES = {}
103for k, v in CTYPES.items():
104 TTYPES[v] = k
105TTYPES[CompactType.FALSE] = TType.BOOL
106del k
107del v
108
109class TCompactProtocol(TProtocolBase):
110 "Compact implementation of the Thrift protocol driver."
111
112 PROTOCOL_ID = 0x82
113 VERSION = 1
114 VERSION_MASK = 0x1f
115 TYPE_MASK = 0xe0
116 TYPE_SHIFT_AMOUNT = 5
117
118 def __init__(self, trans):
119 TProtocolBase.__init__(self, trans)
120 self.state = CLEAR
121 self.__last_fid = 0
122 self.__bool_fid = None
123 self.__bool_value = None
124 self.__structs = []
125 self.__containers = []
126
127 def __writeVarint(self, n):
128 writeVarint(self.trans, n)
129
130 def writeMessageBegin(self, name, type, seqid):
131 assert self.state == CLEAR
132 self.__writeUByte(self.PROTOCOL_ID)
133 self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
134 self.__writeVarint(seqid)
135 self.__writeString(name)
136 self.state = VALUE_WRITE
137
138 def writeMessageEnd(self):
139 assert self.state == VALUE_WRITE
140 self.state = CLEAR
141
142 def writeStructBegin(self, name):
143 assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
144 self.__structs.append((self.state, self.__last_fid))
145 self.state = FIELD_WRITE
146 self.__last_fid = 0
147
148 def writeStructEnd(self):
149 assert self.state == FIELD_WRITE
150 self.state, self.__last_fid = self.__structs.pop()
151
152 def writeFieldStop(self):
153 self.__writeByte(0)
154
155 def __writeFieldHeader(self, type, fid):
156 delta = fid - self.__last_fid
157 if 0 < delta <= 15:
158 self.__writeUByte(delta << 4 | type)
159 else:
160 self.__writeByte(type)
161 self.__writeI16(fid)
162 self.__last_fid = fid
163
164 def writeFieldBegin(self, name, type, fid):
165 assert self.state == FIELD_WRITE, self.state
166 if type == TType.BOOL:
167 self.state = BOOL_WRITE
168 self.__bool_fid = fid
169 else:
170 self.state = VALUE_WRITE
171 self.__writeFieldHeader(CTYPES[type], fid)
172
173 def writeFieldEnd(self):
174 assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
175 self.state = FIELD_WRITE
176
177 def __writeUByte(self, byte):
178 self.trans.write(pack('!B', byte))
179
180 def __writeByte(self, byte):
181 self.trans.write(pack('!b', byte))
182
183 def __writeI16(self, i16):
184 self.__writeVarint(makeZigZag(i16, 16))
185
186 def __writeSize(self, i32):
187 self.__writeVarint(i32)
188
189 def writeCollectionBegin(self, etype, size):
190 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
191 if size <= 14:
192 self.__writeUByte(size << 4 | CTYPES[etype])
193 else:
194 self.__writeUByte(0xf0 | CTYPES[etype])
195 self.__writeSize(size)
196 self.__containers.append(self.state)
197 self.state = CONTAINER_WRITE
198 writeSetBegin = writeCollectionBegin
199 writeListBegin = writeCollectionBegin
200
201 def writeMapBegin(self, ktype, vtype, size):
202 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
203 if size == 0:
204 self.__writeByte(0)
205 else:
206 self.__writeSize(size)
207 self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
208 self.__containers.append(self.state)
209 self.state = CONTAINER_WRITE
210
211 def writeCollectionEnd(self):
212 assert self.state == CONTAINER_WRITE, self.state
213 self.state = self.__containers.pop()
214 writeMapEnd = writeCollectionEnd
215 writeSetEnd = writeCollectionEnd
216 writeListEnd = writeCollectionEnd
217
218 def writeBool(self, bool):
219 if self.state == BOOL_WRITE:
Bryan Duxburydf4cffd2011-03-15 17:16:09 +0000220 if bool:
221 ctype = CompactType.TRUE
222 else:
223 ctype = CompactType.FALSE
224 self.__writeFieldHeader(ctype, self.__bool_fid)
David Reissabafd792010-09-27 17:28:15 +0000225 elif self.state == CONTAINER_WRITE:
Bryan Duxbury54df97c2011-07-13 18:11:29 +0000226 if bool:
227 self.__writeByte(CompactType.TRUE)
228 else:
229 self.__writeByte(CompactType.FALSE)
David Reissabafd792010-09-27 17:28:15 +0000230 else:
Bryan Duxburydf4cffd2011-03-15 17:16:09 +0000231 raise AssertionError, "Invalid state in compact protocol"
David Reissabafd792010-09-27 17:28:15 +0000232
233 writeByte = writer(__writeByte)
234 writeI16 = writer(__writeI16)
235
236 @writer
237 def writeI32(self, i32):
238 self.__writeVarint(makeZigZag(i32, 32))
239
240 @writer
241 def writeI64(self, i64):
242 self.__writeVarint(makeZigZag(i64, 64))
243
244 @writer
245 def writeDouble(self, dub):
246 self.trans.write(pack('!d', dub))
247
248 def __writeString(self, s):
249 self.__writeSize(len(s))
250 self.trans.write(s)
251 writeString = writer(__writeString)
252
253 def readFieldBegin(self):
254 assert self.state == FIELD_READ, self.state
255 type = self.__readUByte()
256 if type & 0x0f == TType.STOP:
257 return (None, 0, 0)
258 delta = type >> 4
259 if delta == 0:
260 fid = self.__readI16()
261 else:
262 fid = self.__last_fid + delta
263 self.__last_fid = fid
264 type = type & 0x0f
265 if type == CompactType.TRUE:
266 self.state = BOOL_READ
267 self.__bool_value = True
268 elif type == CompactType.FALSE:
269 self.state = BOOL_READ
270 self.__bool_value = False
271 else:
272 self.state = VALUE_READ
273 return (None, self.__getTType(type), fid)
274
275 def readFieldEnd(self):
276 assert self.state in (VALUE_READ, BOOL_READ), self.state
277 self.state = FIELD_READ
278
279 def __readUByte(self):
280 result, = unpack('!B', self.trans.readAll(1))
281 return result
282
283 def __readByte(self):
284 result, = unpack('!b', self.trans.readAll(1))
285 return result
286
287 def __readVarint(self):
288 return readVarint(self.trans)
289
290 def __readZigZag(self):
291 return fromZigZag(self.__readVarint())
292
293 def __readSize(self):
294 result = self.__readVarint()
295 if result < 0:
296 raise TException("Length < 0")
297 return result
298
299 def readMessageBegin(self):
300 assert self.state == CLEAR
301 proto_id = self.__readUByte()
302 if proto_id != self.PROTOCOL_ID:
303 raise TProtocolException(TProtocolException.BAD_VERSION,
304 'Bad protocol id in the message: %d' % proto_id)
305 ver_type = self.__readUByte()
306 type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
307 version = ver_type & self.VERSION_MASK
308 if version != self.VERSION:
309 raise TProtocolException(TProtocolException.BAD_VERSION,
310 'Bad version: %d (expect %d)' % (version, self.VERSION))
311 seqid = self.__readVarint()
312 name = self.__readString()
313 return (name, type, seqid)
314
315 def readMessageEnd(self):
Bryan Duxbury59d4efd2011-03-21 17:38:22 +0000316 assert self.state == CLEAR
David Reissabafd792010-09-27 17:28:15 +0000317 assert len(self.__structs) == 0
David Reissabafd792010-09-27 17:28:15 +0000318
319 def readStructBegin(self):
320 assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
321 self.__structs.append((self.state, self.__last_fid))
322 self.state = FIELD_READ
323 self.__last_fid = 0
324
325 def readStructEnd(self):
326 assert self.state == FIELD_READ
327 self.state, self.__last_fid = self.__structs.pop()
328
329 def readCollectionBegin(self):
330 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
331 size_type = self.__readUByte()
332 size = size_type >> 4
333 type = self.__getTType(size_type)
334 if size == 15:
335 size = self.__readSize()
336 self.__containers.append(self.state)
337 self.state = CONTAINER_READ
338 return type, size
339 readSetBegin = readCollectionBegin
340 readListBegin = readCollectionBegin
341
342 def readMapBegin(self):
343 assert self.state in (VALUE_READ, CONTAINER_READ), self.state
344 size = self.__readSize()
345 types = 0
346 if size > 0:
347 types = self.__readUByte()
348 vtype = self.__getTType(types)
349 ktype = self.__getTType(types >> 4)
350 self.__containers.append(self.state)
351 self.state = CONTAINER_READ
352 return (ktype, vtype, size)
353
354 def readCollectionEnd(self):
355 assert self.state == CONTAINER_READ, self.state
356 self.state = self.__containers.pop()
357 readSetEnd = readCollectionEnd
358 readListEnd = readCollectionEnd
359 readMapEnd = readCollectionEnd
360
361 def readBool(self):
362 if self.state == BOOL_READ:
Bryan Duxbury54df97c2011-07-13 18:11:29 +0000363 return self.__bool_value == CompactType.TRUE
David Reissabafd792010-09-27 17:28:15 +0000364 elif self.state == CONTAINER_READ:
Bryan Duxbury54df97c2011-07-13 18:11:29 +0000365 return self.__readByte() == CompactType.TRUE
David Reissabafd792010-09-27 17:28:15 +0000366 else:
367 raise AssertionError, "Invalid state in compact protocol: %d" % self.state
368
369 readByte = reader(__readByte)
370 __readI16 = __readZigZag
371 readI16 = reader(__readZigZag)
372 readI32 = reader(__readZigZag)
373 readI64 = reader(__readZigZag)
374
375 @reader
376 def readDouble(self):
377 buff = self.trans.readAll(8)
378 val, = unpack('!d', buff)
379 return val
380
381 def __readString(self):
382 len = self.__readSize()
383 return self.trans.readAll(len)
384 readString = reader(__readString)
385
386 def __getTType(self, byte):
387 return TTYPES[byte & 0x0f]
388
389
390class TCompactProtocolFactory:
391 def __init__(self):
392 pass
393
394 def getProtocol(self, trans):
395 return TCompactProtocol(trans)