THRIFT-162 Thrift structures are unhashable, preventing them from being used as set elements
Client: Python
Patch: David Reiss, Nobuaki Sukegawa
This closes #714
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index 44816ab..49c0b57 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -65,8 +65,9 @@
if (gen_dynamic_) {
gen_newstyle_ = 0; // dynamic is newstyle
gen_dynbaseclass_ = "TBase";
+ gen_dynbaseclass_frozen_ = "TFrozenBase";
gen_dynbaseclass_exc_ = "TExceptionBase";
- import_dynbase_ = "from thrift.protocol.TBase import TBase, TExceptionBase, TTransport\n";
+ import_dynbase_ = "from thrift.protocol.TBase import TBase, TFrozenBase, TExceptionBase, TTransport\n";
}
iter = parsed_options.find("dynbase");
@@ -75,6 +76,11 @@
gen_dynbaseclass_ = (iter->second);
}
+ iter = parsed_options.find("dynfrozen");
+ if (iter != parsed_options.end()) {
+ gen_dynbaseclass_frozen_ = (iter->second);
+ }
+
iter = parsed_options.find("dynexc");
if (iter != parsed_options.end()) {
gen_dynbaseclass_exc_ = (iter->second);
@@ -142,8 +148,7 @@
void generate_py_struct(t_struct* tstruct, bool is_exception);
void generate_py_struct_definition(std::ofstream& out,
t_struct* tstruct,
- bool is_xception = false,
- bool is_result = false);
+ bool is_xception = false);
void generate_py_struct_reader(std::ofstream& out, t_struct* tstruct);
void generate_py_struct_writer(std::ofstream& out, t_struct* tstruct);
void generate_py_struct_required_validator(std::ofstream& out, t_struct* tstruct);
@@ -166,8 +171,7 @@
void generate_deserialize_field(std::ofstream& out,
t_field* tfield,
- std::string prefix = "",
- bool inclass = false);
+ std::string prefix = "");
void generate_deserialize_struct(std::ofstream& out, t_struct* tstruct, std::string prefix = "");
@@ -244,7 +248,12 @@
return real_module;
}
+ static bool is_immutable(t_type* ttype) {
+ return ttype->annotations_.find("python.immutable") != ttype->annotations_.end();
+ }
+
private:
+
/**
* True if we should generate new-style classes.
*/
@@ -257,6 +266,7 @@
bool gen_dynbase_;
std::string gen_dynbaseclass_;
+ std::string gen_dynbaseclass_frozen_;
std::string gen_dynbaseclass_exc_;
std::string import_dynbase_;
@@ -353,14 +363,12 @@
py_autogen_comment() << endl <<
py_imports() << endl <<
render_includes() << endl <<
- render_fastbinary_includes() <<
- endl << endl;
+ render_fastbinary_includes();
f_consts_ <<
py_autogen_comment() << endl <<
py_imports() << endl <<
- "from .ttypes import *" << endl <<
- endl;
+ "from .ttypes import *" << endl;
}
/**
@@ -372,9 +380,6 @@
for (size_t i = 0; i < includes.size(); ++i) {
result += "import " + get_real_py_module(includes[i], gen_twisted_) + ".ttypes\n";
}
- if (includes.size() > 0) {
- result += "\n";
- }
return result;
}
@@ -413,7 +418,7 @@
* Prints standard thrift imports
*/
string t_py_generator::py_imports() {
- return string("from thrift.Thrift import TType, TMessageType, TException, TApplicationException");
+ return string("from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException");
}
/**
@@ -443,7 +448,7 @@
void t_py_generator::generate_enum(t_enum* tenum) {
std::ostringstream to_string_mapping, from_string_mapping;
- f_types_ << "class " << tenum->get_name() << (gen_newstyle_ ? "(object)" : "")
+ f_types_ << endl << endl << "class " << tenum->get_name() << (gen_newstyle_ ? "(object)" : "")
<< (gen_dynamic_ ? "(" + gen_dynbaseclass_ + ")" : "") << ":" << endl;
indent_up();
generate_python_docstring(f_types_, tenum);
@@ -468,7 +473,7 @@
indent_down();
f_types_ << endl;
- f_types_ << to_string_mapping.str() << endl << from_string_mapping.str() << endl;
+ f_types_ << to_string_mapping.str() << endl << from_string_mapping.str();
}
/**
@@ -547,6 +552,9 @@
} else if (type->is_map()) {
t_type* ktype = ((t_map*)type)->get_key_type();
t_type* vtype = ((t_map*)type)->get_val_type();
+ if (is_immutable(type)) {
+ out << "TFrozenDict(";
+ }
out << "{" << endl;
indent_up();
const map<t_const_value*, t_const_value*>& val = value->get_map();
@@ -560,6 +568,9 @@
}
indent_down();
indent(out) << "}";
+ if (is_immutable(type)) {
+ out << ")";
+ }
} else if (type->is_list() || type->is_set()) {
t_type* etype;
if (type->is_list()) {
@@ -568,9 +579,16 @@
etype = ((t_set*)type)->get_elem_type();
}
if (type->is_set()) {
+ if (is_immutable(type)) {
+ out << "frozen";
+ }
out << "set(";
}
- out << "[" << endl;
+ if (is_immutable(type) || type->is_set()) {
+ out << "(" << endl;
+ } else {
+ out << "[" << endl;
+ }
indent_up();
const vector<t_const_value*>& val = value->get_list();
vector<t_const_value*>::const_iterator v_iter;
@@ -580,7 +598,11 @@
out << "," << endl;
}
indent_down();
- indent(out) << "]";
+ if (is_immutable(type) || type->is_set()) {
+ indent(out) << ")";
+ } else {
+ indent(out) << "]";
+ }
if (type->is_set()) {
out << ")";
}
@@ -622,26 +644,26 @@
*/
void t_py_generator::generate_py_struct_definition(ofstream& out,
t_struct* tstruct,
- bool is_exception,
- bool is_result) {
- (void)is_result;
+ bool is_exception) {
const vector<t_field*>& members = tstruct->get_members();
const vector<t_field*>& sorted_members = tstruct->get_sorted_members();
vector<t_field*>::const_iterator m_iter;
- out << std::endl << "class " << tstruct->get_name();
+ out << endl << endl << "class " << tstruct->get_name();
if (is_exception) {
if (gen_dynamic_) {
out << "(" << gen_dynbaseclass_exc_ << ")";
} else {
out << "(TException)";
}
- } else {
- if (gen_newstyle_) {
- out << "(object)";
- } else if (gen_dynamic_) {
+ } else if (gen_dynamic_) {
+ if (is_immutable(tstruct)) {
+ out << "(" << gen_dynbaseclass_frozen_ << ")";
+ } else {
out << "(" << gen_dynbaseclass_ << ")";
}
+ } else if (gen_newstyle_) {
+ out << "(object)";
}
out << ":" << endl;
indent_up();
@@ -670,13 +692,13 @@
*/
if (gen_slots_) {
- indent(out) << "__slots__ = [ " << endl;
+ indent(out) << "__slots__ = (" << endl;
indent_up();
for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
indent(out) << "'" << (*m_iter)->get_name() << "'," << endl;
}
indent_down();
- indent(out) << " ]" << endl << endl;
+ indent(out) << ")" << endl << endl;
}
// TODO(dreiss): Look into generating an empty tuple instead of None
@@ -691,7 +713,7 @@
for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) {
- indent(out) << "None, # " << sorted_keys_pos << endl;
+ indent(out) << "None, # " << sorted_keys_pos << endl;
}
indent(out) << "(" << (*m_iter)->get_key() << ", " << type_to_enum((*m_iter)->get_type())
@@ -700,16 +722,17 @@
<< ", " << type_to_spec_args((*m_iter)->get_type()) << ", "
<< render_field_default_value(*m_iter) << ", "
<< "),"
- << " # " << sorted_keys_pos << endl;
+ << " # " << sorted_keys_pos << endl;
sorted_keys_pos++;
}
indent_down();
- indent(out) << ")" << endl << endl;
+ indent(out) << ")" << endl;
} else {
indent(out) << "thrift_spec = None" << endl;
}
+ out << endl;
if (members.size() > 0) {
out << indent() << "def __init__(self,";
@@ -729,10 +752,23 @@
if (!type->is_base_type() && !type->is_enum() && (*m_iter)->get_value() != NULL) {
indent(out) << "if " << (*m_iter)->get_name() << " is "
<< "self.thrift_spec[" << (*m_iter)->get_key() << "][4]:" << endl;
- indent(out) << " " << (*m_iter)->get_name() << " = " << render_field_default_value(*m_iter)
+ indent_up();
+ indent(out) << (*m_iter)->get_name() << " = " << render_field_default_value(*m_iter)
<< endl;
+ indent_down();
}
- indent(out) << "self." << (*m_iter)->get_name() << " = " << (*m_iter)->get_name() << endl;
+
+ if (is_immutable(tstruct)) {
+ if (gen_newstyle_ || gen_dynamic_) {
+ indent(out) << "super(" << tstruct->get_name() << ", self).__setattr__('"
+ << (*m_iter)->get_name() << "', " << (*m_iter)->get_name() << ")" << endl;
+ } else {
+ indent(out) << "self.__dict__['" << (*m_iter)->get_name()
+ << "'] = " << (*m_iter)->get_name() << endl;
+ }
+ } else {
+ indent(out) << "self." << (*m_iter)->get_name() << " = " << (*m_iter)->get_name() << endl;
+ }
}
indent_down();
@@ -740,6 +776,26 @@
out << endl;
}
+ if (is_immutable(tstruct)) {
+ out << indent() << "def __setattr__(self, *args):" << endl
+ << indent() << " raise TypeError(\"can't modify immutable instance\")" << endl
+ << endl;
+ out << indent() << "def __delattr__(self, *args):" << endl
+ << indent() << " raise TypeError(\"can't modify immutable instance\")" << endl
+ << endl;
+
+ // Hash all of the members in order, and also hash in the class
+ // to avoid collisions for stuff like single-field structures.
+ out << indent() << "def __hash__(self):" << endl
+ << indent() << " return hash(self.__class__) ^ hash((";
+
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ out << "self." << (*m_iter)->get_name() << ", ";
+ }
+
+ out << "))" << endl << endl;
+ }
+
if (!gen_dynamic_) {
generate_py_struct_reader(out, tstruct);
generate_py_struct_writer(out, tstruct);
@@ -759,7 +815,7 @@
out <<
indent() << "def __repr__(self):" << endl <<
indent() << " L = ['%s=%r' % (key, value)" << endl <<
- indent() << " for key, value in self.__dict__.items()]" << endl <<
+ indent() << " for key, value in self.__dict__.items()]" << endl <<
indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
endl;
@@ -794,7 +850,7 @@
<< indent() << " return True" << endl << endl;
out << indent() << "def __ne__(self, other):" << endl << indent()
- << " return not (self == other)" << endl << endl;
+ << " return not (self == other)" << endl;
}
indent_down();
}
@@ -806,18 +862,30 @@
const vector<t_field*>& fields = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
- indent(out) << "def read(self, iprot):" << endl;
+ if (is_immutable(tstruct)) {
+ out << indent() << "@classmethod" << endl << indent() << "def read(cls, iprot):" << endl;
+ } else {
+ indent(out) << "def read(self, iprot):" << endl;
+ }
indent_up();
+ const char* id = is_immutable(tstruct) ? "cls" : "self";
+
indent(out) << "if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated "
"and isinstance(iprot.trans, TTransport.CReadableTransport) "
- "and self.thrift_spec is not None "
+ "and " << id << ".thrift_spec is not None "
"and fastbinary is not None:" << endl;
indent_up();
- indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))"
- << endl;
- indent(out) << "return" << endl;
+ if (is_immutable(tstruct)) {
+ indent(out)
+ << "return fastbinary.decode_binary(None, iprot.trans, (cls, cls.thrift_spec))"
+ << endl;
+ } else {
+ indent(out) << "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))"
+ << endl;
+ indent(out) << "return" << endl;
+ }
indent_down();
indent(out) << "iprot.readStructBegin()" << endl;
@@ -850,7 +918,11 @@
indent_up();
indent(out) << "if ftype == " << type_to_enum((*f_iter)->get_type()) << ":" << endl;
indent_up();
- generate_deserialize_field(out, *f_iter, "self.");
+ if (is_immutable(tstruct)) {
+ generate_deserialize_field(out, *f_iter);
+ } else {
+ generate_deserialize_field(out, *f_iter, "self.");
+ }
indent_down();
out << indent() << "else:" << endl << indent() << " iprot.skip(ftype)" << endl;
indent_down();
@@ -866,6 +938,16 @@
indent(out) << "iprot.readStructEnd()" << endl;
+ if (is_immutable(tstruct)) {
+ indent(out) << "return cls(" << endl;
+ indent_up();
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ indent(out) << (*f_iter)->get_name() << "=" << (*f_iter)->get_name() << "," << endl;
+ }
+ indent_down();
+ indent(out) << ")" << endl;
+ }
+
indent_down();
out << endl;
}
@@ -916,7 +998,6 @@
indent_down();
generate_py_struct_required_validator(out, tstruct);
- out << endl;
}
void t_py_generator::generate_py_struct_required_validator(ofstream& out, t_struct* tstruct) {
@@ -962,7 +1043,7 @@
f_service_ << "import logging" << endl
<< "from .ttypes import *" << endl
<< "from thrift.Thrift import TProcessor" << endl
- << render_fastbinary_includes() << endl;
+ << render_fastbinary_includes();
if (gen_twisted_) {
f_service_ << "from zope.interface import Interface, implements" << endl
@@ -974,8 +1055,6 @@
f_service_ << "from thrift.transport import TTransport" << endl;
}
- f_service_ << endl;
-
// Generate the three main parts of the service
generate_service_interface(tservice);
generate_service_client(tservice);
@@ -996,7 +1075,7 @@
vector<t_function*> functions = tservice->get_functions();
vector<t_function*>::iterator f_iter;
- f_service_ << "# HELPER FUNCTIONS AND STRUCTURES" << endl;
+ f_service_ << endl << "# HELPER FUNCTIONS AND STRUCTURES" << endl;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
t_struct* ts = (*f_iter)->get_arglist();
@@ -1024,7 +1103,7 @@
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
result.append(*f_iter);
}
- generate_py_struct_definition(f_service_, &result, false, true);
+ generate_py_struct_definition(f_service_, &result, false);
}
}
@@ -1047,7 +1126,7 @@
}
}
- f_service_ << "class Iface" << extends_if << ":" << endl;
+ f_service_ << endl << endl << "class Iface" << extends_if << ":" << endl;
indent_up();
generate_python_docstring(f_service_, tservice);
vector<t_function*> functions = tservice->get_functions();
@@ -1055,17 +1134,22 @@
f_service_ << indent() << "pass" << endl;
} else {
vector<t_function*>::iterator f_iter;
+ bool first = true;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ if (first) {
+ first = false;
+ } else {
+ f_service_ << endl;
+ }
f_service_ << indent() << "def " << function_signature(*f_iter, true) << ":" << endl;
indent_up();
generate_python_docstring(f_service_, (*f_iter));
- f_service_ << indent() << "pass" << endl << endl;
+ f_service_ << indent() << "pass" << endl;
indent_down();
}
}
indent_down();
- f_service_ << endl;
}
/**
@@ -1089,6 +1173,8 @@
}
}
+ f_service_ << endl << endl;
+
if (gen_twisted_) {
f_service_ << "class Client" << extends_client << ":" << endl << " implements(Iface)" << endl
<< endl;
@@ -1111,7 +1197,7 @@
if (gen_twisted_) {
f_service_ << indent() << " self._transport = transport" << endl << indent()
<< " self._oprot_factory = oprot_factory" << endl << indent()
- << " self._seqid = 0" << endl << indent() << " self._reqs = {}" << endl << endl;
+ << " self._seqid = 0" << endl << indent() << " self._reqs = {}" << endl;
} else if (gen_tornado_) {
f_service_ << indent() << " self._transport = transport" << endl << indent()
<< " self._iprot_factory = iprot_factory" << endl << indent()
@@ -1119,28 +1205,26 @@
<< indent() << " else iprot_factory)" << endl << indent()
<< " self._seqid = 0" << endl << indent() << " self._reqs = {}" << endl
<< indent() << " self._transport.io_loop.spawn_callback(self._start_receiving)"
- << endl << endl;
+ << endl;
} else {
f_service_ << indent() << " self._iprot = self._oprot = iprot" << endl << indent()
<< " if oprot is not None:" << endl << indent() << " self._oprot = oprot"
- << endl << indent() << " self._seqid = 0" << endl << endl;
+ << endl << indent() << " self._seqid = 0" << endl;
}
} else {
if (gen_twisted_) {
f_service_ << indent() << " " << extends
- << ".Client.__init__(self, transport, oprot_factory)" << endl << endl;
+ << ".Client.__init__(self, transport, oprot_factory)" << endl;
} else if (gen_tornado_) {
f_service_ << indent() << " " << extends
- << ".Client.__init__(self, transport, iprot_factory, oprot_factory)" << endl
- << endl;
+ << ".Client.__init__(self, transport, iprot_factory, oprot_factory)" << endl;
} else {
- f_service_ << indent() << " " << extends << ".Client.__init__(self, iprot, oprot)" << endl
- << endl;
+ f_service_ << indent() << " " << extends << ".Client.__init__(self, iprot, oprot)" << endl;
}
}
if (gen_tornado_ && extends.empty()) {
- f_service_ <<
+ f_service_ << endl <<
indent() << "@gen.engine" << endl <<
indent() << "def _start_receiving(self):" << endl <<
indent() << " while True:" << endl <<
@@ -1164,8 +1248,7 @@
indent() << " except Exception as e:" << endl <<
indent() << " future.set_exception(e)" << endl <<
indent() << " else:" << endl <<
- indent() << " future.set_result(result)" << endl <<
- endl;
+ indent() << " future.set_result(result)" << endl;
}
// Generate client method implementations
@@ -1177,6 +1260,7 @@
vector<t_field*>::const_iterator fld_iter;
string funname = (*f_iter)->get_name();
+ f_service_ << endl;
// Open function
indent(f_service_) << "def " << function_signature(*f_iter, false) << ":" << endl;
indent_up();
@@ -1386,12 +1470,10 @@
// Close function
indent_down();
- f_service_ << endl;
}
}
indent_down();
- f_service_ << endl;
}
/**
@@ -1560,6 +1642,8 @@
extends_processor = extends + ".Processor, ";
}
+ f_service_ << endl << endl;
+
// Generate the header portion
if (gen_twisted_) {
f_service_ << "class Processor(" << extends_processor << "TProcessor):" << endl
@@ -1630,15 +1714,14 @@
}
indent_down();
- f_service_ << endl;
// Generate the process subfunctions
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ f_service_ << endl;
generate_process_function(tservice, *f_iter);
}
indent_down();
- f_service_ << endl;
}
/**
@@ -1860,7 +1943,10 @@
}
f_service_ << "args." << (*f_iter)->get_name();
}
- f_service_ << ")" << endl << indent() << "msg_type = TMessageType.REPLY" << endl;
+ f_service_ << ")" << endl;
+ if (!tfunction->is_oneway()) {
+ f_service_ << indent() << "msg_type = TMessageType.REPLY" << endl;
+ }
indent_down();
f_service_ << indent()
@@ -1900,7 +1986,6 @@
// Close function
indent_down();
- f_service_ << endl;
}
}
@@ -1909,9 +1994,7 @@
*/
void t_py_generator::generate_deserialize_field(ofstream& out,
t_field* tfield,
- string prefix,
- bool inclass) {
- (void)inclass;
+ string prefix) {
t_type* type = get_true_type(tfield->get_type());
if (type->is_void()) {
@@ -1932,7 +2015,6 @@
switch (tbase) {
case t_base_type::TYPE_VOID:
throw "compiler error: cannot serialize void field in a struct: " + name;
- break;
case t_base_type::TYPE_STRING:
if (((t_base_type*)type)->is_binary()) {
out << "readBinary()";
@@ -1979,8 +2061,12 @@
* Generates an unserializer for a struct, calling read()
*/
void t_py_generator::generate_deserialize_struct(ofstream& out, t_struct* tstruct, string prefix) {
- out << indent() << prefix << " = " << type_name(tstruct) << "()" << endl << indent() << prefix
- << ".read(iprot)" << endl;
+ if (is_immutable(tstruct)) {
+ out << indent() << prefix << " = " << type_name(tstruct) << ".read(iprot)" << endl;
+ } else {
+ out << indent() << prefix << " = " << type_name(tstruct) << "()" << endl
+ << indent() << prefix << ".read(iprot)" << endl;
+ }
}
/**
@@ -2030,9 +2116,18 @@
// Read container end
if (ttype->is_map()) {
indent(out) << "iprot.readMapEnd()" << endl;
+ if (is_immutable(ttype)) {
+ indent(out) << prefix << " = TFrozenDict(" << prefix << ")" << endl;
+ }
} else if (ttype->is_set()) {
indent(out) << "iprot.readSetEnd()" << endl;
+ if (is_immutable(ttype)) {
+ indent(out) << prefix << " = frozenset(" << prefix << ")" << endl;
+ }
} else if (ttype->is_list()) {
+ if (is_immutable(ttype)) {
+ indent(out) << prefix << " = tuple(" << prefix << ")" << endl;
+ }
indent(out) << "iprot.readListEnd()" << endl;
}
}
@@ -2178,7 +2273,7 @@
if (ttype->is_map()) {
string kiter = tmp("kiter");
string viter = tmp("viter");
- indent(out) << "for " << kiter << "," << viter << " in " << prefix << ".items():" << endl;
+ indent(out) << "for " << kiter << ", " << viter << " in " << prefix << ".items():" << endl;
indent_up();
generate_serialize_map_element(out, (t_map*)ttype, kiter, viter);
indent_down();
@@ -2456,18 +2551,21 @@
} else if (ttype->is_struct() || ttype->is_xception()) {
return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)";
} else if (ttype->is_map()) {
- return "(" + type_to_enum(((t_map*)ttype)->get_key_type()) + ","
- + type_to_spec_args(((t_map*)ttype)->get_key_type()) + ","
- + type_to_enum(((t_map*)ttype)->get_val_type()) + ","
- + type_to_spec_args(((t_map*)ttype)->get_val_type()) + ")";
+ return "(" + type_to_enum(((t_map*)ttype)->get_key_type()) + ", "
+ + type_to_spec_args(((t_map*)ttype)->get_key_type()) + ", "
+ + type_to_enum(((t_map*)ttype)->get_val_type()) + ", "
+ + type_to_spec_args(((t_map*)ttype)->get_val_type()) + ", "
+ + (is_immutable(ttype) ? "True" : "False") + ")";
} else if (ttype->is_set()) {
- return "(" + type_to_enum(((t_set*)ttype)->get_elem_type()) + ","
- + type_to_spec_args(((t_set*)ttype)->get_elem_type()) + ")";
+ return "(" + type_to_enum(((t_set*)ttype)->get_elem_type()) + ", "
+ + type_to_spec_args(((t_set*)ttype)->get_elem_type()) + ", "
+ + (is_immutable(ttype) ? "True" : "False") + ")";
} else if (ttype->is_list()) {
- return "(" + type_to_enum(((t_list*)ttype)->get_elem_type()) + ","
- + type_to_spec_args(((t_list*)ttype)->get_elem_type()) + ")";
+ return "(" + type_to_enum(((t_list*)ttype)->get_elem_type()) + ", "
+ + type_to_spec_args(((t_list*)ttype)->get_elem_type()) + ", "
+ + (is_immutable(ttype) ? "True" : "False") + ")";
}
throw "INVALID TYPE IN type_to_spec_args: " + ttype->get_name();
@@ -2484,6 +2582,7 @@
" slots: Generate code using slots for instance members.\n"
" dynamic: Generate dynamic code, less code generated but slower.\n"
" dynbase=CLS Derive generated classes from class CLS instead of TBase.\n"
+ " dynfrozen=CLS Derive generated immutable classes from class CLS instead of TFrozenBase.\n"
" dynexc=CLS Derive generated exceptions from CLS instead of TExceptionBase.\n"
" dynimport='from foo.bar import CLS'\n"
" Add an import line to generated code to find the dynbase class.\n")
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 9890af7..cbb9184 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -168,3 +168,23 @@
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()
+
+
+class TFrozenDict(dict):
+ """A dictionary that is "frozen" like a frozenset"""
+
+ def __init__(self, *args, **kwargs):
+ super(TFrozenDict, self).__init__(*args, **kwargs)
+ # Sort the items so they will be in a consistent order.
+ # XOR in the hash of the class so we don't collide with
+ # the hash of a list of tuples.
+ self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
+
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
+
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
+
+ def __hash__(self):
+ return self.__hashval
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 118a679..4f71e11 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -27,7 +27,7 @@
class TBase(object):
- __slots__ = []
+ __slots__ = ()
def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
@@ -68,4 +68,27 @@
class TExceptionBase(TBase, Exception):
- __slots__ = []
+ pass
+
+
+class TFrozenBase(TBase):
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
+
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
+
+ def __hash__(self, *args):
+ return hash(self.__class__) ^ hash(self.__slots__)
+
+ @classmethod
+ def read(cls, iprot):
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ cls.thrift_spec is not None and
+ fastbinary is not None):
+ self = cls()
+ return fastbinary.decode_binary(None,
+ iprot.trans,
+ (self.__class__, self.thrift_spec))
+ return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index ca22c48..1d703e3 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -17,7 +17,7 @@
# under the License.
#
-from thrift.Thrift import TException, TType
+from thrift.Thrift import TException, TType, TFrozenDict
import six
from ..compat import binary_to_str, str_to_binary
@@ -108,9 +108,6 @@
def writeBinary(self, str_val):
pass
- def writeBinary(self, str_val):
- return self.writeString(str_val)
-
def readMessageBegin(self):
pass
@@ -171,9 +168,6 @@
def readBinary(self):
pass
- def readBinary(self):
- return self.readString()
-
def skip(self, ttype):
if ttype == TType.STOP:
return
@@ -264,6 +258,7 @@
def readContainerList(self, spec):
results = []
ttype, tspec = spec[0], spec[1]
+ is_immutable = spec[2]
r_handler = self._ttype_handlers(ttype, spec)[0]
reader = getattr(self, r_handler)
(list_type, list_len) = self.readListBegin()
@@ -279,11 +274,12 @@
val = val_reader(tspec)
results.append(val)
self.readListEnd()
- return results
+ return tuple(results) if is_immutable else results
def readContainerSet(self, spec):
results = set()
ttype, tspec = spec[0], spec[1]
+ is_immutable = spec[2]
r_handler = self._ttype_handlers(ttype, spec)[0]
reader = getattr(self, r_handler)
(set_type, set_len) = self.readSetBegin()
@@ -297,7 +293,7 @@
for idx in range(set_len):
results.add(val_reader(tspec))
self.readSetEnd()
- return results
+ return frozenset(results) if is_immutable else results
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
@@ -309,6 +305,7 @@
results = dict()
key_ttype, key_spec = spec[0], spec[1]
val_ttype, val_spec = spec[2], spec[3]
+ is_immutable = spec[4]
(map_ktype, map_vtype, map_len) = self.readMapBegin()
# TODO: compare types we just decoded with thrift_spec and
# abort/skip if types disagree
@@ -328,9 +325,11 @@
# i.e. this fails: d=dict(); d[[0,1]] = 2
results[k_val] = v_val
self.readMapEnd()
- return results
+ return TFrozenDict(results) if is_immutable else results
- def readStruct(self, obj, thrift_spec):
+ def readStruct(self, obj, thrift_spec, is_immutable=False):
+ if is_immutable:
+ fields = {}
self.readStructBegin()
while True:
(fname, ftype, fid) = self.readFieldBegin()
@@ -345,11 +344,16 @@
fname = field[2]
fspec = field[3]
val = self.readFieldByTType(ftype, fspec)
- setattr(obj, fname, val)
+ if is_immutable:
+ fields[fname] = val
+ else:
+ setattr(obj, fname, val)
else:
self.skip(ftype)
self.readFieldEnd()
self.readStructEnd()
+ if is_immutable:
+ return obj(**fields)
def writeContainerStruct(self, val, spec):
val.write(self)
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 93c4911..a17019b 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -124,11 +124,6 @@
#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() )
#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) )
-// Py_ssize_t was not defined before Python 2.5
-#if (PY_VERSION_HEX < 0x02050000)
-typedef int Py_ssize_t;
-#endif
-
/**
* A cache of the spec_args for a set or list,
* so we don't have to keep calling PyTuple_GET_ITEM.
@@ -136,6 +131,7 @@
typedef struct {
TType element_type;
PyObject* typeargs;
+ bool immutable;
} SetListTypeArgs;
/**
@@ -147,6 +143,7 @@
TType vtag;
PyObject* ktypeargs;
PyObject* vtypeargs;
+ bool immutable;
} MapTypeArgs;
/**
@@ -156,6 +153,7 @@
typedef struct {
PyObject* klass;
PyObject* spec;
+ bool immutable;
} StructTypeArgs;
/**
@@ -233,8 +231,8 @@
static bool
parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
- if (PyTuple_Size(typeargs) != 2) {
- PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args");
+ if (PyTuple_Size(typeargs) != 3) {
+ PyErr_SetString(PyExc_TypeError, "expecting tuple of size 3 for list/set type args");
return false;
}
@@ -245,13 +243,15 @@
dest->typeargs = PyTuple_GET_ITEM(typeargs, 1);
+ dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 2);
+
return true;
}
static bool
parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
- if (PyTuple_Size(typeargs) != 4) {
- PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map");
+ if (PyTuple_Size(typeargs) != 5) {
+ PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for typeargs to map");
return false;
}
@@ -267,6 +267,7 @@
dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1);
dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3);
+ dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 4);
return true;
}
@@ -289,7 +290,7 @@
// i'd like to use ParseArgs here, but it seems to be a bottleneck.
if (PyTuple_Size(spec_tuple) != 5) {
- PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple");
+ PyErr_Format(PyExc_TypeError, "expecting 5 arguments for spec tuple but got %d", PyTuple_Size(spec_tuple));
return false;
}
@@ -885,11 +886,21 @@
static PyObject*
decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
-static bool
-decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) {
+static PyObject*
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) {
int spec_seq_len = PyTuple_Size(spec_seq);
+ bool immutable = output == Py_None;
+ PyObject* kwargs = NULL;
if (spec_seq_len == -1) {
- return false;
+ return NULL;
+ }
+
+ if (immutable) {
+ kwargs = PyDict_New();
+ if (!kwargs) {
+ PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage");
+ return NULL;
+ }
}
while (true) {
@@ -901,14 +912,14 @@
type = readByte(input);
if (type == -1) {
- return false;
+ goto error;
}
if (type == T_STOP) {
break;
}
tag = readI16(input);
if (INT_CONV_ERROR_OCCURRED(tag)) {
- return false;
+ goto error;
}
if (tag >= 0 && tag < spec_seq_len) {
item_spec = PyTuple_GET_ITEM(spec_seq, tag);
@@ -918,19 +929,19 @@
if (item_spec == Py_None) {
if (!skip(input, type)) {
- return false;
+ goto error;
} else {
continue;
}
}
if (!parse_struct_item_spec(&parsedspec, item_spec)) {
- return false;
+ goto error;
}
if (parsedspec.type != type) {
if (!skip(input, type)) {
PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped");
- return false;
+ goto error;
} else {
continue;
}
@@ -938,16 +949,34 @@
fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
if (fieldval == NULL) {
- return false;
+ goto error;
}
- if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) {
+ if ((immutable && PyDict_SetItem(kwargs, parsedspec.attrname, fieldval) == -1)
+ || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1)) {
Py_DECREF(fieldval);
- return false;
+ goto error;
}
Py_DECREF(fieldval);
}
- return true;
+ if (immutable) {
+ PyObject* args = PyTuple_New(0);
+ PyObject* ret = NULL;
+ if (!args) {
+ PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage");
+ goto error;
+ }
+ ret = PyObject_Call(klass, args, kwargs);
+ Py_DECREF(kwargs);
+ Py_DECREF(args);
+ return ret;
+ }
+ Py_INCREF(output);
+ return output;
+
+ error:
+ Py_XDECREF(kwargs);
+ return NULL;
}
@@ -1033,6 +1062,7 @@
int32_t len;
PyObject* ret = NULL;
int i;
+ bool use_tuple = false;
if (!parse_set_list_args(&parsedargs, typeargs)) {
return NULL;
@@ -1047,7 +1077,8 @@
return NULL;
}
- ret = PyList_New(len);
+ use_tuple = type == T_LIST && parsedargs.immutable;
+ ret = use_tuple ? PyTuple_New(len) : PyList_New(len);
if (!ret) {
return NULL;
}
@@ -1058,20 +1089,18 @@
Py_DECREF(ret);
return NULL;
}
- PyList_SET_ITEM(ret, i, item);
+ if (use_tuple) {
+ PyTuple_SET_ITEM(ret, i, item);
+ } else {
+ PyList_SET_ITEM(ret, i, item);
+ }
}
// TODO(dreiss): Consider biting the bullet and making two separate cases
// for list and set, avoiding this post facto conversion.
if (type == T_SET) {
PyObject* setret;
-#if (PY_VERSION_HEX < 0x02050000)
- // hack needed for older versions
- setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL);
-#else
- // official version
- setret = PySet_New(ret);
-#endif
+ setret = parsedargs.immutable ? PyFrozenSet_New(ret) : PySet_New(ret);
Py_DECREF(ret);
return setret;
}
@@ -1131,6 +1160,22 @@
goto error;
}
+ if (parsedargs.immutable) {
+ PyObject* thrift = PyImport_ImportModule("thrift.Thrift");
+ PyObject* cls = NULL;
+ PyObject* arg = NULL;
+ if (!thrift) {
+ goto error;
+ }
+ cls = PyObject_GetAttrString(thrift, "TFrozenDict");
+ if (!cls) {
+ goto error;
+ }
+ arg = PyTuple_New(1);
+ PyTuple_SET_ITEM(arg, 0, ret);
+ return PyObject_CallObject(cls, arg);
+ }
+
return ret;
error:
@@ -1140,22 +1185,12 @@
case T_STRUCT: {
StructTypeArgs parsedargs;
- PyObject* ret;
+ PyObject* ret;
if (!parse_struct_args(&parsedargs, typeargs)) {
return NULL;
}
- ret = PyObject_CallObject(parsedargs.klass, NULL);
- if (!ret) {
- return NULL;
- }
-
- if (!decode_struct(input, ret, parsedargs.spec)) {
- Py_DECREF(ret);
- return NULL;
- }
-
- return ret;
+ return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec);
}
case T_STOP:
@@ -1179,6 +1214,7 @@
PyObject* typeargs = NULL;
StructTypeArgs parsedargs;
DecodeBuffer input = {0, 0};
+ PyObject* ret = NULL;
if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
return NULL;
@@ -1192,14 +1228,9 @@
return NULL;
}
- if (!decode_struct(&input, output_obj, parsedargs.spec)) {
- free_decodebuf(&input);
- return NULL;
- }
-
+ ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec);
free_decodebuf(&input);
-
- Py_RETURN_NONE;
+ return ret;
}
/* ====== END READING FUNCTIONS ====== */
diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift
index 50ae4c1..e7119c4 100644
--- a/test/DebugProtoTest.thrift
+++ b/test/DebugProtoTest.thrift
@@ -72,11 +72,15 @@
}
struct Empty {
-}
+} (
+ python.immutable = "",
+)
struct Wrapper {
1: Empty foo
-}
+} (
+ python.immutable = "",
+)
struct RandomStuff {
1: i32 a,
@@ -153,9 +157,9 @@
42: map<byte, binary> byte_binary_map;
43: map<byte, bool> byte_boolean_map;
// collections as keys
- 44: map<list<byte>, byte> list_byte_map;
- 45: map<set<byte>, byte> set_byte_map;
- 46: map<map<byte,byte>, byte> map_byte_map;
+ 44: map<list<byte> (python.immutable = ""), byte> list_byte_map;
+ 45: map<set<byte> (python.immutable = ""), byte> set_byte_map;
+ 46: map<map<byte,byte> (python.immutable = ""), byte> map_byte_map;
// collections as values
47: map<byte, map<byte,byte>> byte_map_map;
48: map<byte, set<byte>> byte_set_map;
diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift
index 414f9a54..a58ed97 100644
--- a/test/ThriftTest.thrift
+++ b/test/ThriftTest.thrift
@@ -105,12 +105,13 @@
{
1: map<Numberz, UserId> userMap,
2: list<Xtruct> xtructs
-}
+} (python.immutable= "")
struct CrazyNesting {
1: string string_field,
2: optional set<Insanity> set_field,
- 3: required list< map<set<i32>,map<i32,set<list<map<Insanity,string>>>>>> list_field,
+ 3: required list<map<set<i32> (python.immutable = ""),
+ map<i32,set<list<map<Insanity,string>(python.immutable = "")> (python.immutable = "")>>>> list_field,
4: binary binary_field
}
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index fa2a264..f084a41 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -37,6 +37,7 @@
DEFAULT_LIBDIR_PY3 = os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib')
SCRIPTS = [
+ 'TestFrozen.py',
'TSimpleJSONProtocolTest.py',
'SerializationTest.py',
'TestEof.py',
diff --git a/test/py/TestFrozen.py b/test/py/TestFrozen.py
new file mode 100755
index 0000000..76750ad
--- /dev/null
+++ b/test/py/TestFrozen.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty, Wrapper
+from thrift.Thrift import TFrozenDict
+from thrift.transport import TTransport
+from thrift.protocol import TBinaryProtocol
+import collections
+import unittest
+
+
+class TestFrozenBase(unittest.TestCase):
+ def _roundtrip(self, src, dst):
+ otrans = TTransport.TMemoryBuffer()
+ optoro = self.protocol(otrans)
+ src.write(optoro)
+ itrans = TTransport.TMemoryBuffer(otrans.getvalue())
+ iproto = self.protocol(itrans)
+ return dst.read(iproto) or dst
+
+ def test_dict_is_hashable_only_after_frozen(self):
+ d0 = {}
+ self.assertFalse(isinstance(d0, collections.Hashable))
+ d1 = TFrozenDict(d0)
+ self.assertTrue(isinstance(d1, collections.Hashable))
+
+ def test_struct_with_collection_fields(self):
+ pass
+
+ def test_set(self):
+ """Test that annotated set field can be serialized and deserialized"""
+ x = CompactProtoTestStruct(set_byte_map={
+ frozenset([42, 100, -100]): 99,
+ frozenset([0]): 100,
+ frozenset([]): 0,
+ })
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.set_byte_map[frozenset([42, 100, -100])], 99)
+ self.assertEqual(x2.set_byte_map[frozenset([0])], 100)
+ self.assertEqual(x2.set_byte_map[frozenset([])], 0)
+
+ def test_map(self):
+ """Test that annotated map field can be serialized and deserialized"""
+ x = CompactProtoTestStruct(map_byte_map={
+ TFrozenDict({42: 42, 100: -100}): 99,
+ TFrozenDict({0: 0}): 100,
+ TFrozenDict({}): 0,
+ })
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.map_byte_map[TFrozenDict({42: 42, 100: -100})], 99)
+ self.assertEqual(x2.map_byte_map[TFrozenDict({0: 0})], 100)
+ self.assertEqual(x2.map_byte_map[TFrozenDict({})], 0)
+
+ def test_list(self):
+ """Test that annotated list field can be serialized and deserialized"""
+ x = CompactProtoTestStruct(list_byte_map={
+ (42, 100, -100): 99,
+ (0,): 100,
+ (): 0,
+ })
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.list_byte_map[(42, 100, -100)], 99)
+ self.assertEqual(x2.list_byte_map[(0,)], 100)
+ self.assertEqual(x2.list_byte_map[()], 0)
+
+ def test_empty_struct(self):
+ """Test that annotated empty struct can be serialized and deserialized"""
+ x = CompactProtoTestStruct(empty_struct_field=Empty())
+ x2 = self._roundtrip(x, CompactProtoTestStruct())
+ self.assertEqual(x2.empty_struct_field, Empty())
+
+ def test_struct(self):
+ """Test that annotated struct can be serialized and deserialized"""
+ x = Wrapper(foo=Empty())
+ self.assertEqual(x.foo, Empty())
+ x2 = self._roundtrip(x, Wrapper)
+ self.assertEqual(x2.foo, Empty())
+
+
+class TestFrozen(TestFrozenBase):
+ def protocol(self, trans):
+ return TBinaryProtocol.TBinaryProtocolFactory().getProtocol(trans)
+
+
+class TestFrozenAccelerated(TestFrozenBase):
+ def protocol(self, trans):
+ return TBinaryProtocol.TBinaryProtocolAcceleratedFactory().getProtocol(trans)
+
+
+def suite():
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ suite.addTest(loader.loadTestsFromTestCase(TestFrozen))
+ suite.addTest(loader.loadTestsFromTestCase(TestFrozenAccelerated))
+ return suite
+
+if __name__ == "__main__":
+ unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))