blob: 2077fa3cdb6ab0cbc46c9d0f22bee1eed08fae65 [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
Raymond Hettingerc50846a2010-04-05 18:56:31 +00008import functools
Benjamin Petersonbed7d042009-07-19 21:01:52 +00009
10from fnmatch import fnmatch
11
12from . import case, suite, util
13
Benjamin Petersondccc1fc2010-03-22 00:15:53 +000014__unittest = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +000015
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000016# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
Benjamin Peterson434ae772010-03-22 01:46:47 +000023 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Benjamin Peterson886af962010-03-21 23:13:07 +000024 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25 suiteClass)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000026
Benjamin Peterson886af962010-03-21 23:13:07 +000027def _make_failed_load_tests(name, exception, suiteClass):
28 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
29
30def _make_failed_test(classname, methodname, exception, suiteClass):
31 def testFailure(self):
32 raise exception
33 attrs = {methodname: testFailure}
34 TestClass = type(classname, (case.TestCase,), attrs)
35 return suiteClass((TestClass(methodname),))
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000036
Ezio Melottieae2b382013-03-01 14:47:50 +020037def _make_skipped_test(methodname, exception, suiteClass):
38 @case.skip(str(exception))
39 def testSkipped(self):
40 pass
41 attrs = {methodname: testSkipped}
42 TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
43 return suiteClass((TestClass(methodname),))
44
Michael Foorde01c62c2012-03-13 00:09:54 -070045def _jython_aware_splitext(path):
46 if path.lower().endswith('$py.class'):
47 return path[:-9]
48 return os.path.splitext(path)[0]
49
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000050
Benjamin Petersonbed7d042009-07-19 21:01:52 +000051class TestLoader(object):
52 """
53 This class is responsible for loading tests according to various criteria
54 and returning them wrapped in a TestSuite
55 """
56 testMethodPrefix = 'test'
57 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
58 suiteClass = suite.TestSuite
59 _top_level_dir = None
60
61 def loadTestsFromTestCase(self, testCaseClass):
62 """Return a suite of all tests cases contained in testCaseClass"""
63 if issubclass(testCaseClass, suite.TestSuite):
64 raise TypeError("Test cases should not be derived from TestSuite." \
65 " Maybe you meant to derive from TestCase?")
66 testCaseNames = self.getTestCaseNames(testCaseClass)
67 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
68 testCaseNames = ['runTest']
69 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
70 return loaded_suite
71
72 def loadTestsFromModule(self, module, use_load_tests=True):
73 """Return a suite of all tests cases contained in the given module"""
74 tests = []
75 for name in dir(module):
76 obj = getattr(module, name)
77 if isinstance(obj, type) and issubclass(obj, case.TestCase):
78 tests.append(self.loadTestsFromTestCase(obj))
79
80 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000081 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000082 if use_load_tests and load_tests is not None:
Benjamin Peterson886af962010-03-21 23:13:07 +000083 try:
84 return load_tests(self, tests, None)
85 except Exception as e:
86 return _make_failed_load_tests(module.__name__, e,
87 self.suiteClass)
Michael Foord41647d62010-02-06 00:26:13 +000088 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000089
90 def loadTestsFromName(self, name, module=None):
91 """Return a suite of all tests cases given a string specifier.
92
93 The name may resolve either to a module, a test case class, a
94 test method within a test case class, or a callable object which
95 returns a TestCase or TestSuite instance.
96
97 The method optionally resolves the names relative to a given module.
98 """
99 parts = name.split('.')
100 if module is None:
101 parts_copy = parts[:]
102 while parts_copy:
103 try:
104 module = __import__('.'.join(parts_copy))
105 break
106 except ImportError:
107 del parts_copy[-1]
108 if not parts_copy:
109 raise
110 parts = parts[1:]
111 obj = module
112 for part in parts:
113 parent, obj = obj, getattr(obj, part)
114
115 if isinstance(obj, types.ModuleType):
116 return self.loadTestsFromModule(obj)
117 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
118 return self.loadTestsFromTestCase(obj)
119 elif (isinstance(obj, types.FunctionType) and
120 isinstance(parent, type) and
121 issubclass(parent, case.TestCase)):
122 name = obj.__name__
123 inst = parent(name)
124 # static methods follow a different path
125 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000126 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000127 elif isinstance(obj, suite.TestSuite):
128 return obj
Florent Xicluna5d1155c2011-10-28 14:45:05 +0200129 if callable(obj):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000130 test = obj()
131 if isinstance(test, suite.TestSuite):
132 return test
133 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000134 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000135 else:
136 raise TypeError("calling %s returned %s, not a test" %
137 (obj, test))
138 else:
139 raise TypeError("don't know how to make test from: %s" % obj)
140
141 def loadTestsFromNames(self, names, module=None):
142 """Return a suite of all tests cases found using the given sequence
143 of string specifiers. See 'loadTestsFromName()'.
144 """
145 suites = [self.loadTestsFromName(name, module) for name in names]
146 return self.suiteClass(suites)
147
148 def getTestCaseNames(self, testCaseClass):
149 """Return a sorted sequence of method names found within testCaseClass
150 """
151 def isTestMethod(attrname, testCaseClass=testCaseClass,
152 prefix=self.testMethodPrefix):
153 return attrname.startswith(prefix) and \
Florent Xicluna5d1155c2011-10-28 14:45:05 +0200154 callable(getattr(testCaseClass, attrname))
Senthil Kumaranf27be5c2011-11-25 02:08:39 +0800155 testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000156 if self.sortTestMethodsUsing:
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000157 testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000158 return testFnNames
159
160 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
161 """Find and return all test modules from the specified start
Michael Foord6bcfade2010-11-20 17:22:21 +0000162 directory, recursing into subdirectories to find them and return all
163 tests found within them. Only test files that match the pattern will
164 be loaded. (Using shell style pattern matching.)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000165
166 All test modules must be importable from the top level of the project.
167 If the start directory is not the top level directory then the top
168 level directory must be specified separately.
169
170 If a test package name (directory with '__init__.py') matches the
171 pattern then the package will be checked for a 'load_tests' function. If
172 this exists then it will be called with loader, tests, pattern.
173
174 If load_tests exists then discovery does *not* recurse into the package,
175 load_tests is responsible for loading all tests in the package.
176
177 The pattern is deliberately not stored as a loader attribute so that
178 packages can continue discovery themselves. top_level_dir is stored so
179 load_tests does not need to pass this argument in to loader.discover().
180 """
Benjamin Petersonb48af542010-04-11 20:43:16 +0000181 set_implicit_top = False
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000182 if top_level_dir is None and self._top_level_dir is not None:
183 # make top_level_dir optional if called from load_tests in a package
184 top_level_dir = self._top_level_dir
185 elif top_level_dir is None:
Benjamin Petersonb48af542010-04-11 20:43:16 +0000186 set_implicit_top = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000187 top_level_dir = start_dir
188
Benjamin Petersonb48af542010-04-11 20:43:16 +0000189 top_level_dir = os.path.abspath(top_level_dir)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000190
191 if not top_level_dir in sys.path:
192 # all test modules must be importable from the top level directory
Michael Foord3b2494f2010-05-07 23:42:40 +0000193 # should we *unconditionally* put the start directory in first
194 # in sys.path to minimise likelihood of conflicts between installed
195 # modules and development versions?
196 sys.path.insert(0, top_level_dir)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000197 self._top_level_dir = top_level_dir
198
Benjamin Petersonb48af542010-04-11 20:43:16 +0000199 is_not_importable = False
200 if os.path.isdir(os.path.abspath(start_dir)):
201 start_dir = os.path.abspath(start_dir)
202 if start_dir != top_level_dir:
203 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
204 else:
205 # support for discovery from dotted module names
206 try:
207 __import__(start_dir)
208 except ImportError:
209 is_not_importable = True
210 else:
211 the_module = sys.modules[start_dir]
212 top_part = start_dir.split('.')[0]
213 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
214 if set_implicit_top:
215 self._top_level_dir = self._get_directory_containing_module(top_part)
216 sys.path.remove(top_level_dir)
217
218 if is_not_importable:
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000219 raise ImportError('Start directory is not importable: %r' % start_dir)
220
221 tests = list(self._find_tests(start_dir, pattern))
222 return self.suiteClass(tests)
223
Benjamin Petersonb48af542010-04-11 20:43:16 +0000224 def _get_directory_containing_module(self, module_name):
225 module = sys.modules[module_name]
226 full_path = os.path.abspath(module.__file__)
227
228 if os.path.basename(full_path).lower().startswith('__init__.py'):
229 return os.path.dirname(os.path.dirname(full_path))
230 else:
231 # here we have been given a module rather than a package - so
232 # all we can do is search the *same* directory the module is in
233 # should an exception be raised instead
234 return os.path.dirname(full_path)
235
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000236 def _get_name_from_path(self, path):
Michael Foorde01c62c2012-03-13 00:09:54 -0700237 path = _jython_aware_splitext(os.path.normpath(path))
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000238
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000239 _relpath = os.path.relpath(path, self._top_level_dir)
240 assert not os.path.isabs(_relpath), "Path must be within the project"
241 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000242
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000243 name = _relpath.replace(os.path.sep, '.')
244 return name
245
246 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000247 __import__(name)
248 return sys.modules[name]
249
Michael Foord4107d312010-06-05 10:45:41 +0000250 def _match_path(self, path, full_path, pattern):
251 # override this method to use alternative matching strategy
252 return fnmatch(path, pattern)
253
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000254 def _find_tests(self, start_dir, pattern):
255 """Used by discovery. Yields test suites it loads."""
256 paths = os.listdir(start_dir)
257
258 for path in paths:
259 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000260 if os.path.isfile(full_path):
261 if not VALID_MODULE_NAME.match(path):
262 # valid Python identifiers only
263 continue
Michael Foord4107d312010-06-05 10:45:41 +0000264 if not self._match_path(path, full_path, pattern):
265 continue
266 # if the test file matches, load it
267 name = self._get_name_from_path(full_path)
268 try:
269 module = self._get_module_from_name(name)
Ezio Melottieae2b382013-03-01 14:47:50 +0200270 except case.SkipTest as e:
271 yield _make_skipped_test(name, e, self.suiteClass)
Michael Foord4107d312010-06-05 10:45:41 +0000272 except:
273 yield _make_failed_import_test(name, self.suiteClass)
274 else:
275 mod_file = os.path.abspath(getattr(module, '__file__', full_path))
Michael Foorde01c62c2012-03-13 00:09:54 -0700276 realpath = _jython_aware_splitext(mod_file)
277 fullpath_noext = _jython_aware_splitext(full_path)
Michael Foord4107d312010-06-05 10:45:41 +0000278 if realpath.lower() != fullpath_noext.lower():
279 module_dir = os.path.dirname(realpath)
Michael Foorde01c62c2012-03-13 00:09:54 -0700280 mod_name = _jython_aware_splitext(os.path.basename(full_path))
Michael Foord4107d312010-06-05 10:45:41 +0000281 expected_dir = os.path.dirname(full_path)
282 msg = ("%r module incorrectly imported from %r. Expected %r. "
283 "Is this module globally installed?")
284 raise ImportError(msg % (mod_name, module_dir, expected_dir))
285 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000286 elif os.path.isdir(full_path):
287 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
288 continue
289
290 load_tests = None
291 tests = None
292 if fnmatch(path, pattern):
293 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000294 name = self._get_name_from_path(full_path)
295 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000296 load_tests = getattr(package, 'load_tests', None)
297 tests = self.loadTestsFromModule(package, use_load_tests=False)
298
299 if load_tests is None:
300 if tests is not None:
301 # tests loaded from package file
302 yield tests
303 # recurse into the package
Andrew Svetlov7d140152012-10-06 17:11:45 +0300304 yield from self._find_tests(full_path, pattern)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000305 else:
Benjamin Peterson886af962010-03-21 23:13:07 +0000306 try:
307 yield load_tests(self, tests, pattern)
308 except Exception as e:
309 yield _make_failed_load_tests(package.__name__, e,
310 self.suiteClass)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000311
312defaultTestLoader = TestLoader()
313
314
315def _makeLoader(prefix, sortUsing, suiteClass=None):
316 loader = TestLoader()
317 loader.sortTestMethodsUsing = sortUsing
318 loader.testMethodPrefix = prefix
319 if suiteClass:
320 loader.suiteClass = suiteClass
321 return loader
322
323def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
324 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
325
326def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
327 suiteClass=suite.TestSuite):
328 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
329 testCaseClass)
330
331def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
332 suiteClass=suite.TestSuite):
333 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
334 module)