THRIFT-4002: Make generated exception classes immutable by default
Currently, the generated exception classes are not hashable under
Python 3 because of the generated `__eq__` method. Exception objects
are generally expected to be hashable by the Python standard library.
Post-construction mutation of an exception object seems like a very
unlikely case, so enable hashing for all exceptions by making them
immutable by default. This also adds a way to opt-out of immutability
by setting the `python.immutable` annotation to `"false"`.
diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc
index 982bca1..e93bbe1 100644
--- a/compiler/cpp/src/thrift/generate/t_py_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc
@@ -65,6 +65,7 @@
coding_ = "";
gen_dynbaseclass_ = "";
gen_dynbaseclass_exc_ = "";
+ gen_dynbaseclass_frozen_exc_ = "";
gen_dynbaseclass_frozen_ = "";
import_dynbase_ = "";
package_prefix_ = "";
@@ -94,8 +95,11 @@
if( gen_dynbaseclass_exc_.empty()) {
gen_dynbaseclass_exc_ = "TExceptionBase";
}
+ if( gen_dynbaseclass_frozen_exc_.empty()) {
+ gen_dynbaseclass_frozen_exc_ = "TFrozenExceptionBase";
+ }
if( import_dynbase_.empty()) {
- import_dynbase_ = "from thrift.protocol.TBase import TBase, TFrozenBase, TExceptionBase, TTransport\n";
+ import_dynbase_ = "from thrift.protocol.TBase import TBase, TFrozenBase, TExceptionBase, TFrozenExceptionBase, TTransport\n";
}
} else if( iter->first.compare("dynbase") == 0) {
gen_dynbase_ = true;
@@ -104,6 +108,8 @@
gen_dynbaseclass_frozen_ = (iter->second);
} else if( iter->first.compare("dynexc") == 0) {
gen_dynbaseclass_exc_ = (iter->second);
+ } else if( iter->first.compare("dynfrozenexc") == 0) {
+ gen_dynbaseclass_frozen_exc_ = (iter->second);
} else if( iter->first.compare("dynimport") == 0) {
gen_dynbase_ = true;
import_dynbase_ = (iter->second);
@@ -269,7 +275,16 @@
}
static bool is_immutable(t_type* ttype) {
- return ttype->annotations_.find("python.immutable") != ttype->annotations_.end();
+ std::map<std::string, std::string>::iterator it = ttype->annotations_.find("python.immutable");
+
+ if (it == ttype->annotations_.end()) {
+ // Exceptions are immutable by default.
+ return ttype->is_xception();
+ } else if (it->second == "false") {
+ return false;
+ } else {
+ return true;
+ }
}
private:
@@ -288,6 +303,7 @@
std::string gen_dynbaseclass_;
std::string gen_dynbaseclass_frozen_;
std::string gen_dynbaseclass_exc_;
+ std::string gen_dynbaseclass_frozen_exc_;
std::string import_dynbase_;
@@ -742,7 +758,11 @@
out << endl << endl << "class " << tstruct->get_name();
if (is_exception) {
if (gen_dynamic_) {
- out << "(" << gen_dynbaseclass_exc_ << ")";
+ if (is_immutable(tstruct)) {
+ out << "(" << gen_dynbaseclass_frozen_exc_ << ")";
+ } else {
+ out << "(" << gen_dynbaseclass_exc_ << ")";
+ }
} else {
out << "(TException)";
}
@@ -2774,6 +2794,7 @@
" dynbase=CLS Derive generated classes from class CLS instead of TBase.\n"
" dynfrozen=CLS Derive generated immutable classes from class CLS instead of TFrozenBase.\n"
" dynexc=CLS Derive generated exceptions from CLS instead of TExceptionBase.\n"
+ " dynfrozenexc=CLS Derive generated immutable exceptions from CLS instead of TFrozenExceptionBase.\n"
" dynimport='from foo.bar import CLS'\n"
" Add an import line to generated code to find the dynbase class.\n"
" package_prefix='top.package.'\n"
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 9ae1b11..6c6ef18 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -80,3 +80,7 @@
[self.__class__, self.thrift_spec])
else:
return iprot.readStruct(cls, cls.thrift_spec, True)
+
+
+class TFrozenExceptionBase(TFrozenBase, TExceptionBase):
+ pass
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 3456e8f..339a283 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -303,8 +303,14 @@
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
- obj = obj_class()
- obj.read(self)
+
+ # If obj_class.read is a classmethod (e.g. in frozen structs),
+ # call it as such.
+ if getattr(obj_class.read, '__self__', None) is obj_class:
+ obj = obj_class.read(self)
+ else:
+ obj = obj_class()
+ obj.read(self)
return obj
def readContainerMap(self, spec):
diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift
index de47ea7..1ab0f6a 100644
--- a/test/DebugProtoTest.thrift
+++ b/test/DebugProtoTest.thrift
@@ -241,6 +241,10 @@
2: map<string, string> map_field;
}
+exception MutableException {
+ 1: string msg;
+} (python.immutable = "false")
+
service ServiceForExceptionWithAMap {
void methodThatThrowsAnException() throws (1: ExceptionWithAMap xwamap);
}
diff --git a/test/py/TestFrozen.py b/test/py/TestFrozen.py
index 6d2595c..ce7425f 100755
--- a/test/py/TestFrozen.py
+++ b/test/py/TestFrozen.py
@@ -19,7 +19,9 @@
# under the License.
#
+from DebugProtoTest import Srv
from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty, Wrapper
+from DebugProtoTest.ttypes import ExceptionWithAMap, MutableException
from thrift.Thrift import TFrozenDict
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol, TCompactProtocol
@@ -94,6 +96,21 @@
x2 = self._roundtrip(x, Wrapper)
self.assertEqual(x2.foo, Empty())
+ def test_frozen_exception(self):
+ exc = ExceptionWithAMap(blah='foo')
+ with self.assertRaises(TypeError):
+ exc.blah = 'bar'
+ mutexc = MutableException(msg='foo')
+ mutexc.msg = 'bar'
+ self.assertEqual(mutexc.msg, 'bar')
+
+ def test_frozen_exception_serialization(self):
+ result = Srv.declaredExceptionMethod_result(
+ xwamap=ExceptionWithAMap(blah="error"))
+ deserialized = self._roundtrip(
+ result, Srv.declaredExceptionMethod_result())
+ self.assertEqual(result, deserialized)
+
class TestFrozen(TestFrozenBase):
def protocol(self, trans):