blob: 360a41ee4a7d40c36a110368f00ada692339dfc4 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
Raymond Hettingerbb006cf2010-04-04 21:45:01 +00009from functools import cmp_to_key as _CmpToKey
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000010from fnmatch import fnmatch
11
12from . import case, suite
13
Michael Foordb1aa30f2010-03-22 00:06:30 +000014__unittest = True
15
Michael Foorde91ea562009-09-13 19:07:03 +000016# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
Michael Foord49899692010-03-22 01:41:11 +000023 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Michael Foord8cb253f2010-03-21 00:55:58 +000024 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25 suiteClass)
Michael Foorde91ea562009-09-13 19:07:03 +000026
Michael Foord73dbe042010-03-21 00:53:39 +000027def _make_failed_load_tests(name, exception, suiteClass):
Michael Foord8cb253f2010-03-21 00:55:58 +000028 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
Michael Foord73dbe042010-03-21 00:53:39 +000029
Michael Foord8cb253f2010-03-21 00:55:58 +000030def _make_failed_test(classname, methodname, exception, suiteClass):
Michael Foord73dbe042010-03-21 00:53:39 +000031 def testFailure(self):
32 raise exception
33 attrs = {methodname: testFailure}
34 TestClass = type(classname, (case.TestCase,), attrs)
35 return suiteClass((TestClass(methodname),))
Michael Foorde91ea562009-09-13 19:07:03 +000036
37
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000038class TestLoader(object):
39 """
40 This class is responsible for loading tests according to various criteria
41 and returning them wrapped in a TestSuite
42 """
43 testMethodPrefix = 'test'
44 sortTestMethodsUsing = cmp
45 suiteClass = suite.TestSuite
46 _top_level_dir = None
47
48 def loadTestsFromTestCase(self, testCaseClass):
49 """Return a suite of all tests cases contained in testCaseClass"""
50 if issubclass(testCaseClass, suite.TestSuite):
51 raise TypeError("Test cases should not be derived from TestSuite." \
52 " Maybe you meant to derive from TestCase?")
53 testCaseNames = self.getTestCaseNames(testCaseClass)
54 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
55 testCaseNames = ['runTest']
56 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
57 return loaded_suite
58
59 def loadTestsFromModule(self, module, use_load_tests=True):
60 """Return a suite of all tests cases contained in the given module"""
61 tests = []
62 for name in dir(module):
63 obj = getattr(module, name)
64 if isinstance(obj, type) and issubclass(obj, case.TestCase):
65 tests.append(self.loadTestsFromTestCase(obj))
66
67 load_tests = getattr(module, 'load_tests', None)
Michael Foord08770602010-02-06 00:22:26 +000068 tests = self.suiteClass(tests)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000069 if use_load_tests and load_tests is not None:
Michael Foord73dbe042010-03-21 00:53:39 +000070 try:
71 return load_tests(self, tests, None)
72 except Exception, e:
73 return _make_failed_load_tests(module.__name__, e,
74 self.suiteClass)
Michael Foord08770602010-02-06 00:22:26 +000075 return tests
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000076
77 def loadTestsFromName(self, name, module=None):
78 """Return a suite of all tests cases given a string specifier.
79
80 The name may resolve either to a module, a test case class, a
81 test method within a test case class, or a callable object which
82 returns a TestCase or TestSuite instance.
83
84 The method optionally resolves the names relative to a given module.
85 """
86 parts = name.split('.')
87 if module is None:
88 parts_copy = parts[:]
89 while parts_copy:
90 try:
91 module = __import__('.'.join(parts_copy))
92 break
93 except ImportError:
94 del parts_copy[-1]
95 if not parts_copy:
96 raise
97 parts = parts[1:]
98 obj = module
99 for part in parts:
100 parent, obj = obj, getattr(obj, part)
101
102 if isinstance(obj, types.ModuleType):
103 return self.loadTestsFromModule(obj)
104 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
105 return self.loadTestsFromTestCase(obj)
106 elif (isinstance(obj, types.UnboundMethodType) and
107 isinstance(parent, type) and
108 issubclass(parent, case.TestCase)):
Michael Foord5a9719d2009-09-13 17:28:35 +0000109 return self.suiteClass([parent(obj.__name__)])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000110 elif isinstance(obj, suite.TestSuite):
111 return obj
112 elif hasattr(obj, '__call__'):
113 test = obj()
114 if isinstance(test, suite.TestSuite):
115 return test
116 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000117 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000118 else:
119 raise TypeError("calling %s returned %s, not a test" %
120 (obj, test))
121 else:
122 raise TypeError("don't know how to make test from: %s" % obj)
123
124 def loadTestsFromNames(self, names, module=None):
125 """Return a suite of all tests cases found using the given sequence
126 of string specifiers. See 'loadTestsFromName()'.
127 """
128 suites = [self.loadTestsFromName(name, module) for name in names]
129 return self.suiteClass(suites)
130
131 def getTestCaseNames(self, testCaseClass):
132 """Return a sorted sequence of method names found within testCaseClass
133 """
134 def isTestMethod(attrname, testCaseClass=testCaseClass,
135 prefix=self.testMethodPrefix):
136 return attrname.startswith(prefix) and \
137 hasattr(getattr(testCaseClass, attrname), '__call__')
138 testFnNames = filter(isTestMethod, dir(testCaseClass))
139 if self.sortTestMethodsUsing:
140 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
141 return testFnNames
142
143 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
144 """Find and return all test modules from the specified start
145 directory, recursing into subdirectories to find them. Only test files
146 that match the pattern will be loaded. (Using shell style pattern
147 matching.)
148
149 All test modules must be importable from the top level of the project.
150 If the start directory is not the top level directory then the top
151 level directory must be specified separately.
152
153 If a test package name (directory with '__init__.py') matches the
154 pattern then the package will be checked for a 'load_tests' function. If
155 this exists then it will be called with loader, tests, pattern.
156
157 If load_tests exists then discovery does *not* recurse into the package,
158 load_tests is responsible for loading all tests in the package.
159
160 The pattern is deliberately not stored as a loader attribute so that
161 packages can continue discovery themselves. top_level_dir is stored so
162 load_tests does not need to pass this argument in to loader.discover().
163 """
Michael Foord931190b2010-04-03 01:15:21 +0000164 set_implicit_top = False
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000165 if top_level_dir is None and self._top_level_dir is not None:
166 # make top_level_dir optional if called from load_tests in a package
167 top_level_dir = self._top_level_dir
168 elif top_level_dir is None:
Michael Foord931190b2010-04-03 01:15:21 +0000169 set_implicit_top = True
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000170 top_level_dir = start_dir
171
Michael Foord931190b2010-04-03 01:15:21 +0000172 top_level_dir = os.path.abspath(top_level_dir)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000173
174 if not top_level_dir in sys.path:
175 # all test modules must be importable from the top level directory
176 sys.path.append(top_level_dir)
177 self._top_level_dir = top_level_dir
178
Michael Foord931190b2010-04-03 01:15:21 +0000179 is_not_importable = False
180 if os.path.isdir(os.path.abspath(start_dir)):
181 start_dir = os.path.abspath(start_dir)
182 if start_dir != top_level_dir:
183 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
184 else:
185 # support for discovery from dotted module names
186 try:
187 __import__(start_dir)
188 except ImportError:
189 is_not_importable = True
190 else:
191 the_module = sys.modules[start_dir]
192 top_part = start_dir.split('.')[0]
193 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
194 if set_implicit_top:
Michael Foordc1bf6772010-04-06 23:18:16 +0000195 self._top_level_dir = self._get_directory_containing_module(top_part)
Michael Foord931190b2010-04-03 01:15:21 +0000196 sys.path.remove(top_level_dir)
197
198 if is_not_importable:
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000199 raise ImportError('Start directory is not importable: %r' % start_dir)
200
201 tests = list(self._find_tests(start_dir, pattern))
202 return self.suiteClass(tests)
203
Michael Foordc1bf6772010-04-06 23:18:16 +0000204 def _get_directory_containing_module(self, module_name):
205 module = sys.modules[module_name]
206 full_path = os.path.abspath(module.__file__)
207
208 if os.path.basename(full_path).lower().startswith('__init__.py'):
209 return os.path.dirname(os.path.dirname(full_path))
210 else:
211 # here we have been given a module rather than a package - so
212 # all we can do is search the *same* directory the module is in
213 # should an exception be raised instead
214 return os.path.dirname(full_path)
215
Michael Foorde91ea562009-09-13 19:07:03 +0000216 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000217 path = os.path.splitext(os.path.normpath(path))[0]
218
Michael Foorde91ea562009-09-13 19:07:03 +0000219 _relpath = os.path.relpath(path, self._top_level_dir)
220 assert not os.path.isabs(_relpath), "Path must be within the project"
221 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000222
Michael Foorde91ea562009-09-13 19:07:03 +0000223 name = _relpath.replace(os.path.sep, '.')
224 return name
225
226 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000227 __import__(name)
228 return sys.modules[name]
229
230 def _find_tests(self, start_dir, pattern):
231 """Used by discovery. Yields test suites it loads."""
232 paths = os.listdir(start_dir)
233
234 for path in paths:
235 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000236 if os.path.isfile(full_path):
237 if not VALID_MODULE_NAME.match(path):
238 # valid Python identifiers only
239 continue
240
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000241 if fnmatch(path, pattern):
242 # if the test file matches, load it
Michael Foorde91ea562009-09-13 19:07:03 +0000243 name = self._get_name_from_path(full_path)
244 try:
245 module = self._get_module_from_name(name)
246 except:
247 yield _make_failed_import_test(name, self.suiteClass)
248 else:
249 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000250 elif os.path.isdir(full_path):
251 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
252 continue
253
254 load_tests = None
255 tests = None
256 if fnmatch(path, pattern):
257 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000258 name = self._get_name_from_path(full_path)
259 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000260 load_tests = getattr(package, 'load_tests', None)
261 tests = self.loadTestsFromModule(package, use_load_tests=False)
262
263 if load_tests is None:
264 if tests is not None:
265 # tests loaded from package file
266 yield tests
267 # recurse into the package
268 for test in self._find_tests(full_path, pattern):
269 yield test
270 else:
Michael Foord73dbe042010-03-21 00:53:39 +0000271 try:
272 yield load_tests(self, tests, pattern)
273 except Exception, e:
274 yield _make_failed_load_tests(package.__name__, e,
275 self.suiteClass)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000276
277defaultTestLoader = TestLoader()
278
279
280def _makeLoader(prefix, sortUsing, suiteClass=None):
281 loader = TestLoader()
282 loader.sortTestMethodsUsing = sortUsing
283 loader.testMethodPrefix = prefix
284 if suiteClass:
285 loader.suiteClass = suiteClass
286 return loader
287
288def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
289 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
290
291def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
292 suiteClass=suite.TestSuite):
293 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
294
295def findTestCases(module, prefix='test', sortUsing=cmp,
296 suiteClass=suite.TestSuite):
297 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)