blob: 541e6c30ce86c59ffca95df14241f79015d545f5 [file] [log] [blame]
Chris Hoge296558c2015-02-19 00:29:49 -06001# Copyright 2014 Mirantis, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
14
15import argparse
16import ast
17import importlib
18import inspect
19import os
20import sys
21import unittest
22import urllib
23import uuid
24
25DECORATOR_MODULE = 'test'
26DECORATOR_NAME = 'idempotent_id'
27DECORATOR_IMPORT = 'tempest.%s' % DECORATOR_MODULE
28IMPORT_LINE = 'from tempest import %s' % DECORATOR_MODULE
29DECORATOR_TEMPLATE = "@%s.%s('%%s')" % (DECORATOR_MODULE,
30 DECORATOR_NAME)
Chris Hoge7579c1a2015-02-26 14:12:15 -080031UNIT_TESTS_EXCLUDE = 'tempest.tests'
Chris Hoge296558c2015-02-19 00:29:49 -060032
33
34class SourcePatcher(object):
35
36 """"Lazy patcher for python source files"""
37
38 def __init__(self):
39 self.source_files = None
40 self.patches = None
41 self.clear()
42
43 def clear(self):
44 """Clear inner state"""
45 self.source_files = {}
46 self.patches = {}
47
48 @staticmethod
49 def _quote(s):
50 return urllib.quote(s)
51
52 @staticmethod
53 def _unquote(s):
54 return urllib.unquote(s)
55
56 def add_patch(self, filename, patch, line_no):
57 """Add lazy patch"""
58 if filename not in self.source_files:
59 with open(filename) as f:
60 self.source_files[filename] = self._quote(f.read())
61 patch_id = str(uuid.uuid4())
62 if not patch.endswith('\n'):
63 patch += '\n'
64 self.patches[patch_id] = self._quote(patch)
65 lines = self.source_files[filename].split(self._quote('\n'))
66 lines[line_no - 1] = ''.join(('{%s:s}' % patch_id, lines[line_no - 1]))
67 self.source_files[filename] = self._quote('\n').join(lines)
68
69 def _save_changes(self, filename, source):
70 print('%s fixed' % filename)
71 with open(filename, 'w') as f:
72 f.write(source)
73
74 def apply_patches(self):
75 """Apply all patches"""
76 for filename in self.source_files:
77 patched_source = self._unquote(
78 self.source_files[filename].format(**self.patches)
79 )
80 self._save_changes(filename, patched_source)
81 self.clear()
82
83
84class TestChecker(object):
85
86 def __init__(self, package):
87 self.package = package
88 self.base_path = os.path.abspath(os.path.dirname(package.__file__))
89
90 def _path_to_package(self, path):
91 relative_path = path[len(self.base_path) + 1:]
92 if relative_path:
93 return '.'.join((self.package.__name__,) +
94 tuple(relative_path.split('/')))
95 else:
96 return self.package.__name__
97
98 def _modules_search(self):
99 """Recursive search for python modules in base package"""
100 modules = []
101 for root, dirs, files in os.walk(self.base_path):
102 if not os.path.exists(os.path.join(root, '__init__.py')):
103 continue
104 root_package = self._path_to_package(root)
105 for item in files:
106 if item.endswith('.py'):
Chris Hoge7579c1a2015-02-26 14:12:15 -0800107 module_name = '.'.join((root_package,
108 os.path.splitext(item)[0]))
109 if not module_name.startswith(UNIT_TESTS_EXCLUDE):
110 modules.append(module_name)
Chris Hoge296558c2015-02-19 00:29:49 -0600111 return modules
112
113 @staticmethod
114 def _get_idempotent_id(test_node):
115 """
116 Return key-value dict with all metadata from @test.idempotent_id
117 decorators for test method
118 """
119 idempotent_id = None
120 for decorator in test_node.decorator_list:
121 if (hasattr(decorator, 'func') and
122 decorator.func.attr == DECORATOR_NAME and
123 decorator.func.value.id == DECORATOR_MODULE):
124 for arg in decorator.args:
125 idempotent_id = ast.literal_eval(arg)
126 return idempotent_id
127
128 @staticmethod
129 def _is_decorator(line):
130 return line.strip().startswith('@')
131
132 @staticmethod
133 def _is_def(line):
134 return line.strip().startswith('def ')
135
136 def _add_uuid_to_test(self, patcher, test_node, source_path):
137 with open(source_path) as src:
138 src_lines = src.read().split('\n')
139 lineno = test_node.lineno
140 insert_position = lineno
141 while True:
142 if (self._is_def(src_lines[lineno - 1]) or
143 (self._is_decorator(src_lines[lineno - 1]) and
144 (DECORATOR_TEMPLATE.split('(')[0] <=
145 src_lines[lineno - 1].strip().split('(')[0]))):
146 insert_position = lineno
147 break
148 lineno += 1
149 patcher.add_patch(
150 source_path,
151 ' ' * test_node.col_offset + DECORATOR_TEMPLATE % uuid.uuid4(),
152 insert_position
153 )
154
155 @staticmethod
156 def _is_test_case(module, node):
157 if (node.__class__ is ast.ClassDef and
158 hasattr(module, node.name) and
159 inspect.isclass(getattr(module, node.name))):
160 return issubclass(getattr(module, node.name), unittest.TestCase)
161
162 @staticmethod
163 def _is_test_method(node):
164 return (node.__class__ is ast.FunctionDef
165 and node.name.startswith('test_'))
166
167 @staticmethod
168 def _next_node(body, node):
169 if body.index(node) < len(body):
170 return body[body.index(node) + 1]
171
172 @staticmethod
173 def _import_name(node):
174 if type(node) == ast.Import:
175 return node.names[0].name
176 elif type(node) == ast.ImportFrom:
177 return '%s.%s' % (node.module, node.names[0].name)
178
179 def _add_import_for_test_uuid(self, patcher, src_parsed, source_path):
180 with open(source_path) as f:
181 src_lines = f.read().split('\n')
182 line_no = 0
183 tempest_imports = [node for node in src_parsed.body
184 if self._import_name(node) and
185 'tempest.' in self._import_name(node)]
186 if not tempest_imports:
187 import_snippet = '\n'.join(('', IMPORT_LINE, ''))
188 else:
189 for node in tempest_imports:
190 if self._import_name(node) < DECORATOR_IMPORT:
191 continue
192 else:
193 line_no = node.lineno
194 import_snippet = IMPORT_LINE
195 break
196 else:
197 line_no = tempest_imports[-1].lineno
198 while True:
199 if (not src_lines[line_no - 1] or
200 getattr(self._next_node(src_parsed.body,
201 tempest_imports[-1]),
202 'lineno') == line_no or
203 line_no == len(src_lines)):
204 break
205 line_no += 1
206 import_snippet = '\n'.join((IMPORT_LINE, ''))
207 patcher.add_patch(source_path, import_snippet, line_no)
208
209 def get_tests(self):
210 """Get test methods with sources from base package with metadata"""
211 tests = {}
212 for module_name in self._modules_search():
213 tests[module_name] = {}
214 module = importlib.import_module(module_name)
215 source_path = '.'.join(
216 (os.path.splitext(module.__file__)[0], 'py')
217 )
218 with open(source_path, 'r') as f:
219 source = f.read()
220 tests[module_name]['source_path'] = source_path
221 tests[module_name]['tests'] = {}
222 source_parsed = ast.parse(source)
223 tests[module_name]['ast'] = source_parsed
224 tests[module_name]['import_valid'] = (
225 hasattr(module, DECORATOR_MODULE) and
226 inspect.ismodule(getattr(module, DECORATOR_MODULE))
227 )
228 test_cases = (node for node in source_parsed.body
229 if self._is_test_case(module, node))
230 for node in test_cases:
231 for subnode in filter(self._is_test_method, node.body):
232 test_name = '%s.%s' % (node.name, subnode.name)
233 tests[module_name]['tests'][test_name] = subnode
234 return tests
235
236 @staticmethod
237 def _filter_tests(function, tests):
238 """Filter tests with condition 'function(test_node) == True'"""
239 result = {}
240 for module_name in tests:
241 for test_name in tests[module_name]['tests']:
242 if function(module_name, test_name, tests):
243 if module_name not in result:
244 result[module_name] = {
245 'ast': tests[module_name]['ast'],
246 'source_path': tests[module_name]['source_path'],
247 'import_valid': tests[module_name]['import_valid'],
248 'tests': {}
249 }
250 result[module_name]['tests'][test_name] = \
251 tests[module_name]['tests'][test_name]
252 return result
253
254 def find_untagged(self, tests):
255 """Filter all tests without uuid in metadata"""
256 def check_uuid_in_meta(module_name, test_name, tests):
257 idempotent_id = self._get_idempotent_id(
258 tests[module_name]['tests'][test_name])
259 return not idempotent_id
260 return self._filter_tests(check_uuid_in_meta, tests)
261
262 def report_collisions(self, tests):
263 """Reports collisions if there are any. Returns true if
264 collisions exist.
265 """
266 uuids = {}
267
268 def report(module_name, test_name, tests):
269 test_uuid = self._get_idempotent_id(
270 tests[module_name]['tests'][test_name])
271 if not test_uuid:
272 return
273 if test_uuid in uuids:
274 error_str = "%s:%s\n uuid %s collision: %s<->%s\n%s:%s\n" % (
275 tests[module_name]['source_path'],
276 tests[module_name]['tests'][test_name].lineno,
277 test_uuid,
278 test_name,
279 uuids[test_uuid]['test_name'],
280 uuids[test_uuid]['source_path'],
281 uuids[test_uuid]['test_node'].lineno,
282 )
283 print(error_str)
284 return True
285 else:
286 uuids[test_uuid] = {
287 'module': module_name,
288 'test_name': test_name,
289 'test_node': tests[module_name]['tests'][test_name],
290 'source_path': tests[module_name]['source_path']
291 }
292 return bool(self._filter_tests(report, tests))
293
294 def report_untagged(self, tests):
295 """Reports untagged tests if there are any. Returns true if
296 untagged tests exist.
297 """
298 def report(module_name, test_name, tests):
299 error_str = "%s:%s\nmissing @test.idempotent_id('...')\n%s\n" % (
300 tests[module_name]['source_path'],
301 tests[module_name]['tests'][test_name].lineno,
302 test_name
303 )
304 print(error_str)
305 return True
306 return bool(self._filter_tests(report, tests))
307
308 def fix_tests(self, tests):
309 """Add uuids to all tests specified in tests and
310 fix it in source files
311 """
312 patcher = SourcePatcher()
313 for module_name in tests:
314 add_import_once = True
315 for test_name in tests[module_name]['tests']:
316 if not tests[module_name]['import_valid'] and add_import_once:
317 self._add_import_for_test_uuid(
318 patcher,
319 tests[module_name]['ast'],
320 tests[module_name]['source_path']
321 )
322 add_import_once = False
323 self._add_uuid_to_test(
324 patcher, tests[module_name]['tests'][test_name],
325 tests[module_name]['source_path'])
326 patcher.apply_patches()
327
328
329def run():
330 parser = argparse.ArgumentParser()
331 parser.add_argument('--package', action='store', dest='package',
332 default='tempest', type=str,
333 help='Package with tests')
334 parser.add_argument('--fix', action='store_true', dest='fix_tests',
335 help='Attempt to fix tests without UUIDs')
336 args = parser.parse_args()
337 sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
338 pkg = importlib.import_module(args.package)
339 checker = TestChecker(pkg)
340 errors = False
341 tests = checker.get_tests()
342 untagged = checker.find_untagged(tests)
343 errors = checker.report_collisions(tests) or errors
344 if args.fix_tests and untagged:
345 checker.fix_tests(untagged)
346 else:
347 errors = checker.report_untagged(untagged) or errors
348 if errors:
Chris Hoge7579c1a2015-02-26 14:12:15 -0800349 sys.exit("@test.idempotent_id existence and uniqueness checks failed\n"
350 "Run 'tox -v -euuidgen' to automatically fix tests with\n"
351 "missing @test.idempotent_id decorators.")
Chris Hoge296558c2015-02-19 00:29:49 -0600352
353if __name__ == '__main__':
354 run()