THRIFT-162 Thrift structures are unhashable, preventing them from being used as set elements
Client: Python
Patch: David Reiss, Nobuaki Sukegawa
This closes #714
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 9890af7..cbb9184 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -168,3 +168,23 @@
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()
+
+
+class TFrozenDict(dict):
+ """A dictionary that is "frozen" like a frozenset"""
+
+ def __init__(self, *args, **kwargs):
+ super(TFrozenDict, self).__init__(*args, **kwargs)
+ # Sort the items so they will be in a consistent order.
+ # XOR in the hash of the class so we don't collide with
+ # the hash of a list of tuples.
+ self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
+
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
+
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
+
+ def __hash__(self):
+ return self.__hashval
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 118a679..4f71e11 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -27,7 +27,7 @@
class TBase(object):
- __slots__ = []
+ __slots__ = ()
def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
@@ -68,4 +68,27 @@
class TExceptionBase(TBase, Exception):
- __slots__ = []
+ pass
+
+
+class TFrozenBase(TBase):
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
+
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
+
+ def __hash__(self, *args):
+ return hash(self.__class__) ^ hash(self.__slots__)
+
+ @classmethod
+ def read(cls, iprot):
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ cls.thrift_spec is not None and
+ fastbinary is not None):
+ self = cls()
+ return fastbinary.decode_binary(None,
+ iprot.trans,
+ (self.__class__, self.thrift_spec))
+ return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index ca22c48..1d703e3 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -17,7 +17,7 @@
# under the License.
#
-from thrift.Thrift import TException, TType
+from thrift.Thrift import TException, TType, TFrozenDict
import six
from ..compat import binary_to_str, str_to_binary
@@ -108,9 +108,6 @@
def writeBinary(self, str_val):
pass
- def writeBinary(self, str_val):
- return self.writeString(str_val)
-
def readMessageBegin(self):
pass
@@ -171,9 +168,6 @@
def readBinary(self):
pass
- def readBinary(self):
- return self.readString()
-
def skip(self, ttype):
if ttype == TType.STOP:
return
@@ -264,6 +258,7 @@
def readContainerList(self, spec):
results = []
ttype, tspec = spec[0], spec[1]
+ is_immutable = spec[2]
r_handler = self._ttype_handlers(ttype, spec)[0]
reader = getattr(self, r_handler)
(list_type, list_len) = self.readListBegin()
@@ -279,11 +274,12 @@
val = val_reader(tspec)
results.append(val)
self.readListEnd()
- return results
+ return tuple(results) if is_immutable else results
def readContainerSet(self, spec):
results = set()
ttype, tspec = spec[0], spec[1]
+ is_immutable = spec[2]
r_handler = self._ttype_handlers(ttype, spec)[0]
reader = getattr(self, r_handler)
(set_type, set_len) = self.readSetBegin()
@@ -297,7 +293,7 @@
for idx in range(set_len):
results.add(val_reader(tspec))
self.readSetEnd()
- return results
+ return frozenset(results) if is_immutable else results
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
@@ -309,6 +305,7 @@
results = dict()
key_ttype, key_spec = spec[0], spec[1]
val_ttype, val_spec = spec[2], spec[3]
+ is_immutable = spec[4]
(map_ktype, map_vtype, map_len) = self.readMapBegin()
# TODO: compare types we just decoded with thrift_spec and
# abort/skip if types disagree
@@ -328,9 +325,11 @@
# i.e. this fails: d=dict(); d[[0,1]] = 2
results[k_val] = v_val
self.readMapEnd()
- return results
+ return TFrozenDict(results) if is_immutable else results
- def readStruct(self, obj, thrift_spec):
+ def readStruct(self, obj, thrift_spec, is_immutable=False):
+ if is_immutable:
+ fields = {}
self.readStructBegin()
while True:
(fname, ftype, fid) = self.readFieldBegin()
@@ -345,11 +344,16 @@
fname = field[2]
fspec = field[3]
val = self.readFieldByTType(ftype, fspec)
- setattr(obj, fname, val)
+ if is_immutable:
+ fields[fname] = val
+ else:
+ setattr(obj, fname, val)
else:
self.skip(ftype)
self.readFieldEnd()
self.readStructEnd()
+ if is_immutable:
+ return obj(**fields)
def writeContainerStruct(self, val, spec):
val.write(self)
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 93c4911..a17019b 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -124,11 +124,6 @@
#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() )
#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) )
-// Py_ssize_t was not defined before Python 2.5
-#if (PY_VERSION_HEX < 0x02050000)
-typedef int Py_ssize_t;
-#endif
-
/**
* A cache of the spec_args for a set or list,
* so we don't have to keep calling PyTuple_GET_ITEM.
@@ -136,6 +131,7 @@
typedef struct {
TType element_type;
PyObject* typeargs;
+ bool immutable;
} SetListTypeArgs;
/**
@@ -147,6 +143,7 @@
TType vtag;
PyObject* ktypeargs;
PyObject* vtypeargs;
+ bool immutable;
} MapTypeArgs;
/**
@@ -156,6 +153,7 @@
typedef struct {
PyObject* klass;
PyObject* spec;
+ bool immutable;
} StructTypeArgs;
/**
@@ -233,8 +231,8 @@
static bool
parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
- if (PyTuple_Size(typeargs) != 2) {
- PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args");
+ if (PyTuple_Size(typeargs) != 3) {
+ PyErr_SetString(PyExc_TypeError, "expecting tuple of size 3 for list/set type args");
return false;
}
@@ -245,13 +243,15 @@
dest->typeargs = PyTuple_GET_ITEM(typeargs, 1);
+ dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 2);
+
return true;
}
static bool
parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
- if (PyTuple_Size(typeargs) != 4) {
- PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map");
+ if (PyTuple_Size(typeargs) != 5) {
+ PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for typeargs to map");
return false;
}
@@ -267,6 +267,7 @@
dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1);
dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3);
+ dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 4);
return true;
}
@@ -289,7 +290,7 @@
// i'd like to use ParseArgs here, but it seems to be a bottleneck.
if (PyTuple_Size(spec_tuple) != 5) {
- PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple");
+ PyErr_Format(PyExc_TypeError, "expecting 5 arguments for spec tuple but got %d", PyTuple_Size(spec_tuple));
return false;
}
@@ -885,11 +886,21 @@
static PyObject*
decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
-static bool
-decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) {
+static PyObject*
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) {
int spec_seq_len = PyTuple_Size(spec_seq);
+ bool immutable = output == Py_None;
+ PyObject* kwargs = NULL;
if (spec_seq_len == -1) {
- return false;
+ return NULL;
+ }
+
+ if (immutable) {
+ kwargs = PyDict_New();
+ if (!kwargs) {
+ PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage");
+ return NULL;
+ }
}
while (true) {
@@ -901,14 +912,14 @@
type = readByte(input);
if (type == -1) {
- return false;
+ goto error;
}
if (type == T_STOP) {
break;
}
tag = readI16(input);
if (INT_CONV_ERROR_OCCURRED(tag)) {
- return false;
+ goto error;
}
if (tag >= 0 && tag < spec_seq_len) {
item_spec = PyTuple_GET_ITEM(spec_seq, tag);
@@ -918,19 +929,19 @@
if (item_spec == Py_None) {
if (!skip(input, type)) {
- return false;
+ goto error;
} else {
continue;
}
}
if (!parse_struct_item_spec(&parsedspec, item_spec)) {
- return false;
+ goto error;
}
if (parsedspec.type != type) {
if (!skip(input, type)) {
PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped");
- return false;
+ goto error;
} else {
continue;
}
@@ -938,16 +949,34 @@
fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
if (fieldval == NULL) {
- return false;
+ goto error;
}
- if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) {
+ if ((immutable && PyDict_SetItem(kwargs, parsedspec.attrname, fieldval) == -1)
+ || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1)) {
Py_DECREF(fieldval);
- return false;
+ goto error;
}
Py_DECREF(fieldval);
}
- return true;
+ if (immutable) {
+ PyObject* args = PyTuple_New(0);
+ PyObject* ret = NULL;
+ if (!args) {
+ PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage");
+ goto error;
+ }
+ ret = PyObject_Call(klass, args, kwargs);
+ Py_DECREF(kwargs);
+ Py_DECREF(args);
+ return ret;
+ }
+ Py_INCREF(output);
+ return output;
+
+ error:
+ Py_XDECREF(kwargs);
+ return NULL;
}
@@ -1033,6 +1062,7 @@
int32_t len;
PyObject* ret = NULL;
int i;
+ bool use_tuple = false;
if (!parse_set_list_args(&parsedargs, typeargs)) {
return NULL;
@@ -1047,7 +1077,8 @@
return NULL;
}
- ret = PyList_New(len);
+ use_tuple = type == T_LIST && parsedargs.immutable;
+ ret = use_tuple ? PyTuple_New(len) : PyList_New(len);
if (!ret) {
return NULL;
}
@@ -1058,20 +1089,18 @@
Py_DECREF(ret);
return NULL;
}
- PyList_SET_ITEM(ret, i, item);
+ if (use_tuple) {
+ PyTuple_SET_ITEM(ret, i, item);
+ } else {
+ PyList_SET_ITEM(ret, i, item);
+ }
}
// TODO(dreiss): Consider biting the bullet and making two separate cases
// for list and set, avoiding this post facto conversion.
if (type == T_SET) {
PyObject* setret;
-#if (PY_VERSION_HEX < 0x02050000)
- // hack needed for older versions
- setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL);
-#else
- // official version
- setret = PySet_New(ret);
-#endif
+ setret = parsedargs.immutable ? PyFrozenSet_New(ret) : PySet_New(ret);
Py_DECREF(ret);
return setret;
}
@@ -1131,6 +1160,22 @@
goto error;
}
+ if (parsedargs.immutable) {
+ PyObject* thrift = PyImport_ImportModule("thrift.Thrift");
+ PyObject* cls = NULL;
+ PyObject* arg = NULL;
+ if (!thrift) {
+ goto error;
+ }
+ cls = PyObject_GetAttrString(thrift, "TFrozenDict");
+ if (!cls) {
+ goto error;
+ }
+ arg = PyTuple_New(1);
+ PyTuple_SET_ITEM(arg, 0, ret);
+ return PyObject_CallObject(cls, arg);
+ }
+
return ret;
error:
@@ -1140,22 +1185,12 @@
case T_STRUCT: {
StructTypeArgs parsedargs;
- PyObject* ret;
+ PyObject* ret;
if (!parse_struct_args(&parsedargs, typeargs)) {
return NULL;
}
- ret = PyObject_CallObject(parsedargs.klass, NULL);
- if (!ret) {
- return NULL;
- }
-
- if (!decode_struct(input, ret, parsedargs.spec)) {
- Py_DECREF(ret);
- return NULL;
- }
-
- return ret;
+ return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec);
}
case T_STOP:
@@ -1179,6 +1214,7 @@
PyObject* typeargs = NULL;
StructTypeArgs parsedargs;
DecodeBuffer input = {0, 0};
+ PyObject* ret = NULL;
if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
return NULL;
@@ -1192,14 +1228,9 @@
return NULL;
}
- if (!decode_struct(&input, output_obj, parsedargs.spec)) {
- free_decodebuf(&input);
- return NULL;
- }
-
+ ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec);
free_decodebuf(&input);
-
- Py_RETURN_NONE;
+ return ret;
}
/* ====== END READING FUNCTIONS ====== */