blob: 9163a1a00d36d4c23bf735180de27bc253b32b51 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
Raymond Hettingerbb006cf2010-04-04 21:45:01 +00009from functools import cmp_to_key as _CmpToKey
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000010from fnmatch import fnmatch
11
12from . import case, suite
13
Michael Foordb1aa30f2010-03-22 00:06:30 +000014__unittest = True
15
Michael Foorde91ea562009-09-13 19:07:03 +000016# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
Michael Foord49899692010-03-22 01:41:11 +000023 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Michael Foord8cb253f2010-03-21 00:55:58 +000024 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25 suiteClass)
Michael Foorde91ea562009-09-13 19:07:03 +000026
Michael Foord73dbe042010-03-21 00:53:39 +000027def _make_failed_load_tests(name, exception, suiteClass):
Michael Foord8cb253f2010-03-21 00:55:58 +000028 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
Michael Foord73dbe042010-03-21 00:53:39 +000029
Michael Foord8cb253f2010-03-21 00:55:58 +000030def _make_failed_test(classname, methodname, exception, suiteClass):
Michael Foord73dbe042010-03-21 00:53:39 +000031 def testFailure(self):
32 raise exception
33 attrs = {methodname: testFailure}
34 TestClass = type(classname, (case.TestCase,), attrs)
35 return suiteClass((TestClass(methodname),))
Michael Foorde91ea562009-09-13 19:07:03 +000036
37
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000038class TestLoader(object):
39 """
40 This class is responsible for loading tests according to various criteria
41 and returning them wrapped in a TestSuite
42 """
43 testMethodPrefix = 'test'
44 sortTestMethodsUsing = cmp
45 suiteClass = suite.TestSuite
46 _top_level_dir = None
47
48 def loadTestsFromTestCase(self, testCaseClass):
49 """Return a suite of all tests cases contained in testCaseClass"""
50 if issubclass(testCaseClass, suite.TestSuite):
51 raise TypeError("Test cases should not be derived from TestSuite." \
52 " Maybe you meant to derive from TestCase?")
53 testCaseNames = self.getTestCaseNames(testCaseClass)
54 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
55 testCaseNames = ['runTest']
56 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
57 return loaded_suite
58
59 def loadTestsFromModule(self, module, use_load_tests=True):
60 """Return a suite of all tests cases contained in the given module"""
61 tests = []
62 for name in dir(module):
63 obj = getattr(module, name)
64 if isinstance(obj, type) and issubclass(obj, case.TestCase):
65 tests.append(self.loadTestsFromTestCase(obj))
66
67 load_tests = getattr(module, 'load_tests', None)
Michael Foord08770602010-02-06 00:22:26 +000068 tests = self.suiteClass(tests)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000069 if use_load_tests and load_tests is not None:
Michael Foord73dbe042010-03-21 00:53:39 +000070 try:
71 return load_tests(self, tests, None)
72 except Exception, e:
73 return _make_failed_load_tests(module.__name__, e,
74 self.suiteClass)
Michael Foord08770602010-02-06 00:22:26 +000075 return tests
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000076
77 def loadTestsFromName(self, name, module=None):
78 """Return a suite of all tests cases given a string specifier.
79
80 The name may resolve either to a module, a test case class, a
81 test method within a test case class, or a callable object which
82 returns a TestCase or TestSuite instance.
83
84 The method optionally resolves the names relative to a given module.
85 """
86 parts = name.split('.')
87 if module is None:
88 parts_copy = parts[:]
89 while parts_copy:
90 try:
91 module = __import__('.'.join(parts_copy))
92 break
93 except ImportError:
94 del parts_copy[-1]
95 if not parts_copy:
96 raise
97 parts = parts[1:]
98 obj = module
99 for part in parts:
100 parent, obj = obj, getattr(obj, part)
101
102 if isinstance(obj, types.ModuleType):
103 return self.loadTestsFromModule(obj)
104 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
105 return self.loadTestsFromTestCase(obj)
106 elif (isinstance(obj, types.UnboundMethodType) and
107 isinstance(parent, type) and
108 issubclass(parent, case.TestCase)):
Michael Foord050e9e52013-09-08 15:34:27 +1200109 name = parts[-1]
110 inst = parent(name)
111 return self.suiteClass([inst])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000112 elif isinstance(obj, suite.TestSuite):
113 return obj
114 elif hasattr(obj, '__call__'):
115 test = obj()
116 if isinstance(test, suite.TestSuite):
117 return test
118 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000119 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000120 else:
121 raise TypeError("calling %s returned %s, not a test" %
122 (obj, test))
123 else:
124 raise TypeError("don't know how to make test from: %s" % obj)
125
126 def loadTestsFromNames(self, names, module=None):
127 """Return a suite of all tests cases found using the given sequence
128 of string specifiers. See 'loadTestsFromName()'.
129 """
130 suites = [self.loadTestsFromName(name, module) for name in names]
131 return self.suiteClass(suites)
132
133 def getTestCaseNames(self, testCaseClass):
134 """Return a sorted sequence of method names found within testCaseClass
135 """
136 def isTestMethod(attrname, testCaseClass=testCaseClass,
137 prefix=self.testMethodPrefix):
138 return attrname.startswith(prefix) and \
139 hasattr(getattr(testCaseClass, attrname), '__call__')
140 testFnNames = filter(isTestMethod, dir(testCaseClass))
141 if self.sortTestMethodsUsing:
142 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
143 return testFnNames
144
145 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
146 """Find and return all test modules from the specified start
147 directory, recursing into subdirectories to find them. Only test files
148 that match the pattern will be loaded. (Using shell style pattern
149 matching.)
150
151 All test modules must be importable from the top level of the project.
152 If the start directory is not the top level directory then the top
153 level directory must be specified separately.
154
155 If a test package name (directory with '__init__.py') matches the
156 pattern then the package will be checked for a 'load_tests' function. If
157 this exists then it will be called with loader, tests, pattern.
158
159 If load_tests exists then discovery does *not* recurse into the package,
160 load_tests is responsible for loading all tests in the package.
161
162 The pattern is deliberately not stored as a loader attribute so that
163 packages can continue discovery themselves. top_level_dir is stored so
164 load_tests does not need to pass this argument in to loader.discover().
165 """
Michael Foord931190b2010-04-03 01:15:21 +0000166 set_implicit_top = False
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000167 if top_level_dir is None and self._top_level_dir is not None:
168 # make top_level_dir optional if called from load_tests in a package
169 top_level_dir = self._top_level_dir
170 elif top_level_dir is None:
Michael Foord931190b2010-04-03 01:15:21 +0000171 set_implicit_top = True
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000172 top_level_dir = start_dir
173
Michael Foord931190b2010-04-03 01:15:21 +0000174 top_level_dir = os.path.abspath(top_level_dir)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000175
176 if not top_level_dir in sys.path:
177 # all test modules must be importable from the top level directory
Michael Foorde6f5e222010-05-07 23:39:38 +0000178 # should we *unconditionally* put the start directory in first
179 # in sys.path to minimise likelihood of conflicts between installed
180 # modules and development versions?
181 sys.path.insert(0, top_level_dir)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000182 self._top_level_dir = top_level_dir
183
Michael Foord931190b2010-04-03 01:15:21 +0000184 is_not_importable = False
185 if os.path.isdir(os.path.abspath(start_dir)):
186 start_dir = os.path.abspath(start_dir)
187 if start_dir != top_level_dir:
188 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
189 else:
190 # support for discovery from dotted module names
191 try:
192 __import__(start_dir)
193 except ImportError:
194 is_not_importable = True
195 else:
196 the_module = sys.modules[start_dir]
197 top_part = start_dir.split('.')[0]
198 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
199 if set_implicit_top:
Michael Foordc1bf6772010-04-06 23:18:16 +0000200 self._top_level_dir = self._get_directory_containing_module(top_part)
Michael Foord931190b2010-04-03 01:15:21 +0000201 sys.path.remove(top_level_dir)
202
203 if is_not_importable:
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000204 raise ImportError('Start directory is not importable: %r' % start_dir)
205
206 tests = list(self._find_tests(start_dir, pattern))
207 return self.suiteClass(tests)
208
Michael Foordc1bf6772010-04-06 23:18:16 +0000209 def _get_directory_containing_module(self, module_name):
210 module = sys.modules[module_name]
211 full_path = os.path.abspath(module.__file__)
212
213 if os.path.basename(full_path).lower().startswith('__init__.py'):
214 return os.path.dirname(os.path.dirname(full_path))
215 else:
216 # here we have been given a module rather than a package - so
217 # all we can do is search the *same* directory the module is in
218 # should an exception be raised instead
219 return os.path.dirname(full_path)
220
Michael Foorde91ea562009-09-13 19:07:03 +0000221 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000222 path = os.path.splitext(os.path.normpath(path))[0]
223
Michael Foorde91ea562009-09-13 19:07:03 +0000224 _relpath = os.path.relpath(path, self._top_level_dir)
225 assert not os.path.isabs(_relpath), "Path must be within the project"
226 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000227
Michael Foorde91ea562009-09-13 19:07:03 +0000228 name = _relpath.replace(os.path.sep, '.')
229 return name
230
231 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000232 __import__(name)
233 return sys.modules[name]
234
Michael Foord9ef5d332010-06-05 10:39:42 +0000235 def _match_path(self, path, full_path, pattern):
236 # override this method to use alternative matching strategy
237 return fnmatch(path, pattern)
238
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000239 def _find_tests(self, start_dir, pattern):
240 """Used by discovery. Yields test suites it loads."""
241 paths = os.listdir(start_dir)
242
243 for path in paths:
244 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000245 if os.path.isfile(full_path):
246 if not VALID_MODULE_NAME.match(path):
247 # valid Python identifiers only
248 continue
Michael Foord9ef5d332010-06-05 10:39:42 +0000249 if not self._match_path(path, full_path, pattern):
250 continue
251 # if the test file matches, load it
252 name = self._get_name_from_path(full_path)
253 try:
254 module = self._get_module_from_name(name)
255 except:
256 yield _make_failed_import_test(name, self.suiteClass)
257 else:
258 mod_file = os.path.abspath(getattr(module, '__file__', full_path))
Antoine Pitrou5d791802013-10-23 19:11:29 +0200259 realpath = os.path.splitext(os.path.realpath(mod_file))[0]
260 fullpath_noext = os.path.splitext(os.path.realpath(full_path))[0]
Michael Foord9ef5d332010-06-05 10:39:42 +0000261 if realpath.lower() != fullpath_noext.lower():
262 module_dir = os.path.dirname(realpath)
263 mod_name = os.path.splitext(os.path.basename(full_path))[0]
264 expected_dir = os.path.dirname(full_path)
265 msg = ("%r module incorrectly imported from %r. Expected %r. "
266 "Is this module globally installed?")
267 raise ImportError(msg % (mod_name, module_dir, expected_dir))
268 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000269 elif os.path.isdir(full_path):
270 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
271 continue
272
273 load_tests = None
274 tests = None
275 if fnmatch(path, pattern):
276 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000277 name = self._get_name_from_path(full_path)
278 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000279 load_tests = getattr(package, 'load_tests', None)
280 tests = self.loadTestsFromModule(package, use_load_tests=False)
281
282 if load_tests is None:
283 if tests is not None:
284 # tests loaded from package file
285 yield tests
286 # recurse into the package
287 for test in self._find_tests(full_path, pattern):
288 yield test
289 else:
Michael Foord73dbe042010-03-21 00:53:39 +0000290 try:
291 yield load_tests(self, tests, pattern)
292 except Exception, e:
293 yield _make_failed_load_tests(package.__name__, e,
294 self.suiteClass)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000295
296defaultTestLoader = TestLoader()
297
298
299def _makeLoader(prefix, sortUsing, suiteClass=None):
300 loader = TestLoader()
301 loader.sortTestMethodsUsing = sortUsing
302 loader.testMethodPrefix = prefix
303 if suiteClass:
304 loader.suiteClass = suiteClass
305 return loader
306
307def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
308 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
309
310def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
311 suiteClass=suite.TestSuite):
312 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
313
314def findTestCases(module, prefix='test', sortUsing=cmp,
315 suiteClass=suite.TestSuite):
316 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)