blob: f4720477c5b85c2339bbefefc7d8bb3e71a928db [file] [log] [blame]
Hasnain Lakhanid2743002025-08-25 14:22:15 -07001import glob
2import sys
3import os
4import atheris
5
6def setup_thrift_imports():
7 """Set up the Python path to include Thrift libraries and generated code."""
8
9 # For oss-fuzz, we need to package it using pyinstaller and set up paths properly
10 if getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS'):
11 print('running in a PyInstaller bundle')
12 sys.path.insert(0, "thrift_lib")
13 sys.path.insert(0, "gen-py")
14 else:
15 print('running in a normal Python process')
16 SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
17 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))))
18
19 for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')):
20 for pattern in ('-%d.%d', '-%d%d'):
21 postfix = pattern % (sys.version_info[0], sys.version_info[1])
22 if libpath.endswith(postfix):
23 sys.path.insert(0, libpath)
24 break
25
26 gen_path = os.path.join(
27 os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..", "gen-py"
28 )
29 sys.path.append(gen_path)
30 print(sys.path)
31
32setup_thrift_imports()
33
34from thrift.transport import TTransport
35from thrift.TSerialization import serialize, deserialize
36from fuzz.ttypes import FuzzTest
37
38def create_parser_fuzzer(protocol_factory_class):
39 """
40 Create a parser fuzzer function for a specific protocol.
41
42 Args:
43 protocol_factory_class: The Thrift protocol factory class to use
44
45 Returns:
46 A function that can be used with atheris.Setup()
47 """
48 def TestOneInput(data):
49 if len(data) < 2:
50 return
51
52 try:
53 # Create a memory buffer with the fuzzed data
54 buf = TTransport.TMemoryBuffer(data)
55 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
56 factory = protocol_factory_class(string_length_limit=1000, container_length_limit=1000)
57
58 # Try to deserialize the fuzzed data into the test class
59 test_instance = deserialize(FuzzTest(), data, factory)
60
61 except Exception as e:
62 # We expect various exceptions during fuzzing
63 pass
64
65 return TestOneInput
66
67def create_roundtrip_fuzzer(protocol_factory_class):
68 """
69 Create a roundtrip fuzzer function for a specific protocol.
70
71 Args:
72 protocol_factory_class: The Thrift protocol factory class to use
73
74 Returns:
75 A function that can be used with atheris.Setup()
76 """
77 def TestOneInput(data):
78 if len(data) < 2:
79 return
80
81 try:
82 # Create a memory buffer with the fuzzed data
83 buf = TTransport.TMemoryBuffer(data)
84 transport = TTransport.TBufferedTransportFactory().getTransport(buf)
85 factory = protocol_factory_class(string_length_limit=1000, container_length_limit=1000)
86
87 # Try to deserialize the fuzzed data into the test class
88 test_instance = deserialize(FuzzTest(), data, factory)
89 # If deserialization succeeds, try to serialize it back
90 serialized = serialize(test_instance, factory)
91 # Deserialize again
92 deserialized = deserialize(FuzzTest(), serialized, factory)
93 # Verify the objects are equal after a second deserialization
94 assert test_instance == deserialized
95
96 except AssertionError as e:
97 raise e
98 except Exception as e:
99 # We expect various exceptions during fuzzing
100 pass
101
102 return TestOneInput
103
104def _run_fuzzer(fuzzer_function):
105 """
106 Set up and run the fuzzer for a specific protocol.
107
108 Args:
109 fuzzer_function: The fuzzer function to use
110 """
111 setup_thrift_imports()
112 atheris.instrument_all()
113 atheris.Setup(sys.argv, fuzzer_function, enable_python_coverage=True)
114 atheris.Fuzz()
115
116
117def run_roundtrip_fuzzer(protocol_factory_class):
118 """
119 Set up and run the fuzzer for a specific protocol.
120
121 Args:
122 protocol_factory_class: The Thrift protocol factory class to use
123 """
124 _run_fuzzer(create_roundtrip_fuzzer(protocol_factory_class))
125
126
127def run_parser_fuzzer(protocol_factory_class):
128 """
129 Set up and run the fuzzer for a specific protocol.
130
131 Args:
132 protocol_factory_class: The Thrift protocol factory class to use
133 """
134 _run_fuzzer(create_parser_fuzzer(protocol_factory_class))