blob: f00f38d1a163d9cf342800bcb5ee61be5c0828c5 [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
Raymond Hettingerc50846a2010-04-05 18:56:31 +00008import functools
Benjamin Petersonbed7d042009-07-19 21:01:52 +00009
10from fnmatch import fnmatch
11
12from . import case, suite, util
13
Benjamin Petersondccc1fc2010-03-22 00:15:53 +000014__unittest = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +000015
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000016# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
Benjamin Peterson434ae772010-03-22 01:46:47 +000023 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Benjamin Peterson886af962010-03-21 23:13:07 +000024 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25 suiteClass)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000026
Benjamin Peterson886af962010-03-21 23:13:07 +000027def _make_failed_load_tests(name, exception, suiteClass):
28 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
29
30def _make_failed_test(classname, methodname, exception, suiteClass):
31 def testFailure(self):
32 raise exception
33 attrs = {methodname: testFailure}
34 TestClass = type(classname, (case.TestCase,), attrs)
35 return suiteClass((TestClass(methodname),))
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000036
37
Benjamin Petersonbed7d042009-07-19 21:01:52 +000038class TestLoader(object):
39 """
40 This class is responsible for loading tests according to various criteria
41 and returning them wrapped in a TestSuite
42 """
43 testMethodPrefix = 'test'
44 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
45 suiteClass = suite.TestSuite
46 _top_level_dir = None
47
48 def loadTestsFromTestCase(self, testCaseClass):
49 """Return a suite of all tests cases contained in testCaseClass"""
50 if issubclass(testCaseClass, suite.TestSuite):
51 raise TypeError("Test cases should not be derived from TestSuite." \
52 " Maybe you meant to derive from TestCase?")
53 testCaseNames = self.getTestCaseNames(testCaseClass)
54 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
55 testCaseNames = ['runTest']
56 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
57 return loaded_suite
58
59 def loadTestsFromModule(self, module, use_load_tests=True):
60 """Return a suite of all tests cases contained in the given module"""
61 tests = []
62 for name in dir(module):
63 obj = getattr(module, name)
64 if isinstance(obj, type) and issubclass(obj, case.TestCase):
65 tests.append(self.loadTestsFromTestCase(obj))
66
67 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000068 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000069 if use_load_tests and load_tests is not None:
Benjamin Peterson886af962010-03-21 23:13:07 +000070 try:
71 return load_tests(self, tests, None)
72 except Exception as e:
73 return _make_failed_load_tests(module.__name__, e,
74 self.suiteClass)
Michael Foord41647d62010-02-06 00:26:13 +000075 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000076
77 def loadTestsFromName(self, name, module=None):
78 """Return a suite of all tests cases given a string specifier.
79
80 The name may resolve either to a module, a test case class, a
81 test method within a test case class, or a callable object which
82 returns a TestCase or TestSuite instance.
83
84 The method optionally resolves the names relative to a given module.
85 """
86 parts = name.split('.')
87 if module is None:
88 parts_copy = parts[:]
89 while parts_copy:
90 try:
91 module = __import__('.'.join(parts_copy))
92 break
93 except ImportError:
94 del parts_copy[-1]
95 if not parts_copy:
96 raise
97 parts = parts[1:]
98 obj = module
99 for part in parts:
100 parent, obj = obj, getattr(obj, part)
101
102 if isinstance(obj, types.ModuleType):
103 return self.loadTestsFromModule(obj)
104 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
105 return self.loadTestsFromTestCase(obj)
106 elif (isinstance(obj, types.FunctionType) and
107 isinstance(parent, type) and
108 issubclass(parent, case.TestCase)):
109 name = obj.__name__
110 inst = parent(name)
111 # static methods follow a different path
112 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000113 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000114 elif isinstance(obj, suite.TestSuite):
115 return obj
116 if hasattr(obj, '__call__'):
117 test = obj()
118 if isinstance(test, suite.TestSuite):
119 return test
120 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000121 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000122 else:
123 raise TypeError("calling %s returned %s, not a test" %
124 (obj, test))
125 else:
126 raise TypeError("don't know how to make test from: %s" % obj)
127
128 def loadTestsFromNames(self, names, module=None):
129 """Return a suite of all tests cases found using the given sequence
130 of string specifiers. See 'loadTestsFromName()'.
131 """
132 suites = [self.loadTestsFromName(name, module) for name in names]
133 return self.suiteClass(suites)
134
135 def getTestCaseNames(self, testCaseClass):
136 """Return a sorted sequence of method names found within testCaseClass
137 """
138 def isTestMethod(attrname, testCaseClass=testCaseClass,
139 prefix=self.testMethodPrefix):
140 return attrname.startswith(prefix) and \
141 hasattr(getattr(testCaseClass, attrname), '__call__')
142 testFnNames = testFnNames = list(filter(isTestMethod,
143 dir(testCaseClass)))
144 if self.sortTestMethodsUsing:
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000145 testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000146 return testFnNames
147
148 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
149 """Find and return all test modules from the specified start
150 directory, recursing into subdirectories to find them. Only test files
151 that match the pattern will be loaded. (Using shell style pattern
152 matching.)
153
154 All test modules must be importable from the top level of the project.
155 If the start directory is not the top level directory then the top
156 level directory must be specified separately.
157
158 If a test package name (directory with '__init__.py') matches the
159 pattern then the package will be checked for a 'load_tests' function. If
160 this exists then it will be called with loader, tests, pattern.
161
162 If load_tests exists then discovery does *not* recurse into the package,
163 load_tests is responsible for loading all tests in the package.
164
165 The pattern is deliberately not stored as a loader attribute so that
166 packages can continue discovery themselves. top_level_dir is stored so
167 load_tests does not need to pass this argument in to loader.discover().
168 """
169 if top_level_dir is None and self._top_level_dir is not None:
170 # make top_level_dir optional if called from load_tests in a package
171 top_level_dir = self._top_level_dir
172 elif top_level_dir is None:
173 top_level_dir = start_dir
174
175 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
176 start_dir = os.path.abspath(os.path.normpath(start_dir))
177
178 if not top_level_dir in sys.path:
179 # all test modules must be importable from the top level directory
180 sys.path.append(top_level_dir)
181 self._top_level_dir = top_level_dir
182
183 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
184 # what about __init__.pyc or pyo (etc)
185 raise ImportError('Start directory is not importable: %r' % start_dir)
186
187 tests = list(self._find_tests(start_dir, pattern))
188 return self.suiteClass(tests)
189
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000190 def _get_name_from_path(self, path):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000191 path = os.path.splitext(os.path.normpath(path))[0]
192
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000193 _relpath = os.path.relpath(path, self._top_level_dir)
194 assert not os.path.isabs(_relpath), "Path must be within the project"
195 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000196
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000197 name = _relpath.replace(os.path.sep, '.')
198 return name
199
200 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000201 __import__(name)
202 return sys.modules[name]
203
204 def _find_tests(self, start_dir, pattern):
205 """Used by discovery. Yields test suites it loads."""
206 paths = os.listdir(start_dir)
207
208 for path in paths:
209 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000210 if os.path.isfile(full_path):
211 if not VALID_MODULE_NAME.match(path):
212 # valid Python identifiers only
213 continue
214
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000215 if fnmatch(path, pattern):
216 # if the test file matches, load it
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000217 name = self._get_name_from_path(full_path)
218 try:
219 module = self._get_module_from_name(name)
220 except:
221 yield _make_failed_import_test(name, self.suiteClass)
222 else:
223 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000224 elif os.path.isdir(full_path):
225 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
226 continue
227
228 load_tests = None
229 tests = None
230 if fnmatch(path, pattern):
231 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000232 name = self._get_name_from_path(full_path)
233 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000234 load_tests = getattr(package, 'load_tests', None)
235 tests = self.loadTestsFromModule(package, use_load_tests=False)
236
237 if load_tests is None:
238 if tests is not None:
239 # tests loaded from package file
240 yield tests
241 # recurse into the package
242 for test in self._find_tests(full_path, pattern):
243 yield test
244 else:
Benjamin Peterson886af962010-03-21 23:13:07 +0000245 try:
246 yield load_tests(self, tests, pattern)
247 except Exception as e:
248 yield _make_failed_load_tests(package.__name__, e,
249 self.suiteClass)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000250
251defaultTestLoader = TestLoader()
252
253
254def _makeLoader(prefix, sortUsing, suiteClass=None):
255 loader = TestLoader()
256 loader.sortTestMethodsUsing = sortUsing
257 loader.testMethodPrefix = prefix
258 if suiteClass:
259 loader.suiteClass = suiteClass
260 return loader
261
262def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
263 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
264
265def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
266 suiteClass=suite.TestSuite):
267 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
268 testCaseClass)
269
270def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
271 suiteClass=suite.TestSuite):
272 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
273 module)