THRIFT-1115 python TBase class for dynamic (de)serialization, and __slots__ option for memory savings
Patch: Will Pierce

git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1169492 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index af6f58d..1d271fc 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -38,6 +38,25 @@
   UTF8   = 16
   UTF16  = 17
 
+  _VALUES_TO_NAMES = ( 'STOP',
+                      'VOID',
+                      'BOOL',
+                      'BYTE',
+                      'DOUBLE',
+                      None,
+                      'I16',
+                      None,
+                      'I32',
+                      None,
+                       'I64',
+                       'STRING',
+                       'STRUCT',
+                       'MAP',
+                       'SET',
+                       'LIST',
+                       'UTF8',
+                       'UTF16' )
+
 class TMessageType:
   CALL  = 1
   REPLY = 2
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
new file mode 100644
index 0000000..e675c7d
--- /dev/null
+++ b/lib/py/src/protocol/TBase.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.Thrift import *
+from thrift.protocol import TBinaryProtocol
+from thrift.transport import TTransport
+
+try:
+  from thrift.protocol import fastbinary
+except:
+  fastbinary = None
+
+class TBase(object):
+  __slots__ = []
+
+  def __repr__(self):
+    L = ['%s=%r' % (key, getattr(self, key))
+              for key in self.__slots__ ]
+    return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
+
+  def __eq__(self, other):
+    if not isinstance(other, self.__class__):
+      return False
+    for attr in self.__slots__:
+      my_val = getattr(self, attr)
+      other_val = getattr(other, attr)
+      if my_val != other_val:
+        return False
+    return True
+    
+  def __ne__(self, other):
+    return not (self == other)
+  
+  def read(self, iprot):
+    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None:
+      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))
+      return
+    iprot.readStruct(self, self.thrift_spec)
+
+  def write(self, oprot):
+    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None:
+      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
+      return
+    oprot.writeStruct(self, self.thrift_spec)
+
+class TExceptionBase(Exception):
+  # old style class so python2.4 can raise exceptions derived from this
+  #  This can't inherit from TBase because of that limitation.
+  __slots__ = []
+  
+  __repr__ = TBase.__repr__.im_func
+  __eq__ = TBase.__eq__.im_func
+  __ne__ = TBase.__ne__.im_func
+  read = TBase.read.im_func
+  write = TBase.write.im_func
+  
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 6d57aeb..016a331 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -1,3 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
 from TProtocol import *
 from struct import pack, unpack
 
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index be3cb14..7338ff6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -200,6 +200,205 @@
         self.skip(etype)
       self.readListEnd()
 
