blob: 837da4800909e873e933706ce1b498b3af821155 [file] [log] [blame]
Chris Hoge296558c2015-02-19 00:29:49 -06001# Copyright 2014 Mirantis, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
14
15import argparse
16import ast
17import importlib
18import inspect
19import os
20import sys
21import unittest
22import urllib
23import uuid
24
25DECORATOR_MODULE = 'test'
26DECORATOR_NAME = 'idempotent_id'
27DECORATOR_IMPORT = 'tempest.%s' % DECORATOR_MODULE
28IMPORT_LINE = 'from tempest import %s' % DECORATOR_MODULE
29DECORATOR_TEMPLATE = "@%s.%s('%%s')" % (DECORATOR_MODULE,
30 DECORATOR_NAME)
31
32
33class SourcePatcher(object):
34
35 """"Lazy patcher for python source files"""
36
37 def __init__(self):
38 self.source_files = None
39 self.patches = None
40 self.clear()
41
42 def clear(self):
43 """Clear inner state"""
44 self.source_files = {}
45 self.patches = {}
46
47 @staticmethod
48 def _quote(s):
49 return urllib.quote(s)
50
51 @staticmethod
52 def _unquote(s):
53 return urllib.unquote(s)
54
55 def add_patch(self, filename, patch, line_no):
56 """Add lazy patch"""
57 if filename not in self.source_files:
58 with open(filename) as f:
59 self.source_files[filename] = self._quote(f.read())
60 patch_id = str(uuid.uuid4())
61 if not patch.endswith('\n'):
62 patch += '\n'
63 self.patches[patch_id] = self._quote(patch)
64 lines = self.source_files[filename].split(self._quote('\n'))
65 lines[line_no - 1] = ''.join(('{%s:s}' % patch_id, lines[line_no - 1]))
66 self.source_files[filename] = self._quote('\n').join(lines)
67
68 def _save_changes(self, filename, source):
69 print('%s fixed' % filename)
70 with open(filename, 'w') as f:
71 f.write(source)
72
73 def apply_patches(self):
74 """Apply all patches"""
75 for filename in self.source_files:
76 patched_source = self._unquote(
77 self.source_files[filename].format(**self.patches)
78 )
79 self._save_changes(filename, patched_source)
80 self.clear()
81
82
83class TestChecker(object):
84
85 def __init__(self, package):
86 self.package = package
87 self.base_path = os.path.abspath(os.path.dirname(package.__file__))
88
89 def _path_to_package(self, path):
90 relative_path = path[len(self.base_path) + 1:]
91 if relative_path:
92 return '.'.join((self.package.__name__,) +
93 tuple(relative_path.split('/')))
94 else:
95 return self.package.__name__
96
97 def _modules_search(self):
98 """Recursive search for python modules in base package"""
99 modules = []
100 for root, dirs, files in os.walk(self.base_path):
101 if not os.path.exists(os.path.join(root, '__init__.py')):
102 continue
103 root_package = self._path_to_package(root)
104 for item in files:
105 if item.endswith('.py'):
106 modules.append('.'.join((root_package,
107 os.path.splitext(item)[0])))
108 return modules
109
110 @staticmethod
111 def _get_idempotent_id(test_node):
112 """
113 Return key-value dict with all metadata from @test.idempotent_id
114 decorators for test method
115 """
116 idempotent_id = None
117 for decorator in test_node.decorator_list:
118 if (hasattr(decorator, 'func') and
119 decorator.func.attr == DECORATOR_NAME and
120 decorator.func.value.id == DECORATOR_MODULE):
121 for arg in decorator.args:
122 idempotent_id = ast.literal_eval(arg)
123 return idempotent_id
124
125 @staticmethod
126 def _is_decorator(line):
127 return line.strip().startswith('@')
128
129 @staticmethod
130 def _is_def(line):
131 return line.strip().startswith('def ')
132
133 def _add_uuid_to_test(self, patcher, test_node, source_path):
134 with open(source_path) as src:
135 src_lines = src.read().split('\n')
136 lineno = test_node.lineno
137 insert_position = lineno
138 while True:
139 if (self._is_def(src_lines[lineno - 1]) or
140 (self._is_decorator(src_lines[lineno - 1]) and
141 (DECORATOR_TEMPLATE.split('(')[0] <=
142 src_lines[lineno - 1].strip().split('(')[0]))):
143 insert_position = lineno
144 break
145 lineno += 1
146 patcher.add_patch(
147 source_path,
148 ' ' * test_node.col_offset + DECORATOR_TEMPLATE % uuid.uuid4(),
149 insert_position
150 )
151
152 @staticmethod
153 def _is_test_case(module, node):
154 if (node.__class__ is ast.ClassDef and
155 hasattr(module, node.name) and
156 inspect.isclass(getattr(module, node.name))):
157 return issubclass(getattr(module, node.name), unittest.TestCase)
158
159 @staticmethod
160 def _is_test_method(node):
161 return (node.__class__ is ast.FunctionDef
162 and node.name.startswith('test_'))
163
164 @staticmethod
165 def _next_node(body, node):
166 if body.index(node) < len(body):
167 return body[body.index(node) + 1]
168
169 @staticmethod
170 def _import_name(node):
171 if type(node) == ast.Import:
172 return node.names[0].name
173 elif type(node) == ast.ImportFrom:
174 return '%s.%s' % (node.module, node.names[0].name)
175
176 def _add_import_for_test_uuid(self, patcher, src_parsed, source_path):
177 with open(source_path) as f:
178 src_lines = f.read().split('\n')
179 line_no = 0
180 tempest_imports = [node for node in src_parsed.body
181 if self._import_name(node) and
182 'tempest.' in self._import_name(node)]
183 if not tempest_imports:
184 import_snippet = '\n'.join(('', IMPORT_LINE, ''))
185 else:
186 for node in tempest_imports:
187 if self._import_name(node) < DECORATOR_IMPORT:
188 continue
189 else:
190 line_no = node.lineno
191 import_snippet = IMPORT_LINE
192 break
193 else:
194 line_no = tempest_imports[-1].lineno
195 while True:
196 if (not src_lines[line_no - 1] or
197 getattr(self._next_node(src_parsed.body,
198 tempest_imports[-1]),
199 'lineno') == line_no or
200 line_no == len(src_lines)):
201 break
202 line_no += 1
203 import_snippet = '\n'.join((IMPORT_LINE, ''))
204 patcher.add_patch(source_path, import_snippet, line_no)
205
206 def get_tests(self):
207 """Get test methods with sources from base package with metadata"""
208 tests = {}
209 for module_name in self._modules_search():
210 tests[module_name] = {}
211 module = importlib.import_module(module_name)
212 source_path = '.'.join(
213 (os.path.splitext(module.__file__)[0], 'py')
214 )
215 with open(source_path, 'r') as f:
216 source = f.read()
217 tests[module_name]['source_path'] = source_path
218 tests[module_name]['tests'] = {}
219 source_parsed = ast.parse(source)
220 tests[module_name]['ast'] = source_parsed
221 tests[module_name]['import_valid'] = (
222 hasattr(module, DECORATOR_MODULE) and
223 inspect.ismodule(getattr(module, DECORATOR_MODULE))
224 )
225 test_cases = (node for node in source_parsed.body
226 if self._is_test_case(module, node))
227 for node in test_cases:
228 for subnode in filter(self._is_test_method, node.body):
229 test_name = '%s.%s' % (node.name, subnode.name)
230 tests[module_name]['tests'][test_name] = subnode
231 return tests
232
233 @staticmethod
234 def _filter_tests(function, tests):
235 """Filter tests with condition 'function(test_node) == True'"""
236 result = {}
237 for module_name in tests:
238 for test_name in tests[module_name]['tests']:
239 if function(module_name, test_name, tests):
240 if module_name not in result:
241 result[module_name] = {
242 'ast': tests[module_name]['ast'],
243 'source_path': tests[module_name]['source_path'],
244 'import_valid': tests[module_name]['import_valid'],
245 'tests': {}
246 }
247 result[module_name]['tests'][test_name] = \
248 tests[module_name]['tests'][test_name]
249 return result
250
251 def find_untagged(self, tests):
252 """Filter all tests without uuid in metadata"""
253 def check_uuid_in_meta(module_name, test_name, tests):
254 idempotent_id = self._get_idempotent_id(
255 tests[module_name]['tests'][test_name])
256 return not idempotent_id
257 return self._filter_tests(check_uuid_in_meta, tests)
258
259 def report_collisions(self, tests):
260 """Reports collisions if there are any. Returns true if
261 collisions exist.
262 """
263 uuids = {}
264
265 def report(module_name, test_name, tests):
266 test_uuid = self._get_idempotent_id(
267 tests[module_name]['tests'][test_name])
268 if not test_uuid:
269 return
270 if test_uuid in uuids:
271 error_str = "%s:%s\n uuid %s collision: %s<->%s\n%s:%s\n" % (
272 tests[module_name]['source_path'],
273 tests[module_name]['tests'][test_name].lineno,
274 test_uuid,
275 test_name,
276 uuids[test_uuid]['test_name'],
277 uuids[test_uuid]['source_path'],
278 uuids[test_uuid]['test_node'].lineno,
279 )
280 print(error_str)
281 return True
282 else:
283 uuids[test_uuid] = {
284 'module': module_name,
285 'test_name': test_name,
286 'test_node': tests[module_name]['tests'][test_name],
287 'source_path': tests[module_name]['source_path']
288 }
289 return bool(self._filter_tests(report, tests))
290
291 def report_untagged(self, tests):
292 """Reports untagged tests if there are any. Returns true if
293 untagged tests exist.
294 """
295 def report(module_name, test_name, tests):
296 error_str = "%s:%s\nmissing @test.idempotent_id('...')\n%s\n" % (
297 tests[module_name]['source_path'],
298 tests[module_name]['tests'][test_name].lineno,
299 test_name
300 )
301 print(error_str)
302 return True
303 return bool(self._filter_tests(report, tests))
304
305 def fix_tests(self, tests):
306 """Add uuids to all tests specified in tests and
307 fix it in source files
308 """
309 patcher = SourcePatcher()
310 for module_name in tests:
311 add_import_once = True
312 for test_name in tests[module_name]['tests']:
313 if not tests[module_name]['import_valid'] and add_import_once:
314 self._add_import_for_test_uuid(
315 patcher,
316 tests[module_name]['ast'],
317 tests[module_name]['source_path']
318 )
319 add_import_once = False
320 self._add_uuid_to_test(
321 patcher, tests[module_name]['tests'][test_name],
322 tests[module_name]['source_path'])
323 patcher.apply_patches()
324
325
326def run():
327 parser = argparse.ArgumentParser()
328 parser.add_argument('--package', action='store', dest='package',
329 default='tempest', type=str,
330 help='Package with tests')
331 parser.add_argument('--fix', action='store_true', dest='fix_tests',
332 help='Attempt to fix tests without UUIDs')
333 args = parser.parse_args()
334 sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
335 pkg = importlib.import_module(args.package)
336 checker = TestChecker(pkg)
337 errors = False
338 tests = checker.get_tests()
339 untagged = checker.find_untagged(tests)
340 errors = checker.report_collisions(tests) or errors
341 if args.fix_tests and untagged:
342 checker.fix_tests(untagged)
343 else:
344 errors = checker.report_untagged(untagged) or errors
345 if errors:
346 sys.exit('@test.idempotent_id existence and uniqueness checks failed')
347
348if __name__ == '__main__':
349 run()