| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 1 | from TProtocol import * | 
|  | 2 | from struct import pack, unpack | 
|  | 3 |  | 
|  | 4 | __all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] | 
|  | 5 |  | 
|  | 6 | CLEAR = 0 | 
|  | 7 | FIELD_WRITE = 1 | 
|  | 8 | VALUE_WRITE = 2 | 
|  | 9 | CONTAINER_WRITE = 3 | 
|  | 10 | BOOL_WRITE = 4 | 
|  | 11 | FIELD_READ = 5 | 
|  | 12 | CONTAINER_READ = 6 | 
|  | 13 | VALUE_READ = 7 | 
|  | 14 | BOOL_READ = 8 | 
|  | 15 |  | 
|  | 16 | def make_helper(v_from, container): | 
|  | 17 | def helper(func): | 
|  | 18 | def nested(self, *args, **kwargs): | 
|  | 19 | assert self.state in (v_from, container), (self.state, v_from, container) | 
|  | 20 | return func(self, *args, **kwargs) | 
|  | 21 | return nested | 
|  | 22 | return helper | 
|  | 23 | writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) | 
|  | 24 | reader = make_helper(VALUE_READ, CONTAINER_READ) | 
|  | 25 |  | 
|  | 26 | def makeZigZag(n, bits): | 
|  | 27 | return (n << 1) ^ (n >> (bits - 1)) | 
|  | 28 |  | 
|  | 29 | def fromZigZag(n): | 
|  | 30 | return (n >> 1) ^ -(n & 1) | 
|  | 31 |  | 
|  | 32 | def writeVarint(trans, n): | 
|  | 33 | out = [] | 
|  | 34 | while True: | 
|  | 35 | if n & ~0x7f == 0: | 
|  | 36 | out.append(n) | 
|  | 37 | break | 
|  | 38 | else: | 
|  | 39 | out.append((n & 0xff) | 0x80) | 
|  | 40 | n = n >> 7 | 
|  | 41 | trans.write(''.join(map(chr, out))) | 
|  | 42 |  | 
|  | 43 | def readVarint(trans): | 
|  | 44 | result = 0 | 
|  | 45 | shift = 0 | 
|  | 46 | while True: | 
|  | 47 | x = trans.readAll(1) | 
|  | 48 | byte = ord(x) | 
|  | 49 | result |= (byte & 0x7f) << shift | 
|  | 50 | if byte >> 7 == 0: | 
|  | 51 | return result | 
|  | 52 | shift += 7 | 
|  | 53 |  | 
|  | 54 | class CompactType: | 
|  | 55 | TRUE = 1 | 
|  | 56 | FALSE = 2 | 
|  | 57 | BYTE = 0x03 | 
|  | 58 | I16 = 0x04 | 
|  | 59 | I32 = 0x05 | 
|  | 60 | I64 = 0x06 | 
|  | 61 | DOUBLE = 0x07 | 
|  | 62 | BINARY = 0x08 | 
|  | 63 | LIST = 0x09 | 
|  | 64 | SET = 0x0A | 
|  | 65 | MAP = 0x0B | 
|  | 66 | STRUCT = 0x0C | 
|  | 67 |  | 
|  | 68 | CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection | 
|  | 69 | TType.BYTE: CompactType.BYTE, | 
|  | 70 | TType.I16: CompactType.I16, | 
|  | 71 | TType.I32: CompactType.I32, | 
|  | 72 | TType.I64: CompactType.I64, | 
|  | 73 | TType.DOUBLE: CompactType.DOUBLE, | 
|  | 74 | TType.STRING: CompactType.BINARY, | 
|  | 75 | TType.STRUCT: CompactType.STRUCT, | 
|  | 76 | TType.LIST: CompactType.LIST, | 
|  | 77 | TType.SET: CompactType.SET, | 
|  | 78 | TType.MAP: CompactType.MAP, | 
|  | 79 | } | 
|  | 80 |  | 
|  | 81 | TTYPES = {} | 
|  | 82 | for k, v in CTYPES.items(): | 
|  | 83 | TTYPES[v] = k | 
|  | 84 | TTYPES[CompactType.FALSE] = TType.BOOL | 
|  | 85 | del k | 
|  | 86 | del v | 
|  | 87 |  | 
|  | 88 | class TCompactProtocol(TProtocolBase): | 
|  | 89 | "Compact implementation of the Thrift protocol driver." | 
|  | 90 |  | 
|  | 91 | PROTOCOL_ID = 0x82 | 
|  | 92 | VERSION = 1 | 
|  | 93 | VERSION_MASK = 0x1f | 
|  | 94 | TYPE_MASK = 0xe0 | 
|  | 95 | TYPE_SHIFT_AMOUNT = 5 | 
|  | 96 |  | 
|  | 97 | def __init__(self, trans): | 
|  | 98 | TProtocolBase.__init__(self, trans) | 
|  | 99 | self.state = CLEAR | 
|  | 100 | self.__last_fid = 0 | 
|  | 101 | self.__bool_fid = None | 
|  | 102 | self.__bool_value = None | 
|  | 103 | self.__structs = [] | 
|  | 104 | self.__containers = [] | 
|  | 105 |  | 
|  | 106 | def __writeVarint(self, n): | 
|  | 107 | writeVarint(self.trans, n) | 
|  | 108 |  | 
|  | 109 | def writeMessageBegin(self, name, type, seqid): | 
|  | 110 | assert self.state == CLEAR | 
|  | 111 | self.__writeUByte(self.PROTOCOL_ID) | 
|  | 112 | self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) | 
|  | 113 | self.__writeVarint(seqid) | 
|  | 114 | self.__writeString(name) | 
|  | 115 | self.state = VALUE_WRITE | 
|  | 116 |  | 
|  | 117 | def writeMessageEnd(self): | 
|  | 118 | assert self.state == VALUE_WRITE | 
|  | 119 | self.state = CLEAR | 
|  | 120 |  | 
|  | 121 | def writeStructBegin(self, name): | 
|  | 122 | assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state | 
|  | 123 | self.__structs.append((self.state, self.__last_fid)) | 
|  | 124 | self.state = FIELD_WRITE | 
|  | 125 | self.__last_fid = 0 | 
|  | 126 |  | 
|  | 127 | def writeStructEnd(self): | 
|  | 128 | assert self.state == FIELD_WRITE | 
|  | 129 | self.state, self.__last_fid = self.__structs.pop() | 
|  | 130 |  | 
|  | 131 | def writeFieldStop(self): | 
|  | 132 | self.__writeByte(0) | 
|  | 133 |  | 
|  | 134 | def __writeFieldHeader(self, type, fid): | 
|  | 135 | delta = fid - self.__last_fid | 
|  | 136 | if 0 < delta <= 15: | 
|  | 137 | self.__writeUByte(delta << 4 | type) | 
|  | 138 | else: | 
|  | 139 | self.__writeByte(type) | 
|  | 140 | self.__writeI16(fid) | 
|  | 141 | self.__last_fid = fid | 
|  | 142 |  | 
|  | 143 | def writeFieldBegin(self, name, type, fid): | 
|  | 144 | assert self.state == FIELD_WRITE, self.state | 
|  | 145 | if type == TType.BOOL: | 
|  | 146 | self.state = BOOL_WRITE | 
|  | 147 | self.__bool_fid = fid | 
|  | 148 | else: | 
|  | 149 | self.state = VALUE_WRITE | 
|  | 150 | self.__writeFieldHeader(CTYPES[type], fid) | 
|  | 151 |  | 
|  | 152 | def writeFieldEnd(self): | 
|  | 153 | assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state | 
|  | 154 | self.state = FIELD_WRITE | 
|  | 155 |  | 
|  | 156 | def __writeUByte(self, byte): | 
|  | 157 | self.trans.write(pack('!B', byte)) | 
|  | 158 |  | 
|  | 159 | def __writeByte(self, byte): | 
|  | 160 | self.trans.write(pack('!b', byte)) | 
|  | 161 |  | 
|  | 162 | def __writeI16(self, i16): | 
|  | 163 | self.__writeVarint(makeZigZag(i16, 16)) | 
|  | 164 |  | 
|  | 165 | def __writeSize(self, i32): | 
|  | 166 | self.__writeVarint(i32) | 
|  | 167 |  | 
|  | 168 | def writeCollectionBegin(self, etype, size): | 
|  | 169 | assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state | 
|  | 170 | if size <= 14: | 
|  | 171 | self.__writeUByte(size << 4 | CTYPES[etype]) | 
|  | 172 | else: | 
|  | 173 | self.__writeUByte(0xf0 | CTYPES[etype]) | 
|  | 174 | self.__writeSize(size) | 
|  | 175 | self.__containers.append(self.state) | 
|  | 176 | self.state = CONTAINER_WRITE | 
|  | 177 | writeSetBegin = writeCollectionBegin | 
|  | 178 | writeListBegin = writeCollectionBegin | 
|  | 179 |  | 
|  | 180 | def writeMapBegin(self, ktype, vtype, size): | 
|  | 181 | assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state | 
|  | 182 | if size == 0: | 
|  | 183 | self.__writeByte(0) | 
|  | 184 | else: | 
|  | 185 | self.__writeSize(size) | 
|  | 186 | self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) | 
|  | 187 | self.__containers.append(self.state) | 
|  | 188 | self.state = CONTAINER_WRITE | 
|  | 189 |  | 
|  | 190 | def writeCollectionEnd(self): | 
|  | 191 | assert self.state == CONTAINER_WRITE, self.state | 
|  | 192 | self.state = self.__containers.pop() | 
|  | 193 | writeMapEnd = writeCollectionEnd | 
|  | 194 | writeSetEnd = writeCollectionEnd | 
|  | 195 | writeListEnd = writeCollectionEnd | 
|  | 196 |  | 
|  | 197 | def writeBool(self, bool): | 
|  | 198 | if self.state == BOOL_WRITE: | 
|  | 199 | self.__writeFieldHeader(types[bool], self.__bool_fid) | 
|  | 200 | elif self.state == CONTAINER_WRITE: | 
|  | 201 | self.__writeByte(int(bool)) | 
|  | 202 | else: | 
|  | 203 | raise AssertetionError, "Invalid state in compact protocol" | 
|  | 204 |  | 
|  | 205 | writeByte = writer(__writeByte) | 
|  | 206 | writeI16 = writer(__writeI16) | 
|  | 207 |  | 
|  | 208 | @writer | 
|  | 209 | def writeI32(self, i32): | 
|  | 210 | self.__writeVarint(makeZigZag(i32, 32)) | 
|  | 211 |  | 
|  | 212 | @writer | 
|  | 213 | def writeI64(self, i64): | 
|  | 214 | self.__writeVarint(makeZigZag(i64, 64)) | 
|  | 215 |  | 
|  | 216 | @writer | 
|  | 217 | def writeDouble(self, dub): | 
|  | 218 | self.trans.write(pack('!d', dub)) | 
|  | 219 |  | 
|  | 220 | def __writeString(self, s): | 
|  | 221 | self.__writeSize(len(s)) | 
|  | 222 | self.trans.write(s) | 
|  | 223 | writeString = writer(__writeString) | 
|  | 224 |  | 
|  | 225 | def readFieldBegin(self): | 
|  | 226 | assert self.state == FIELD_READ, self.state | 
|  | 227 | type = self.__readUByte() | 
|  | 228 | if type & 0x0f == TType.STOP: | 
|  | 229 | return (None, 0, 0) | 
|  | 230 | delta = type >> 4 | 
|  | 231 | if delta == 0: | 
|  | 232 | fid = self.__readI16() | 
|  | 233 | else: | 
|  | 234 | fid = self.__last_fid + delta | 
|  | 235 | self.__last_fid = fid | 
|  | 236 | type = type & 0x0f | 
|  | 237 | if type == CompactType.TRUE: | 
|  | 238 | self.state = BOOL_READ | 
|  | 239 | self.__bool_value = True | 
|  | 240 | elif type == CompactType.FALSE: | 
|  | 241 | self.state = BOOL_READ | 
|  | 242 | self.__bool_value = False | 
|  | 243 | else: | 
|  | 244 | self.state = VALUE_READ | 
|  | 245 | return (None, self.__getTType(type), fid) | 
|  | 246 |  | 
|  | 247 | def readFieldEnd(self): | 
|  | 248 | assert self.state in (VALUE_READ, BOOL_READ), self.state | 
|  | 249 | self.state = FIELD_READ | 
|  | 250 |  | 
|  | 251 | def __readUByte(self): | 
|  | 252 | result, = unpack('!B', self.trans.readAll(1)) | 
|  | 253 | return result | 
|  | 254 |  | 
|  | 255 | def __readByte(self): | 
|  | 256 | result, = unpack('!b', self.trans.readAll(1)) | 
|  | 257 | return result | 
|  | 258 |  | 
|  | 259 | def __readVarint(self): | 
|  | 260 | return readVarint(self.trans) | 
|  | 261 |  | 
|  | 262 | def __readZigZag(self): | 
|  | 263 | return fromZigZag(self.__readVarint()) | 
|  | 264 |  | 
|  | 265 | def __readSize(self): | 
|  | 266 | result = self.__readVarint() | 
|  | 267 | if result < 0: | 
|  | 268 | raise TException("Length < 0") | 
|  | 269 | return result | 
|  | 270 |  | 
|  | 271 | def readMessageBegin(self): | 
|  | 272 | assert self.state == CLEAR | 
|  | 273 | proto_id = self.__readUByte() | 
|  | 274 | if proto_id != self.PROTOCOL_ID: | 
|  | 275 | raise TProtocolException(TProtocolException.BAD_VERSION, | 
|  | 276 | 'Bad protocol id in the message: %d' % proto_id) | 
|  | 277 | ver_type = self.__readUByte() | 
|  | 278 | type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT | 
|  | 279 | version = ver_type & self.VERSION_MASK | 
|  | 280 | if version != self.VERSION: | 
|  | 281 | raise TProtocolException(TProtocolException.BAD_VERSION, | 
|  | 282 | 'Bad version: %d (expect %d)' % (version, self.VERSION)) | 
|  | 283 | seqid = self.__readVarint() | 
|  | 284 | name = self.__readString() | 
|  | 285 | return (name, type, seqid) | 
|  | 286 |  | 
|  | 287 | def readMessageEnd(self): | 
|  | 288 | assert self.state == VALUE_READ | 
|  | 289 | assert len(self.__structs) == 0 | 
|  | 290 | self.state = CLEAR | 
|  | 291 |  | 
|  | 292 | def readStructBegin(self): | 
|  | 293 | assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state | 
|  | 294 | self.__structs.append((self.state, self.__last_fid)) | 
|  | 295 | self.state = FIELD_READ | 
|  | 296 | self.__last_fid = 0 | 
|  | 297 |  | 
|  | 298 | def readStructEnd(self): | 
|  | 299 | assert self.state == FIELD_READ | 
|  | 300 | self.state, self.__last_fid = self.__structs.pop() | 
|  | 301 |  | 
|  | 302 | def readCollectionBegin(self): | 
|  | 303 | assert self.state in (VALUE_READ, CONTAINER_READ), self.state | 
|  | 304 | size_type = self.__readUByte() | 
|  | 305 | size = size_type >> 4 | 
|  | 306 | type = self.__getTType(size_type) | 
|  | 307 | if size == 15: | 
|  | 308 | size = self.__readSize() | 
|  | 309 | self.__containers.append(self.state) | 
|  | 310 | self.state = CONTAINER_READ | 
|  | 311 | return type, size | 
|  | 312 | readSetBegin = readCollectionBegin | 
|  | 313 | readListBegin = readCollectionBegin | 
|  | 314 |  | 
|  | 315 | def readMapBegin(self): | 
|  | 316 | assert self.state in (VALUE_READ, CONTAINER_READ), self.state | 
|  | 317 | size = self.__readSize() | 
|  | 318 | types = 0 | 
|  | 319 | if size > 0: | 
|  | 320 | types = self.__readUByte() | 
|  | 321 | vtype = self.__getTType(types) | 
|  | 322 | ktype = self.__getTType(types >> 4) | 
|  | 323 | self.__containers.append(self.state) | 
|  | 324 | self.state = CONTAINER_READ | 
|  | 325 | return (ktype, vtype, size) | 
|  | 326 |  | 
|  | 327 | def readCollectionEnd(self): | 
|  | 328 | assert self.state == CONTAINER_READ, self.state | 
|  | 329 | self.state = self.__containers.pop() | 
|  | 330 | readSetEnd = readCollectionEnd | 
|  | 331 | readListEnd = readCollectionEnd | 
|  | 332 | readMapEnd = readCollectionEnd | 
|  | 333 |  | 
|  | 334 | def readBool(self): | 
|  | 335 | if self.state == BOOL_READ: | 
|  | 336 | return self.__bool_value | 
|  | 337 | elif self.state == CONTAINER_READ: | 
|  | 338 | return bool(self.__readByte()) | 
|  | 339 | else: | 
|  | 340 | raise AssertionError, "Invalid state in compact protocol: %d" % self.state | 
|  | 341 |  | 
|  | 342 | readByte = reader(__readByte) | 
|  | 343 | __readI16 = __readZigZag | 
|  | 344 | readI16 = reader(__readZigZag) | 
|  | 345 | readI32 = reader(__readZigZag) | 
|  | 346 | readI64 = reader(__readZigZag) | 
|  | 347 |  | 
|  | 348 | @reader | 
|  | 349 | def readDouble(self): | 
|  | 350 | buff = self.trans.readAll(8) | 
|  | 351 | val, = unpack('!d', buff) | 
|  | 352 | return val | 
|  | 353 |  | 
|  | 354 | def __readString(self): | 
|  | 355 | len = self.__readSize() | 
|  | 356 | return self.trans.readAll(len) | 
|  | 357 | readString = reader(__readString) | 
|  | 358 |  | 
|  | 359 | def __getTType(self, byte): | 
|  | 360 | return TTYPES[byte & 0x0f] | 
|  | 361 |  | 
|  | 362 |  | 
|  | 363 | class TCompactProtocolFactory: | 
|  | 364 | def __init__(self): | 
|  | 365 | pass | 
|  | 366 |  | 
|  | 367 | def getProtocol(self, trans): | 
|  | 368 | return TCompactProtocol(trans) |