+  # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name )
+  _TTYPE_HANDLERS = (
+       (None, None, False), # 0 == TType,STOP
+       (None, None, False), # 1 == TType.VOID # TODO: handle void?
+       ('readBool', 'writeBool', False), # 2 == TType.BOOL
+       ('readByte',  'writeByte', False), # 3 == TType.BYTE and I08
+       ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE
+       (None, None, False), # 5, undefined
+       ('readI16', 'writeI16', False), # 6 == TType.I16
+       (None, None, False), # 7, undefined
+       ('readI32', 'writeI32', False), # 8 == TType.I32
+       (None, None, False), # 9, undefined
+       ('readI64', 'writeI64', False), # 10 == TType.I64
+       ('readString', 'writeString', False), # 11 == TType.STRING and UTF7
+       ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT
+       ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP
+       ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET
+       ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST
+       (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types?
+       (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types?
+      )
+
+  def readFieldByTType(self, ttype, spec):
+    try:
+      (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
+    except IndexError:
+      raise TProtocolException(type=TProtocolException.INVALID_DATA,
+                               message='Invalid field type %d' % (ttype))
+    if r_handler is None:
+      raise TProtocolException(type=TProtocolException.INVALID_DATA,
+                               message='Invalid field type %d' % (ttype))
+    reader = getattr(self, r_handler)
+    if not is_container:
+      return reader()
+    return reader(spec)
+
+  def readContainerList(self, spec):
+    results = []
+    ttype, tspec = spec[0], spec[1]
+    r_handler = self._TTYPE_HANDLERS[ttype][0]
+    reader = getattr(self, r_handler)
+    (list_type, list_len) = self.readListBegin()
+    if tspec is None:
+      # list values are simple types
+      for idx in xrange(list_len):
+        results.append(reader())
+    else:
+      # this is like an inlined readFieldByTType
+      container_reader = self._TTYPE_HANDLERS[list_type][0]
+      val_reader = getattr(self, container_reader)
+      for idx in xrange(list_len):
+        val = val_reader(tspec)
+        results.append(val)
+    self.readListEnd()
+    return results
+
+  def readContainerSet(self, spec):
+    results = set()
+    ttype, tspec = spec[0], spec[1]
+    r_handler = self._TTYPE_HANDLERS[ttype][0]
+    reader = getattr(self, r_handler)
+    (set_type, set_len) = self.readSetBegin()
+    if tspec is None:
+      # set members are simple types
+      for idx in xrange(set_len):
+        results.add(reader())
+    else:
+      container_reader = self._TTYPE_HANDLERS[set_type][0]
+      val_reader = getattr(self, container_reader)
+      for idx in xrange(set_len):
+        results.add(val_reader(tspec)) 
+    self.readSetEnd()
+    return results
+
+  def readContainerStruct(self, spec):
+    (obj_class, obj_spec) = spec
+    obj = obj_class()
+    obj.read(self)
+    return obj
+  
+  def readContainerMap(self, spec):
+    results = dict()
+    key_ttype, key_spec = spec[0], spec[1]
+    val_ttype, val_spec = spec[2], spec[3]
+    (map_ktype, map_vtype, map_len) = self.readMapBegin()
+    # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree
+    key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
+    val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
+    # list values are simple types
+    for idx in xrange(map_len):
+      if key_spec is None:
+        k_val = key_reader()
+      else:
+        k_val = self.readFieldByTType(key_ttype, key_spec)
+      if val_spec is None:
+        v_val = val_reader()
+      else:
+        v_val = self.readFieldByTType(val_ttype, val_spec)
+      # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails
+      results[k_val] = v_val
+    self.readMapEnd()
+    return results
+
+  def readStruct(self, obj, thrift_spec):
+    self.readStructBegin()
+    while True:
+      (fname, ftype, fid) = self.readFieldBegin()
+      if ftype == TType.STOP:
+        break
+      try:
+        field = thrift_spec[fid]
+      except IndexError:
+        self.skip(ftype)
+      else:
+        if field is not None and ftype == field[1]:
+          fname = field[2]
+          fspec = field[3]
+          val = self.readFieldByTType(ftype, fspec)
+          setattr(obj, fname, val)
+        else:
+          self.skip(ftype)
+      self.readFieldEnd()
+    self.readStructEnd()
+
+  def writeContainerStruct(self, val, spec):
+    val.write(self)
+
+  def writeContainerList(self, val, spec):
+    self.writeListBegin(spec[0], len(val))
+    r_handler, w_handler, is_container  = self._TTYPE_HANDLERS[spec[0]]
+    e_writer = getattr(self, w_handler)
+    if not is_container:
+      for elem in val:
+        e_writer(elem)
+    else:
+      for elem in val:
+        e_writer(elem, spec[1])
+    self.writeListEnd()
+
+  def writeContainerSet(self, val, spec):
+    self.writeSetBegin(spec[0], len(val))
+    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+    e_writer = getattr(self, w_handler)
+    if not is_container:
+      for elem in val:
+        e_writer(elem)
+    else:
+      for elem in val:
+        e_writer(elem, spec[1])
+    self.writeSetEnd()
+
+  def writeContainerMap(self, val, spec):
+    k_type = spec[0]
+    v_type = spec[2]
+    ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
+    ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
+    k_writer = getattr(self, ktype_name)
+    v_writer = getattr(self, vtype_name)
+    self.writeMapBegin(k_type, v_type, len(val))
+    for m_key, m_val in val.iteritems():
+      if not k_is_container:
+        k_writer(m_key)
+      else:
+        k_writer(m_key, spec[1])
+      if not v_is_container:
+        v_writer(m_val)
+      else:
+        v_writer(m_val, spec[3])
+    self.writeMapEnd()
+
+  def writeStruct(self, obj, thrift_spec):
+    self.writeStructBegin(obj.__class__.__name__)
+    for field in thrift_spec:
+      if field is None:
+        continue
+      fname = field[2]
+      val = getattr(obj, fname)
+      if val is None:
+        # skip writing out unset fields
+        continue
+      fid = field[0]
+      ftype = field[1]
+      fspec = field[3]
+      # get the writer method for this value
+      self.writeFieldBegin(fname, ftype, fid)
+      self.writeFieldByTType(ftype, val, fspec)
+      self.writeFieldEnd()
+    self.writeFieldStop()
+    self.writeStructEnd()
+
+  def writeFieldByTType(self, ttype, val, spec):
+    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
+    writer = getattr(self, w_handler)
+    if is_container:
+      writer(val, spec)
+    else:
+      writer(val)
+
 class TProtocolFactory:
   def getProtocol(self, trans):
     pass
+
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index 01bfe18..d53359b 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -17,4 +17,4 @@
 # under the License.
 #
 
-__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary']
+__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']