blob: ab364002e20f4a9c4ce228145260d1c58d22ab8f [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
Raymond Hettingerc50846a2010-04-05 18:56:31 +00008import functools
Benjamin Petersonbed7d042009-07-19 21:01:52 +00009
10from fnmatch import fnmatch
11
12from . import case, suite, util
13
Benjamin Petersondccc1fc2010-03-22 00:15:53 +000014__unittest = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +000015
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000016# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
Benjamin Peterson434ae772010-03-22 01:46:47 +000023 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Benjamin Peterson886af962010-03-21 23:13:07 +000024 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25 suiteClass)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000026
Benjamin Peterson886af962010-03-21 23:13:07 +000027def _make_failed_load_tests(name, exception, suiteClass):
28 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
29
30def _make_failed_test(classname, methodname, exception, suiteClass):
31 def testFailure(self):
32 raise exception
33 attrs = {methodname: testFailure}
34 TestClass = type(classname, (case.TestCase,), attrs)
35 return suiteClass((TestClass(methodname),))
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000036
37
Benjamin Petersonbed7d042009-07-19 21:01:52 +000038class TestLoader(object):
39 """
40 This class is responsible for loading tests according to various criteria
41 and returning them wrapped in a TestSuite
42 """
43 testMethodPrefix = 'test'
44 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
45 suiteClass = suite.TestSuite
46 _top_level_dir = None
47
48 def loadTestsFromTestCase(self, testCaseClass):
49 """Return a suite of all tests cases contained in testCaseClass"""
50 if issubclass(testCaseClass, suite.TestSuite):
51 raise TypeError("Test cases should not be derived from TestSuite." \
52 " Maybe you meant to derive from TestCase?")
53 testCaseNames = self.getTestCaseNames(testCaseClass)
54 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
55 testCaseNames = ['runTest']
56 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
57 return loaded_suite
58
59 def loadTestsFromModule(self, module, use_load_tests=True):
60 """Return a suite of all tests cases contained in the given module"""
61 tests = []
62 for name in dir(module):
63 obj = getattr(module, name)
64 if isinstance(obj, type) and issubclass(obj, case.TestCase):
65 tests.append(self.loadTestsFromTestCase(obj))
66
67 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000068 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000069 if use_load_tests and load_tests is not None:
Benjamin Peterson886af962010-03-21 23:13:07 +000070 try:
71 return load_tests(self, tests, None)
72 except Exception as e:
73 return _make_failed_load_tests(module.__name__, e,
74 self.suiteClass)
Michael Foord41647d62010-02-06 00:26:13 +000075 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000076
77 def loadTestsFromName(self, name, module=None):
78 """Return a suite of all tests cases given a string specifier.
79
80 The name may resolve either to a module, a test case class, a
81 test method within a test case class, or a callable object which
82 returns a TestCase or TestSuite instance.
83
84 The method optionally resolves the names relative to a given module.
85 """
86 parts = name.split('.')
87 if module is None:
88 parts_copy = parts[:]
89 while parts_copy:
90 try:
91 module = __import__('.'.join(parts_copy))
92 break
93 except ImportError:
94 del parts_copy[-1]
95 if not parts_copy:
96 raise
97 parts = parts[1:]
98 obj = module
99 for part in parts:
100 parent, obj = obj, getattr(obj, part)
101
102 if isinstance(obj, types.ModuleType):
103 return self.loadTestsFromModule(obj)
104 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
105 return self.loadTestsFromTestCase(obj)
106 elif (isinstance(obj, types.FunctionType) and
107 isinstance(parent, type) and
108 issubclass(parent, case.TestCase)):
109 name = obj.__name__
110 inst = parent(name)
111 # static methods follow a different path
112 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000113 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000114 elif isinstance(obj, suite.TestSuite):
115 return obj
Florent Xicluna5d1155c2011-10-28 14:45:05 +0200116 if callable(obj):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000117 test = obj()
118 if isinstance(test, suite.TestSuite):
119 return test
120 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000121 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000122 else:
123 raise TypeError("calling %s returned %s, not a test" %
124 (obj, test))
125 else:
126 raise TypeError("don't know how to make test from: %s" % obj)
127
128 def loadTestsFromNames(self, names, module=None):
129 """Return a suite of all tests cases found using the given sequence
130 of string specifiers. See 'loadTestsFromName()'.
131 """
132 suites = [self.loadTestsFromName(name, module) for name in names]
133 return self.suiteClass(suites)
134
135 def getTestCaseNames(self, testCaseClass):
136 """Return a sorted sequence of method names found within testCaseClass
137 """
138 def isTestMethod(attrname, testCaseClass=testCaseClass,
139 prefix=self.testMethodPrefix):
140 return attrname.startswith(prefix) and \
Florent Xicluna5d1155c2011-10-28 14:45:05 +0200141 callable(getattr(testCaseClass, attrname))
Senthil Kumaranf27be5c2011-11-25 02:08:39 +0800142 testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000143 if self.sortTestMethodsUsing:
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000144 testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000145 return testFnNames
146
147 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
148 """Find and return all test modules from the specified start
Michael Foord6bcfade2010-11-20 17:22:21 +0000149 directory, recursing into subdirectories to find them and return all
150 tests found within them. Only test files that match the pattern will
151 be loaded. (Using shell style pattern matching.)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000152
153 All test modules must be importable from the top level of the project.
154 If the start directory is not the top level directory then the top
155 level directory must be specified separately.
156
157 If a test package name (directory with '__init__.py') matches the
158 pattern then the package will be checked for a 'load_tests' function. If
159 this exists then it will be called with loader, tests, pattern.
160
161 If load_tests exists then discovery does *not* recurse into the package,
162 load_tests is responsible for loading all tests in the package.
163
164 The pattern is deliberately not stored as a loader attribute so that
165 packages can continue discovery themselves. top_level_dir is stored so
166 load_tests does not need to pass this argument in to loader.discover().
167 """
Benjamin Petersonb48af542010-04-11 20:43:16 +0000168 set_implicit_top = False
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000169 if top_level_dir is None and self._top_level_dir is not None:
170 # make top_level_dir optional if called from load_tests in a package
171 top_level_dir = self._top_level_dir
172 elif top_level_dir is None:
Benjamin Petersonb48af542010-04-11 20:43:16 +0000173 set_implicit_top = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000174 top_level_dir = start_dir
175
Benjamin Petersonb48af542010-04-11 20:43:16 +0000176 top_level_dir = os.path.abspath(top_level_dir)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000177
178 if not top_level_dir in sys.path:
179 # all test modules must be importable from the top level directory
Michael Foord3b2494f2010-05-07 23:42:40 +0000180 # should we *unconditionally* put the start directory in first
181 # in sys.path to minimise likelihood of conflicts between installed
182 # modules and development versions?
183 sys.path.insert(0, top_level_dir)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000184 self._top_level_dir = top_level_dir
185
Benjamin Petersonb48af542010-04-11 20:43:16 +0000186 is_not_importable = False
187 if os.path.isdir(os.path.abspath(start_dir)):
188 start_dir = os.path.abspath(start_dir)
189 if start_dir != top_level_dir:
190 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
191 else:
192 # support for discovery from dotted module names
193 try:
194 __import__(start_dir)
195 except ImportError:
196 is_not_importable = True
197 else:
198 the_module = sys.modules[start_dir]
199 top_part = start_dir.split('.')[0]
200 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
201 if set_implicit_top:
202 self._top_level_dir = self._get_directory_containing_module(top_part)
203 sys.path.remove(top_level_dir)
204
205 if is_not_importable:
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000206 raise ImportError('Start directory is not importable: %r' % start_dir)
207
208 tests = list(self._find_tests(start_dir, pattern))
209 return self.suiteClass(tests)
210
Benjamin Petersonb48af542010-04-11 20:43:16 +0000211 def _get_directory_containing_module(self, module_name):
212 module = sys.modules[module_name]
213 full_path = os.path.abspath(module.__file__)
214
215 if os.path.basename(full_path).lower().startswith('__init__.py'):
216 return os.path.dirname(os.path.dirname(full_path))
217 else:
218 # here we have been given a module rather than a package - so
219 # all we can do is search the *same* directory the module is in
220 # should an exception be raised instead
221 return os.path.dirname(full_path)
222
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000223 def _get_name_from_path(self, path):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000224 path = os.path.splitext(os.path.normpath(path))[0]
225
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000226 _relpath = os.path.relpath(path, self._top_level_dir)
227 assert not os.path.isabs(_relpath), "Path must be within the project"
228 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000229
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000230 name = _relpath.replace(os.path.sep, '.')
231 return name
232
233 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000234 __import__(name)
235 return sys.modules[name]
236
Michael Foord4107d312010-06-05 10:45:41 +0000237 def _match_path(self, path, full_path, pattern):
238 # override this method to use alternative matching strategy
239 return fnmatch(path, pattern)
240
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000241 def _find_tests(self, start_dir, pattern):
242 """Used by discovery. Yields test suites it loads."""
243 paths = os.listdir(start_dir)
244
245 for path in paths:
246 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000247 if os.path.isfile(full_path):
248 if not VALID_MODULE_NAME.match(path):
249 # valid Python identifiers only
250 continue
Michael Foord4107d312010-06-05 10:45:41 +0000251 if not self._match_path(path, full_path, pattern):
252 continue
253 # if the test file matches, load it
254 name = self._get_name_from_path(full_path)
255 try:
256 module = self._get_module_from_name(name)
257 except:
258 yield _make_failed_import_test(name, self.suiteClass)
259 else:
260 mod_file = os.path.abspath(getattr(module, '__file__', full_path))
261 realpath = os.path.splitext(mod_file)[0]
262 fullpath_noext = os.path.splitext(full_path)[0]
263 if realpath.lower() != fullpath_noext.lower():
264 module_dir = os.path.dirname(realpath)
265 mod_name = os.path.splitext(os.path.basename(full_path))[0]
266 expected_dir = os.path.dirname(full_path)
267 msg = ("%r module incorrectly imported from %r. Expected %r. "
268 "Is this module globally installed?")
269 raise ImportError(msg % (mod_name, module_dir, expected_dir))
270 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000271 elif os.path.isdir(full_path):
272 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
273 continue
274
275 load_tests = None
276 tests = None
277 if fnmatch(path, pattern):
278 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000279 name = self._get_name_from_path(full_path)
280 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000281 load_tests = getattr(package, 'load_tests', None)
282 tests = self.loadTestsFromModule(package, use_load_tests=False)
283
284 if load_tests is None:
285 if tests is not None:
286 # tests loaded from package file
287 yield tests
288 # recurse into the package
289 for test in self._find_tests(full_path, pattern):
290 yield test
291 else:
Benjamin Peterson886af962010-03-21 23:13:07 +0000292 try:
293 yield load_tests(self, tests, pattern)
294 except Exception as e:
295 yield _make_failed_load_tests(package.__name__, e,
296 self.suiteClass)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000297
298defaultTestLoader = TestLoader()
299
300
301def _makeLoader(prefix, sortUsing, suiteClass=None):
302 loader = TestLoader()
303 loader.sortTestMethodsUsing = sortUsing
304 loader.testMethodPrefix = prefix
305 if suiteClass:
306 loader.suiteClass = suiteClass
307 return loader
308
309def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
310 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
311
312def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
313 suiteClass=suite.TestSuite):
314 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
315 testCaseClass)
316
317def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
318 suiteClass=suite.TestSuite):
319 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
320 module)