THRIFT-162 Thrift structures are unhashable, preventing them from being used as set elements
Client: Python
Patch: David Reiss, Nobuaki Sukegawa

This closes #714
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 9890af7..cbb9184 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -168,3 +168,23 @@
       oprot.writeFieldEnd()
     oprot.writeFieldStop()
     oprot.writeStructEnd()
+
+
+class TFrozenDict(dict):
+  """A dictionary that is "frozen" like a frozenset"""
+
+  def __init__(self, *args, **kwargs):
+    super(TFrozenDict, self).__init__(*args, **kwargs)
+    # Sort the items so they will be in a consistent order.
+    # XOR in the hash of the class so we don't collide with
+    # the hash of a list of tuples.
+    self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
+
+  def __setitem__(self, *args):
+    raise TypeError("Can't modify frozen TFreezableDict")
+
+  def __delitem__(self, *args):
+    raise TypeError("Can't modify frozen TFreezableDict")
+
+  def __hash__(self):
+    return self.__hashval
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 118a679..4f71e11 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -27,7 +27,7 @@
 
 
 class TBase(object):
-  __slots__ = []
+  __slots__ = ()
 
   def __repr__(self):
     L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
@@ -68,4 +68,27 @@
 
 
 class TExceptionBase(TBase, Exception):
-  __slots__ = []
+  pass
+
+
+class TFrozenBase(TBase):
+  def __setitem__(self, *args):
+    raise TypeError("Can't modify frozen struct")
+
+  def __delitem__(self, *args):
+    raise TypeError("Can't modify frozen struct")
+
+  def __hash__(self, *args):
+    return hash(self.__class__) ^ hash(self.__slots__)
+
+  @classmethod
+  def read(cls, iprot):
+    if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+       isinstance(iprot.trans, TTransport.CReadableTransport) and
+       cls.thrift_spec is not None and
+       fastbinary is not None):
+      self = cls()
+      return fastbinary.decode_binary(None,
+                                      iprot.trans,
+                                      (self.__class__, self.thrift_spec))
+    return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index ca22c48..1d703e3 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -17,7 +17,7 @@
 # under the License.
 #
 
-from thrift.Thrift import TException, TType
+from thrift.Thrift import TException, TType, TFrozenDict
 import six
 
 from ..compat import binary_to_str, str_to_binary
@@ -108,9 +108,6 @@
   def writeBinary(self, str_val):
     pass
 
-  def writeBinary(self, str_val):
-    return self.writeString(str_val)
-
   def readMessageBegin(self):
     pass
 
@@ -171,9 +168,6 @@
   def readBinary(self):
     pass
 
-  def readBinary(self):
-    return self.readString()
-
   def skip(self, ttype):
     if ttype == TType.STOP:
       return
@@ -264,6 +258,7 @@
   def readContainerList(self, spec):
     results = []
     ttype, tspec = spec[0], spec[1]
+    is_immutable = spec[2]
     r_handler = self._ttype_handlers(ttype, spec)[0]
     reader = getattr(self, r_handler)
     (list_type, list_len) = self.readListBegin()
@@ -279,11 +274,12 @@
         val = val_reader(tspec)
         results.append(val)
     self.readListEnd()
-    return results
+    return tuple(results) if is_immutable else results
 
   def readContainerSet(self, spec):
     results = set()
     ttype, tspec = spec[0], spec[1]
+    is_immutable = spec[2]
     r_handler = self._ttype_handlers(ttype, spec)[0]
     reader = getattr(self, r_handler)
     (set_type, set_len) = self.readSetBegin()
@@ -297,7 +293,7 @@
       for idx in range(set_len):
         results.add(val_reader(tspec))
     self.readSetEnd()
-    return results
+    return frozenset(results) if is_immutable else results
 
   def readContainerStruct(self, spec):
     (obj_class, obj_spec) = spec
@@ -309,6 +305,7 @@
     results = dict()
     key_ttype, key_spec = spec[0], spec[1]
     val_ttype, val_spec = spec[2], spec[3]
+    is_immutable = spec[4]
     (map_ktype, map_vtype, map_len) = self.readMapBegin()
     # TODO: compare types we just decoded with thrift_spec and
     # abort/skip if types disagree
@@ -328,9 +325,11 @@
       # i.e. this fails: d=dict(); d[[0,1]] = 2
       results[k_val] = v_val
     self.readMapEnd()
