THRIFT-3525 py:dynamic fails to handle binary list/set/map element
This closes #775
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index be2fcea..9679ba0 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -18,10 +18,13 @@
#
from thrift.Thrift import TException, TType, TFrozenDict
-import six
-
from ..compat import binary_to_str, str_to_binary
+import six
+import sys
+from itertools import islice
+from six.moves import zip
+
class TProtocolException(TException):
"""Custom Protocol Exception class"""
@@ -239,61 +242,38 @@
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid binary field type %d' % ttype)
return ('readBinary', 'writeBinary', False)
- return self._TTYPE_HANDLERS[ttype]
+ return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
+
+ def _read_by_ttype(self, ttype, spec, espec):
+ reader_name, _, is_container = self._ttype_handlers(ttype, spec)
+ if reader_name is None:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid type %d' % (ttype))
+ reader_func = getattr(self, reader_name)
+ read = (lambda: reader_func(espec)) if is_container else reader_func
+ while True:
+ yield read()
def readFieldByTType(self, ttype, spec):
- try:
- (r_handler, w_handler, is_container) = self._ttype_handlers(ttype, spec)
- except IndexError:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid field type %d' % (ttype))
- if r_handler is None:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid field type %d' % (ttype))
- reader = getattr(self, r_handler)
- if not is_container:
- return reader()
- return reader(spec)
+ return self._read_by_ttype(ttype, spec, spec).next()
def readContainerList(self, spec):
- results = []
- ttype, tspec = spec[0], spec[1]
- is_immutable = spec[2]
- r_handler = self._ttype_handlers(ttype, spec)[0]
- reader = getattr(self, r_handler)
+ ttype, tspec, is_immutable = spec
(list_type, list_len) = self.readListBegin()
- if tspec is None:
- # list values are simple types
- for idx in range(list_len):
- results.append(reader())
- else:
- # this is like an inlined readFieldByTType
- container_reader = self._ttype_handlers(list_type, tspec)[0]
- val_reader = getattr(self, container_reader)
- for idx in range(list_len):
- val = val_reader(tspec)
- results.append(val)
+ # TODO: compare types we just decoded with thrift_spec
+ elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
+ results = (tuple if is_immutable else list)(elems)
self.readListEnd()
- return tuple(results) if is_immutable else results
+ return results
def readContainerSet(self, spec):
- results = set()
- ttype, tspec = spec[0], spec[1]
- is_immutable = spec[2]
- r_handler = self._ttype_handlers(ttype, spec)[0]
- reader = getattr(self, r_handler)
+ ttype, tspec, is_immutable = spec
(set_type, set_len) = self.readSetBegin()
- if tspec is None:
- # set members are simple types
- for idx in range(set_len):
- results.add(reader())
- else:
- container_reader = self._ttype_handlers(set_type, tspec)[0]
- val_reader = getattr(self, container_reader)
- for idx in range(set_len):
- results.add(val_reader(tspec))
+ # TODO: compare types we just decoded with thrift_spec
+ elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
+ results = (frozenset if is_immutable else set)(elems)
self.readSetEnd()
- return frozenset(results) if is_immutable else results
+ return results
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
@@ -302,30 +282,16 @@
return obj
def readContainerMap(self, spec):
- results = dict()
- key_ttype, key_spec = spec[0], spec[1]
- val_ttype, val_spec = spec[2], spec[3]
- is_immutable = spec[4]
+ ktype, kspec, vtype, vspec, is_immutable = spec
(map_ktype, map_vtype, map_len) = self.readMapBegin()
# TODO: compare types we just decoded with thrift_spec and
# abort/skip if types disagree
- key_reader = getattr(self, self._ttype_handlers(key_ttype, key_spec)[0])
- val_reader = getattr(self, self._ttype_handlers(val_ttype, val_spec)[0])
- # list values are simple types
- for idx in range(map_len):
- if key_spec is None:
- k_val = key_reader()
- else:
- k_val = self.readFieldByTType(key_ttype, key_spec)
- if val_spec is None:
- v_val = val_reader()
- else:
- v_val = self.readFieldByTType(val_ttype, val_spec)
- # this raises a TypeError with unhashable keys types
- # i.e. this fails: d=dict(); d[[0,1]] = 2
- results[k_val] = v_val
+ keys = self._read_by_ttype(ktype, spec, kspec)
+ vals = self._read_by_ttype(vtype, spec, vspec)
+ keyvals = islice(zip(keys, vals), map_len)
+ results = (TFrozenDict if is_immutable else dict)(keyvals)
self.readMapEnd()
- return TFrozenDict(results) if is_immutable else results
+ return results
def readStruct(self, obj, thrift_spec, is_immutable=False):
if is_immutable:
@@ -359,46 +325,25 @@
val.write(self)
def writeContainerList(self, val, spec):
- self.writeListBegin(spec[0], len(val))
- r_handler, w_handler, is_container = self._ttype_handlers(spec[0], spec)
- e_writer = getattr(self, w_handler)
- if not is_container:
- for elem in val:
- e_writer(elem)
- else:
- for elem in val:
- e_writer(elem, spec[1])
+ ttype, tspec, _ = spec
+ self.writeListBegin(ttype, len(val))
+ for _ in self._write_by_ttype(ttype, val, spec, tspec):
+ pass
self.writeListEnd()
def writeContainerSet(self, val, spec):
- self.writeSetBegin(spec[0], len(val))
- r_handler, w_handler, is_container = self._ttype_handlers(spec[0], spec)
- e_writer = getattr(self, w_handler)
- if not is_container:
- for elem in val:
- e_writer(elem)
- else:
- for elem in val:
- e_writer(elem, spec[1])
+ ttype, tspec, _ = spec
+ self.writeSetBegin(ttype, len(val))
+ for _ in self._write_by_ttype(ttype, val, spec, tspec):
+ pass
self.writeSetEnd()
def writeContainerMap(self, val, spec):
- k_type = spec[0]
- v_type = spec[2]
- ignore, ktype_name, k_is_container = self._ttype_handlers(k_type, spec)
- ignore, vtype_name, v_is_container = self._ttype_handlers(v_type, spec)
- k_writer = getattr(self, ktype_name)
- v_writer = getattr(self, vtype_name)
- self.writeMapBegin(k_type, v_type, len(val))
- for m_key, m_val in six.iteritems(val):
- if not k_is_container:
- k_writer(m_key)
- else:
- k_writer(m_key, spec[1])
- if not v_is_container:
- v_writer(m_val)
- else:
- v_writer(m_val, spec[3])
+ ktype, kspec, vtype, vspec, _ = spec
+ self.writeMapBegin(ktype, vtype, len(val))
+ for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
+ self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
+ pass
self.writeMapEnd()
def writeStruct(self, obj, thrift_spec):
@@ -414,20 +359,21 @@
fid = field[0]
ftype = field[1]
fspec = field[3]
- # get the writer method for this value
self.writeFieldBegin(fname, ftype, fid)
self.writeFieldByTType(ftype, val, fspec)
self.writeFieldEnd()
self.writeFieldStop()
self.writeStructEnd()
+ def _write_by_ttype(self, ttype, vals, spec, espec):
+ _, writer_name, is_container = self._ttype_handlers(ttype, spec)
+ writer_func = getattr(self, writer_name)
+ write = (lambda v: writer_func(v, espec)) if is_container else writer_func
+ for v in vals:
+ yield write(v)
+
def writeFieldByTType(self, ttype, val, spec):
- r_handler, w_handler, is_container = self._ttype_handlers(ttype, spec)
- writer = getattr(self, w_handler)
- if is_container:
- writer(val, spec)
- else:
- writer(val)
+ self._write_by_ttype(ttype, [val], spec, spec).next()
def checkIntegerLimits(i, bits):