blob: 33c8138b7104611a1ec2a5c2eebbd9eec617046a [file] [log] [blame]
Jens Geyer72a714e2025-08-26 22:12:07 +02001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
Hasnain Lakhanid2743002025-08-25 14:22:15 -070020import glob
21import sys
22import os
23import atheris
24
25def setup_thrift_imports():
26 """Set up the Python path to include Thrift libraries and generated code."""
27
28 # For oss-fuzz, we need to package it using pyinstaller and set up paths properly
29 if getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS'):
30 print('running in a PyInstaller bundle')
31 sys.path.insert(0, "thrift_lib")
32 sys.path.insert(0, "gen-py")
33 else:
34 print('running in a normal Python process')
35 SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
36 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))))
37
38 for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')):
39 for pattern in ('-%d.%d', '-%d%d'):
40 postfix = pattern % (sys.version_info[0], sys.version_info[1])
41 if libpath.endswith(postfix):
42 sys.path.insert(0, libpath)
43 break
44
45 gen_path = os.path.join(
46 os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..", "gen-py"
47 )
48 sys.path.append(gen_path)
49 print(sys.path)
50
51setup_thrift_imports()
52
53from thrift.transport import TTransport
54from thrift.TSerialization import serialize, deserialize
55from fuzz.ttypes import FuzzTest
56
57def create_parser_fuzzer(protocol_factory_class):
58 """
59 Create a parser fuzzer function for a specific protocol.
60
61 Args:
62 protocol_factory_class: The Thrift protocol factory class to use
63
64 Returns:
65 A function that can be used with atheris.Setup()
66 """
67 def TestOneInput(data):
68 if len(data) < 2:
69 return
70
71 try:
72 # Create a memory buffer with the fuzzed data
73 buf = TTransport.TMemoryBuffer(data)
74 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
75 factory = protocol_factory_class(string_length_limit=1000, container_length_limit=1000)
76
77 # Try to deserialize the fuzzed data into the test class
78 test_instance = deserialize(FuzzTest(), data, factory)
79
80 except Exception as e:
81 # We expect various exceptions during fuzzing
82 pass
83
84 return TestOneInput
85
86def create_roundtrip_fuzzer(protocol_factory_class):
87 """
88 Create a roundtrip fuzzer function for a specific protocol.
89
90 Args:
91 protocol_factory_class: The Thrift protocol factory class to use
92
93 Returns:
94 A function that can be used with atheris.Setup()
95 """
96 def TestOneInput(data):
97 if len(data) < 2:
98 return
99
100 try:
101 # Create a memory buffer with the fuzzed data
102 buf = TTransport.TMemoryBuffer(data)
103 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
104 factory = protocol_factory_class(string_length_limit=1000, container_length_limit=1000)
105
106 # Try to deserialize the fuzzed data into the test class
107 test_instance = deserialize(FuzzTest(), data, factory)
108 # If deserialization succeeds, try to serialize it back
109 serialized = serialize(test_instance, factory)
110 # Deserialize again
111 deserialized = deserialize(FuzzTest(), serialized, factory)
112 # Verify the objects are equal after a second deserialization
113 assert test_instance == deserialized
114
115 except AssertionError as e:
116 raise e
117 except Exception as e:
118 # We expect various exceptions during fuzzing
119 pass
120
121 return TestOneInput
122
123def _run_fuzzer(fuzzer_function):
124 """
125 Set up and run the fuzzer for a specific protocol.
126
127 Args:
128 fuzzer_function: The fuzzer function to use
129 """
130 setup_thrift_imports()
131 atheris.instrument_all()
132 atheris.Setup(sys.argv, fuzzer_function, enable_python_coverage=True)
133 atheris.Fuzz()
134
135
136def run_roundtrip_fuzzer(protocol_factory_class):
137 """
138 Set up and run the fuzzer for a specific protocol.
139
140 Args:
141 protocol_factory_class: The Thrift protocol factory class to use
142 """
143 _run_fuzzer(create_roundtrip_fuzzer(protocol_factory_class))
144
145
146def run_parser_fuzzer(protocol_factory_class):
147 """
148 Set up and run the fuzzer for a specific protocol.
149
150 Args:
151 protocol_factory_class: The Thrift protocol factory class to use
152 """
153 _run_fuzzer(create_parser_fuzzer(protocol_factory_class))