blob: 827d3caa2bf7c11a3bf9be4aa0abf997700be2a6 [file] [log] [blame]
Jens Geyer98034d12026-03-19 20:24:46 +01001#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12# http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22"""
23Test cases for THRIFT-4002: Immutable exception deserialization.
24
25This test verifies that immutable structs (including exceptions, which are immutable
26by default since Thrift 0.14.0) can be properly deserialized without triggering
27the __setattr__ TypeError.
28
29The bug manifests when:
301. A struct class is marked immutable (has __setattr__ that raises TypeError)
312. Thrift's deserialization tries to set attributes via setattr instead of
32 using the kwargs constructor
33
34This test ensures that all deserialization paths (C extension, pure Python,
35all protocols) correctly handle immutable structs.
36"""
37
38import unittest
39from collections.abc import Hashable
40
41import glob
42import os
43import sys
44
45SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
46ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
47
48for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')):
49 for pattern in ('-%d.%d', '-%d%d'):
50 postfix = pattern % (sys.version_info[0], sys.version_info[1])
51 if libpath.endswith(postfix):
52 sys.path.insert(0, libpath)
53 break
54else:
55 src_path = os.path.join(ROOT_DIR, 'lib', 'py', 'src')
56 if os.path.exists(src_path):
57 sys.path.insert(0, src_path)
58from thrift.Thrift import TException
59from thrift.transport import TTransport
60from thrift.protocol import TBinaryProtocol, TCompactProtocol
61
62
63class ImmutableException(TException):
64 """Test exception that mimics generated immutable exception behavior."""
65
66 thrift_spec = (
67 None, # 0
68 (1, 11, 'message', 'UTF8', None, ), # 1: string
69 )
70
71 def __init__(self, message=None):
72 super(ImmutableException, self).__init__(message)
73
74 def __setattr__(self, *args):
75 raise TypeError("can't modify immutable instance")
76
77 def __delattr__(self, *args):
78 raise TypeError("can't modify immutable instance")
79
80 def __hash__(self):
81 return hash(self.__class__) ^ hash((self.message,))
82
83 def __eq__(self, other):
84 return isinstance(other, self.__class__) and self.message == other.message
85
86 def write(self, oprot):
87 if oprot._fast_encode is not None and self.thrift_spec is not None:
88 oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
89 return
90 oprot.writeStructBegin('ImmutableException')
91 if self.message is not None:
92 oprot.writeFieldBegin('message', 11, 1)
93 oprot.writeString(self.message)
94 oprot.writeFieldEnd()
95 oprot.writeFieldStop()
96 oprot.writeStructEnd()
97
98 @classmethod
99 def read(cls, iprot):
100 if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
101 return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
102 return iprot.readStruct(cls, cls.thrift_spec, True)
103
104
105class MutableException(TException):
106 """Test exception that mimics generated mutable exception behavior."""
107
108 thrift_spec = (
109 None, # 0
110 (1, 11, 'message', 'UTF8', None, ), # 1: string
111 )
112
113 def __init__(self, message=None):
114 super(MutableException, self).__init__(message)
115
116 def write(self, oprot):
117 if oprot._fast_encode is not None and self.thrift_spec is not None:
118 oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
119 return
120 oprot.writeStructBegin('MutableException')
121 if self.message is not None:
122 oprot.writeFieldBegin('message', 11, 1)
123 oprot.writeString(self.message)
124 oprot.writeFieldEnd()
125 oprot.writeFieldStop()
126 oprot.writeStructEnd()
127
128 @classmethod
129 def read(cls, iprot):
130 if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
131 return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
132 return iprot.readStruct(cls, cls.thrift_spec, False)
133
134
135class TestImmutableExceptionDeserialization(unittest.TestCase):
136 """Test that immutable exceptions can be properly deserialized."""
137
138 def _roundtrip(self, exc, protocol_class):
139 """Serialize and deserialize an exception."""
140 otrans = TTransport.TMemoryBuffer()
141 oproto = protocol_class.getProtocol(otrans)
142 exc.write(oproto)
143 itrans = TTransport.TMemoryBuffer(otrans.getvalue())
144 iproto = protocol_class.getProtocol(itrans)
145 return exc.__class__.read(iproto)
146
147 def test_immutable_exception_is_hashable(self):
148 """Verify that immutable exceptions are hashable (required for caching/logging)."""
149 exc = ImmutableException(message="test")
150 self.assertTrue(isinstance(exc, Hashable))
151 self.assertEqual(hash(exc), hash(ImmutableException(message="test")))
152
153 def test_immutable_exception_blocks_modification(self):
154 """Verify that immutable exceptions raise TypeError on attribute modification."""
155 exc = ImmutableException(message="test")
156 with self.assertRaises(TypeError) as cm:
157 exc.message = "modified"
158 self.assertIn("immutable", str(cm.exception))
159
160 def test_immutable_exception_blocks_deletion(self):
161 """Verify that immutable exceptions raise TypeError on attribute deletion."""
162 exc = ImmutableException(message="test")
163 with self.assertRaises(TypeError) as cm:
164 del exc.message
165 self.assertIn("immutable", str(cm.exception))
166
167 def test_immutable_exception_binary_protocol(self):
168 """Test immutable exception deserialization with TBinaryProtocol."""
169 exc = ImmutableException(message="test error")
170 deserialized = self._roundtrip(exc, TBinaryProtocol.TBinaryProtocolFactory())
171 self.assertEqual(exc.message, deserialized.message)
172 self.assertEqual(exc, deserialized)
173
174 def test_immutable_exception_compact_protocol(self):
175 """Test immutable exception deserialization with TCompactProtocol."""
176 exc = ImmutableException(message="test error")
177 deserialized = self._roundtrip(exc, TCompactProtocol.TCompactProtocolFactory())
178 self.assertEqual(exc.message, deserialized.message)
179 self.assertEqual(exc, deserialized)
180
181 def test_mutable_exception_can_be_modified(self):
182 """Verify that mutable exceptions can be modified (control test)."""
183 exc = MutableException(message="original")
184 exc.message = "modified"
185 self.assertEqual(exc.message, "modified")
186
187
188class TestImmutableExceptionAccelerated(unittest.TestCase):
189 """Test immutable exception deserialization with accelerated protocols (C extension)."""
190
191 def setUp(self):
192 try:
193 # The import is intentionally unused - it only checks if the C extension
194 # is available by catching ImportError. The noqa comment documents this.
195 from thrift.protocol import fastbinary # noqa: F401
196 self._has_c_extension = True
197 except ImportError:
198 self._has_c_extension = False
199
200 def _roundtrip(self, exc, protocol_class):
201 """Serialize and deserialize an exception."""
202 otrans = TTransport.TMemoryBuffer()
203 oproto = protocol_class.getProtocol(otrans)
204 exc.write(oproto)
205 itrans = TTransport.TMemoryBuffer(otrans.getvalue())
206 iproto = protocol_class.getProtocol(itrans)
207 return exc.__class__.read(iproto)
208
209 def test_immutable_exception_binary_accelerated(self):
210 """Test immutable exception with TBinaryProtocolAccelerated."""
211 if not self._has_c_extension:
212 self.skipTest("C extension not available")
213 exc = ImmutableException(message="test error")
214 deserialized = self._roundtrip(
215 exc,
216 TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)
217 )
218 self.assertEqual(exc.message, deserialized.message)
219 self.assertEqual(exc, deserialized)
220
221 def test_immutable_exception_compact_accelerated(self):
222 """Test immutable exception with TCompactProtocolAccelerated."""
223 if not self._has_c_extension:
224 self.skipTest("C extension not available")
225 exc = ImmutableException(message="test error")
226 deserialized = self._roundtrip(
227 exc,
228 TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)
229 )
230 self.assertEqual(exc.message, deserialized.message)
231 self.assertEqual(exc, deserialized)
232
233
234def suite():
235 suite = unittest.TestSuite()
236 loader = unittest.TestLoader()
237 suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionDeserialization))
238 suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionAccelerated))
239 return suite
240
241
242if __name__ == "__main__":
243 unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))