-    return results
+    return TFrozenDict(results) if is_immutable else results
 
-  def readStruct(self, obj, thrift_spec):
+  def readStruct(self, obj, thrift_spec, is_immutable=False):
+    if is_immutable:
+      fields = {}
     self.readStructBegin()
     while True:
       (fname, ftype, fid) = self.readFieldBegin()
@@ -345,11 +344,16 @@
           fname = field[2]
           fspec = field[3]
           val = self.readFieldByTType(ftype, fspec)
-          setattr(obj, fname, val)
+          if is_immutable:
+            fields[fname] = val
+          else:
+            setattr(obj, fname, val)
         else:
           self.skip(ftype)
       self.readFieldEnd()
     self.readStructEnd()
+    if is_immutable:
+      return obj(**fields)
 
   def writeContainerStruct(self, val, spec):
     val.write(self)
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 93c4911..a17019b 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -124,11 +124,6 @@
 #define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() )
 #define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) )
 
-// Py_ssize_t was not defined before Python 2.5
-#if (PY_VERSION_HEX < 0x02050000)
-typedef int Py_ssize_t;
-#endif
-
 /**
  * A cache of the spec_args for a set or list,
  * so we don't have to keep calling PyTuple_GET_ITEM.
@@ -136,6 +131,7 @@
 typedef struct {
   TType element_type;
   PyObject* typeargs;
+  bool immutable;
 } SetListTypeArgs;
 
 /**
@@ -147,6 +143,7 @@
   TType vtag;
   PyObject* ktypeargs;
   PyObject* vtypeargs;
+  bool immutable;
 } MapTypeArgs;
 
 /**
@@ -156,6 +153,7 @@
 typedef struct {
   PyObject* klass;
   PyObject* spec;
+  bool immutable;
 } StructTypeArgs;
 
 /**
@@ -233,8 +231,8 @@
 
 static bool
 parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
-  if (PyTuple_Size(typeargs) != 2) {
-    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args");
+  if (PyTuple_Size(typeargs) != 3) {
+    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 3 for list/set type args");
     return false;
   }
 
@@ -245,13 +243,15 @@
 
   dest->typeargs = PyTuple_GET_ITEM(typeargs, 1);
 
+  dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 2);
+
   return true;
 }
 
 static bool
 parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
-  if (PyTuple_Size(typeargs) != 4) {
-    PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map");
+  if (PyTuple_Size(typeargs) != 5) {
+    PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for typeargs to map");
     return false;
   }
 
@@ -267,6 +267,7 @@
 
   dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1);
   dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3);
+  dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 4);
 
   return true;
 }
@@ -289,7 +290,7 @@
 
   // i'd like to use ParseArgs here, but it seems to be a bottleneck.
   if (PyTuple_Size(spec_tuple) != 5) {
-    PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple");
+    PyErr_Format(PyExc_TypeError, "expecting 5 arguments for spec tuple but got %d", PyTuple_Size(spec_tuple));
     return false;
   }
 
@@ -885,11 +886,21 @@
 static PyObject*
 decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
 
-static bool
-decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) {
+static PyObject*
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) {
   int spec_seq_len = PyTuple_Size(spec_seq);
+  bool immutable = output == Py_None;
+  PyObject* kwargs = NULL;
   if (spec_seq_len == -1) {
-    return false;
+    return NULL;
+  }
+
+  if (immutable) {
+    kwargs = PyDict_New();
+    if (!kwargs) {
+      PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage");
+      return NULL;
+    }
   }
 
   while (true) {
@@ -901,14 +912,14 @@
 
     type = readByte(input);
     if (type == -1) {
-      return false;
+      goto error;
     }
     if (type == T_STOP) {
       break;
     }
     tag = readI16(input);
     if (INT_CONV_ERROR_OCCURRED(tag)) {
-      return false;
+      goto error;
     }
     if (tag >= 0 && tag < spec_seq_len) {
       item_spec = PyTuple_GET_ITEM(spec_seq, tag);
@@ -918,19 +929,19 @@
 
     if (item_spec == Py_None) {
       if (!skip(input, type)) {
-        return false;
+        goto error;
       } else {
         continue;
       }
     }
 
     if (!parse_struct_item_spec(&parsedspec, item_spec)) {
-      return false;
+      goto error;
     }
     if (parsedspec.type != type) {
       if (!skip(input, type)) {
         PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped");
-        return false;
+        goto error;
       } else {
         continue;
       }
@@ -938,16 +949,34 @@
 
     fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
     if (fieldval == NULL) {
-      return false;
+      goto error;
     }
 
-    if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) {
+    if ((immutable && PyDict_SetItem(kwargs, parsedspec.attrname, fieldval) == -1)
+        || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1)) {
       Py_DECREF(fieldval);
-      return false;
+      goto error;
     }
     Py_DECREF(fieldval);
   }
-  return true;
+  if (immutable) {
+    PyObject* args = PyTuple_New(0);
+    PyObject* ret = NULL;
+    if (!args) {
+      PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage");
+      goto error;
+    }
+    ret = PyObject_Call(klass, args, kwargs);
+    Py_DECREF(kwargs);
+    Py_DECREF(args);
+    return ret;
+  }
+  Py_INCREF(output);
+  return output;
+
+  error:
+  Py_XDECREF(kwargs);
+  return NULL;
 }
 
 
@@ -1033,6 +1062,7 @@
     int32_t len;
     PyObject* ret = NULL;
     int i;
+    bool use_tuple = false;
 
     if (!parse_set_list_args(&parsedargs, typeargs)) {
       return NULL;
@@ -1047,7 +1077,8 @@
       return NULL;
     }
 
-    ret = PyList_New(len);
+    use_tuple = type == T_LIST && parsedargs.immutable;
+    ret = use_tuple ? PyTuple_New(len) : PyList_New(len);
     if (!ret) {
       return NULL;
     }
@@ -1058,20 +1089,18 @@
         Py_DECREF(ret);
         return NULL;
       }
-      PyList_SET_ITEM(ret, i, item);
+      if (use_tuple) {
+        PyTuple_SET_ITEM(ret, i, item);
+      } else  {
+        PyList_SET_ITEM(ret, i, item);
+      }
     }
 
     // TODO(dreiss): Consider biting the bullet and making two separate cases
     //               for list and set, avoiding this post facto conversion.
     if (type == T_SET) {
       PyObject* setret;
-#if (PY_VERSION_HEX < 0x02050000)
-      // hack needed for older versions
-      setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL);
-#else
-      // official version
-      setret = PySet_New(ret);
-#endif
+      setret = parsedargs.immutable ? PyFrozenSet_New(ret) : PySet_New(ret);
       Py_DECREF(ret);
       return setret;
     }
@@ -1131,6 +1160,22 @@
       goto error;
     }
 
+    if (parsedargs.immutable) {
+      PyObject* thrift = PyImport_ImportModule("thrift.Thrift");
+      PyObject* cls = NULL;
+      PyObject* arg = NULL;
+      if (!thrift) {
+        goto error;
+      }
+      cls = PyObject_GetAttrString(thrift, "TFrozenDict");
+      if (!cls) {
+        goto error;
+      }
+      arg = PyTuple_New(1);
+      PyTuple_SET_ITEM(arg, 0, ret);
+      return PyObject_CallObject(cls, arg);
+    }
+
     return ret;
 
     error:
@@ -1140,22 +1185,12 @@
 
   case T_STRUCT: {
     StructTypeArgs parsedargs;
-	PyObject* ret;
+    PyObject* ret;
     if (!parse_struct_args(&parsedargs, typeargs)) {
       return NULL;
     }
 
-    ret = PyObject_CallObject(parsedargs.klass, NULL);
-    if (!ret) {
-      return NULL;
-    }
-
-    if (!decode_struct(input, ret, parsedargs.spec)) {
-      Py_DECREF(ret);
-      return NULL;
-    }
-
-    return ret;
+    return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec);
   }
 
   case T_STOP:
@@ -1179,6 +1214,7 @@
   PyObject* typeargs = NULL;
   StructTypeArgs parsedargs;
   DecodeBuffer input = {0, 0};
+  PyObject* ret = NULL;
 
   if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
     return NULL;
@@ -1192,14 +1228,9 @@
     return NULL;
   }
 
-  if (!decode_struct(&input, output_obj, parsedargs.spec)) {
-    free_decodebuf(&input);
-    return NULL;
-  }
-
+  ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec);
   free_decodebuf(&input);
-
-  Py_RETURN_NONE;
+  return  ret;
 }
 
 /* ====== END READING FUNCTIONS ====== */