Regression test for THRIFT-4002: Immutable exception deserialization
Client: py
Patch: Jens Geyer
Generated-by: Opencode big-pickle
This test verifies that immutable structs (including exceptions, which are immutable by default since Thrift 0.14.0) can be properly deserialized without triggering the __setattr__ TypeError.
The bug manifests when:
1. A struct class is marked immutable (has __setattr__ that raises TypeError)
2. Thrift's deserialization tries to set attributes via setattr instead of using the kwargs constructor
Test coverage:
- Immutable exception creation and hashability
- Immutable exception blocks modification/deletion
- Round-trip serialization/deserialization with TBinaryProtocol
- Round-trip serialization/deserialization with TCompactProtocol
- Accelerated protocol tests (C extension) when available
Related: THRIFT-4002, THRIFT-5715
diff --git a/lib/py/test/test_immutable_exception.py b/lib/py/test/test_immutable_exception.py
new file mode 100644
index 0000000..827d3ca
--- /dev/null
+++ b/lib/py/test/test_immutable_exception.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+"""
+Test cases for THRIFT-4002: Immutable exception deserialization.
+
+This test verifies that immutable structs (including exceptions, which are immutable
+by default since Thrift 0.14.0) can be properly deserialized without triggering
+the __setattr__ TypeError.
+
+The bug manifests when:
+1. A struct class is marked immutable (has __setattr__ that raises TypeError)
+2. Thrift's deserialization tries to set attributes via setattr instead of
+ using the kwargs constructor
+
+This test ensures that all deserialization paths (C extension, pure Python,
+all protocols) correctly handle immutable structs.
+"""
+
+import unittest
+from collections.abc import Hashable
+
+import glob
+import os
+import sys
+
+SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
+
+for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')):
+ for pattern in ('-%d.%d', '-%d%d'):
+ postfix = pattern % (sys.version_info[0], sys.version_info[1])
+ if libpath.endswith(postfix):
+ sys.path.insert(0, libpath)
+ break
+else:
+ src_path = os.path.join(ROOT_DIR, 'lib', 'py', 'src')
+ if os.path.exists(src_path):
+ sys.path.insert(0, src_path)
+from thrift.Thrift import TException
+from thrift.transport import TTransport
+from thrift.protocol import TBinaryProtocol, TCompactProtocol
+
+
+class ImmutableException(TException):
+ """Test exception that mimics generated immutable exception behavior."""
+
+ thrift_spec = (
+ None, # 0
+ (1, 11, 'message', 'UTF8', None, ), # 1: string
+ )
+
+ def __init__(self, message=None):
+ super(ImmutableException, self).__init__(message)
+
+ def __setattr__(self, *args):
+ raise TypeError("can't modify immutable instance")
+
+ def __delattr__(self, *args):
+ raise TypeError("can't modify immutable instance")
+
+ def __hash__(self):
+ return hash(self.__class__) ^ hash((self.message,))
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.message == other.message
+
+ def write(self, oprot):
+ if oprot._fast_encode is not None and self.thrift_spec is not None:
+ oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
+ return
+ oprot.writeStructBegin('ImmutableException')
+ if self.message is not None:
+ oprot.writeFieldBegin('message', 11, 1)
+ oprot.writeString(self.message)
+ oprot.writeFieldEnd()
+ oprot.writeFieldStop()
+ oprot.writeStructEnd()
+
+ @classmethod
+ def read(cls, iprot):
+ if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
+ return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
+ return iprot.readStruct(cls, cls.thrift_spec, True)
+
+
+class MutableException(TException):
+ """Test exception that mimics generated mutable exception behavior."""
+
+ thrift_spec = (
+ None, # 0
+ (1, 11, 'message', 'UTF8', None, ), # 1: string
+ )
+
+ def __init__(self, message=None):
+ super(MutableException, self).__init__(message)
+
+ def write(self, oprot):
+ if oprot._fast_encode is not None and self.thrift_spec is not None:
+ oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
+ return
+ oprot.writeStructBegin('MutableException')
+ if self.message is not None:
+ oprot.writeFieldBegin('message', 11, 1)
+ oprot.writeString(self.message)
+ oprot.writeFieldEnd()
+ oprot.writeFieldStop()
+ oprot.writeStructEnd()
+
+ @classmethod
+ def read(cls, iprot):
+ if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
+ return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
+ return iprot.readStruct(cls, cls.thrift_spec, False)
+
+
+class TestImmutableExceptionDeserialization(unittest.TestCase):
+ """Test that immutable exceptions can be properly deserialized."""
+
+ def _roundtrip(self, exc, protocol_class):
+ """Serialize and deserialize an exception."""
+ otrans = TTransport.TMemoryBuffer()
+ oproto = protocol_class.getProtocol(otrans)
+ exc.write(oproto)
+ itrans = TTransport.TMemoryBuffer(otrans.getvalue())
+ iproto = protocol_class.getProtocol(itrans)
+ return exc.__class__.read(iproto)
+
+ def test_immutable_exception_is_hashable(self):
+ """Verify that immutable exceptions are hashable (required for caching/logging)."""
+ exc = ImmutableException(message="test")
+ self.assertTrue(isinstance(exc, Hashable))
+ self.assertEqual(hash(exc), hash(ImmutableException(message="test")))
+
+ def test_immutable_exception_blocks_modification(self):
+ """Verify that immutable exceptions raise TypeError on attribute modification."""
+ exc = ImmutableException(message="test")
+ with self.assertRaises(TypeError) as cm:
+ exc.message = "modified"
+ self.assertIn("immutable", str(cm.exception))
+
+ def test_immutable_exception_blocks_deletion(self):
+ """Verify that immutable exceptions raise TypeError on attribute deletion."""
+ exc = ImmutableException(message="test")
+ with self.assertRaises(TypeError) as cm:
+ del exc.message
+ self.assertIn("immutable", str(cm.exception))
+
+ def test_immutable_exception_binary_protocol(self):
+ """Test immutable exception deserialization with TBinaryProtocol."""
+ exc = ImmutableException(message="test error")
+ deserialized = self._roundtrip(exc, TBinaryProtocol.TBinaryProtocolFactory())
+ self.assertEqual(exc.message, deserialized.message)
+ self.assertEqual(exc, deserialized)
+
+ def test_immutable_exception_compact_protocol(self):
+ """Test immutable exception deserialization with TCompactProtocol."""
+ exc = ImmutableException(message="test error")
+ deserialized = self._roundtrip(exc, TCompactProtocol.TCompactProtocolFactory())
+ self.assertEqual(exc.message, deserialized.message)
+ self.assertEqual(exc, deserialized)
+
+ def test_mutable_exception_can_be_modified(self):
+ """Verify that mutable exceptions can be modified (control test)."""
+ exc = MutableException(message="original")
+ exc.message = "modified"
+ self.assertEqual(exc.message, "modified")
+
+
+class TestImmutableExceptionAccelerated(unittest.TestCase):
+ """Test immutable exception deserialization with accelerated protocols (C extension)."""
+
+ def setUp(self):
+ try:
+ # The import is intentionally unused - it only checks if the C extension
+ # is available by catching ImportError. The noqa comment documents this.
+ from thrift.protocol import fastbinary # noqa: F401
+ self._has_c_extension = True
+ except ImportError:
+ self._has_c_extension = False
+
+ def _roundtrip(self, exc, protocol_class):
+ """Serialize and deserialize an exception."""
+ otrans = TTransport.TMemoryBuffer()
+ oproto = protocol_class.getProtocol(otrans)
+ exc.write(oproto)
+ itrans = TTransport.TMemoryBuffer(otrans.getvalue())
+ iproto = protocol_class.getProtocol(itrans)
+ return exc.__class__.read(iproto)
+
+ def test_immutable_exception_binary_accelerated(self):
+ """Test immutable exception with TBinaryProtocolAccelerated."""
+ if not self._has_c_extension:
+ self.skipTest("C extension not available")
+ exc = ImmutableException(message="test error")
+ deserialized = self._roundtrip(
+ exc,
+ TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)
+ )
+ self.assertEqual(exc.message, deserialized.message)
+ self.assertEqual(exc, deserialized)
+
+ def test_immutable_exception_compact_accelerated(self):
+ """Test immutable exception with TCompactProtocolAccelerated."""
+ if not self._has_c_extension:
+ self.skipTest("C extension not available")
+ exc = ImmutableException(message="test error")
+ deserialized = self._roundtrip(
+ exc,
+ TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)
+ )
+ self.assertEqual(exc.message, deserialized.message)
+ self.assertEqual(exc, deserialized)
+
+
+def suite():
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionDeserialization))
+ suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionAccelerated))
+ return suite
+
+
+if __name__ == "__main__":
+ unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))