THRIFT-1115 python TBase class for dynamic (de)serialization, and __slots__ option for memory savings
Patch: Will Pierce
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1169492 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index af6f58d..1d271fc 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -38,6 +38,25 @@
UTF8 = 16
UTF16 = 17
+ _VALUES_TO_NAMES = ( 'STOP',
+ 'VOID',
+ 'BOOL',
+ 'BYTE',
+ 'DOUBLE',
+ None,
+ 'I16',
+ None,
+ 'I32',
+ None,
+ 'I64',
+ 'STRING',
+ 'STRUCT',
+ 'MAP',
+ 'SET',
+ 'LIST',
+ 'UTF8',
+ 'UTF16' )
+
class TMessageType:
CALL = 1
REPLY = 2
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
new file mode 100644
index 0000000..e675c7d
--- /dev/null
+++ b/lib/py/src/protocol/TBase.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.Thrift import *
+from thrift.protocol import TBinaryProtocol
+from thrift.transport import TTransport
+
+try:
+ from thrift.protocol import fastbinary
+except:
+ fastbinary = None
+
+class TBase(object):
+ __slots__ = []
+
+ def __repr__(self):
+ L = ['%s=%r' % (key, getattr(self, key))
+ for key in self.__slots__ ]
+ return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+ for attr in self.__slots__:
+ my_val = getattr(self, attr)
+ other_val = getattr(other, attr)
+ if my_val != other_val:
+ return False
+ return True
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def read(self, iprot):
+ if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None:
+ fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))
+ return
+ iprot.readStruct(self, self.thrift_spec)
+
+ def write(self, oprot):
+ if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None:
+ oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
+ return
+ oprot.writeStruct(self, self.thrift_spec)
+
+class TExceptionBase(Exception):
+ # old style class so python2.4 can raise exceptions derived from this
+ # This can't inherit from TBase because of that limitation.
+ __slots__ = []
+
+ __repr__ = TBase.__repr__.im_func
+ __eq__ = TBase.__eq__.im_func
+ __ne__ = TBase.__ne__.im_func
+ read = TBase.read.im_func
+ write = TBase.write.im_func
+
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 6d57aeb..016a331 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -1,3 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
from TProtocol import *
from struct import pack, unpack
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index be3cb14..7338ff6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -200,6 +200,205 @@
self.skip(etype)
self.readListEnd()
+ # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name )
+ _TTYPE_HANDLERS = (
+ (None, None, False), # 0 == TType,STOP
+ (None, None, False), # 1 == TType.VOID # TODO: handle void?
+ ('readBool', 'writeBool', False), # 2 == TType.BOOL
+ ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08
+ ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE
+ (None, None, False), # 5, undefined
+ ('readI16', 'writeI16', False), # 6 == TType.I16
+ (None, None, False), # 7, undefined
+ ('readI32', 'writeI32', False), # 8 == TType.I32
+ (None, None, False), # 9, undefined
+ ('readI64', 'writeI64', False), # 10 == TType.I64
+ ('readString', 'writeString', False), # 11 == TType.STRING and UTF7
+ ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT
+ ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP
+ ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET
+ ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST
+ (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types?
+ (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types?
+ )
+
+ def readFieldByTType(self, ttype, spec):
+ try:
+ (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
+ except IndexError:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid field type %d' % (ttype))
+ if r_handler is None:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid field type %d' % (ttype))
+ reader = getattr(self, r_handler)
+ if not is_container:
+ return reader()
+ return reader(spec)
+
+ def readContainerList(self, spec):
+ results = []
+ ttype, tspec = spec[0], spec[1]
+ r_handler = self._TTYPE_HANDLERS[ttype][0]
+ reader = getattr(self, r_handler)
+ (list_type, list_len) = self.readListBegin()
+ if tspec is None:
+ # list values are simple types
+ for idx in xrange(list_len):
+ results.append(reader())
+ else:
+ # this is like an inlined readFieldByTType
+ container_reader = self._TTYPE_HANDLERS[list_type][0]
+ val_reader = getattr(self, container_reader)
+ for idx in xrange(list_len):
+ val = val_reader(tspec)
+ results.append(val)
+ self.readListEnd()
+ return results
+
+ def readContainerSet(self, spec):
+ results = set()
+ ttype, tspec = spec[0], spec[1]
+ r_handler = self._TTYPE_HANDLERS[ttype][0]
+ reader = getattr(self, r_handler)
+ (set_type, set_len) = self.readSetBegin()
+ if tspec is None:
+ # set members are simple types
+ for idx in xrange(set_len):
+ results.add(reader())
+ else:
+ container_reader = self._TTYPE_HANDLERS[set_type][0]
+ val_reader = getattr(self, container_reader)
+ for idx in xrange(set_len):
+ results.add(val_reader(tspec))
+ self.readSetEnd()
+ return results
+
+ def readContainerStruct(self, spec):
+ (obj_class, obj_spec) = spec
+ obj = obj_class()
+ obj.read(self)
+ return obj
+
+ def readContainerMap(self, spec):
+ results = dict()
+ key_ttype, key_spec = spec[0], spec[1]
+ val_ttype, val_spec = spec[2], spec[3]
+ (map_ktype, map_vtype, map_len) = self.readMapBegin()
+ # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree
+ key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
+ val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
+ # list values are simple types
+ for idx in xrange(map_len):
+ if key_spec is None:
+ k_val = key_reader()
+ else:
+ k_val = self.readFieldByTType(key_ttype, key_spec)
+ if val_spec is None:
+ v_val = val_reader()
+ else:
+ v_val = self.readFieldByTType(val_ttype, val_spec)
+ # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails
+ results[k_val] = v_val
+ self.readMapEnd()
+ return results
+
+ def readStruct(self, obj, thrift_spec):
+ self.readStructBegin()
+ while True:
+ (fname, ftype, fid) = self.readFieldBegin()
+ if ftype == TType.STOP:
+ break
+ try:
+ field = thrift_spec[fid]
+ except IndexError:
+ self.skip(ftype)
+ else:
+ if field is not None and ftype == field[1]:
+ fname = field[2]
+ fspec = field[3]
+ val = self.readFieldByTType(ftype, fspec)
+ setattr(obj, fname, val)
+ else:
+ self.skip(ftype)
+ self.readFieldEnd()
+ self.readStructEnd()
+
+ def writeContainerStruct(self, val, spec):
+ val.write(self)
+
+ def writeContainerList(self, val, spec):
+ self.writeListBegin(spec[0], len(val))
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+ e_writer = getattr(self, w_handler)
+ if not is_container:
+ for elem in val:
+ e_writer(elem)
+ else:
+ for elem in val:
+ e_writer(elem, spec[1])
+ self.writeListEnd()
+
+ def writeContainerSet(self, val, spec):
+ self.writeSetBegin(spec[0], len(val))
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+ e_writer = getattr(self, w_handler)
+ if not is_container:
+ for elem in val:
+ e_writer(elem)
+ else:
+ for elem in val:
+ e_writer(elem, spec[1])
+ self.writeSetEnd()
+
+ def writeContainerMap(self, val, spec):
+ k_type = spec[0]
+ v_type = spec[2]
+ ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
+ ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
+ k_writer = getattr(self, ktype_name)
+ v_writer = getattr(self, vtype_name)
+ self.writeMapBegin(k_type, v_type, len(val))
+ for m_key, m_val in val.iteritems():
+ if not k_is_container:
+ k_writer(m_key)
+ else:
+ k_writer(m_key, spec[1])
+ if not v_is_container:
+ v_writer(m_val)
+ else:
+ v_writer(m_val, spec[3])
+ self.writeMapEnd()
+
+ def writeStruct(self, obj, thrift_spec):
+ self.writeStructBegin(obj.__class__.__name__)
+ for field in thrift_spec:
+ if field is None:
+ continue
+ fname = field[2]
+ val = getattr(obj, fname)
+ if val is None:
+ # skip writing out unset fields
+ continue
+ fid = field[0]
+ ftype = field[1]
+ fspec = field[3]
+ # get the writer method for this value
+ self.writeFieldBegin(fname, ftype, fid)
+ self.writeFieldByTType(ftype, val, fspec)
+ self.writeFieldEnd()
+ self.writeFieldStop()
+ self.writeStructEnd()
+
+ def writeFieldByTType(self, ttype, val, spec):
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
+ writer = getattr(self, w_handler)
+ if is_container:
+ writer(val, spec)
+ else:
+ writer(val)
+
class TProtocolFactory:
def getProtocol(self, trans):
pass
+
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index 01bfe18..d53359b 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -17,4 +17,4 @@
# under the License.
#
-__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary']
+__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']