blob: 5d11b6e8ffd1c8f9c9bcafac9981946f9e7ab6e8 [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite, util
12
Benjamin Petersondccc1fc2010-03-22 00:15:53 +000013__unittest = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +000014
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000015# what about .pyc or .pyo (etc)
16# we would need to avoid loading the same tests multiple times
17# from '.py', '.pyc' *and* '.pyo'
18VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
19
20
21def _make_failed_import_test(name, suiteClass):
Benjamin Peterson434ae772010-03-22 01:46:47 +000022 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Benjamin Peterson886af962010-03-21 23:13:07 +000023 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
24 suiteClass)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000025
Benjamin Peterson886af962010-03-21 23:13:07 +000026def _make_failed_load_tests(name, exception, suiteClass):
27 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
28
29def _make_failed_test(classname, methodname, exception, suiteClass):
30 def testFailure(self):
31 raise exception
32 attrs = {methodname: testFailure}
33 TestClass = type(classname, (case.TestCase,), attrs)
34 return suiteClass((TestClass(methodname),))
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000035
36
Benjamin Petersonbed7d042009-07-19 21:01:52 +000037class TestLoader(object):
38 """
39 This class is responsible for loading tests according to various criteria
40 and returning them wrapped in a TestSuite
41 """
42 testMethodPrefix = 'test'
43 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
44 suiteClass = suite.TestSuite
45 _top_level_dir = None
46
47 def loadTestsFromTestCase(self, testCaseClass):
48 """Return a suite of all tests cases contained in testCaseClass"""
49 if issubclass(testCaseClass, suite.TestSuite):
50 raise TypeError("Test cases should not be derived from TestSuite." \
51 " Maybe you meant to derive from TestCase?")
52 testCaseNames = self.getTestCaseNames(testCaseClass)
53 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
54 testCaseNames = ['runTest']
55 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
56 return loaded_suite
57
58 def loadTestsFromModule(self, module, use_load_tests=True):
59 """Return a suite of all tests cases contained in the given module"""
60 tests = []
61 for name in dir(module):
62 obj = getattr(module, name)
63 if isinstance(obj, type) and issubclass(obj, case.TestCase):
64 tests.append(self.loadTestsFromTestCase(obj))
65
66 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000067 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000068 if use_load_tests and load_tests is not None:
Benjamin Peterson886af962010-03-21 23:13:07 +000069 try:
70 return load_tests(self, tests, None)
71 except Exception as e:
72 return _make_failed_load_tests(module.__name__, e,
73 self.suiteClass)
Michael Foord41647d62010-02-06 00:26:13 +000074 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000075
76 def loadTestsFromName(self, name, module=None):
77 """Return a suite of all tests cases given a string specifier.
78
79 The name may resolve either to a module, a test case class, a
80 test method within a test case class, or a callable object which
81 returns a TestCase or TestSuite instance.
82
83 The method optionally resolves the names relative to a given module.
84 """
85 parts = name.split('.')
86 if module is None:
87 parts_copy = parts[:]
88 while parts_copy:
89 try:
90 module = __import__('.'.join(parts_copy))
91 break
92 except ImportError:
93 del parts_copy[-1]
94 if not parts_copy:
95 raise
96 parts = parts[1:]
97 obj = module
98 for part in parts:
99 parent, obj = obj, getattr(obj, part)
100
101 if isinstance(obj, types.ModuleType):
102 return self.loadTestsFromModule(obj)
103 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
104 return self.loadTestsFromTestCase(obj)
105 elif (isinstance(obj, types.FunctionType) and
106 isinstance(parent, type) and
107 issubclass(parent, case.TestCase)):
108 name = obj.__name__
109 inst = parent(name)
110 # static methods follow a different path
111 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000112 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000113 elif isinstance(obj, suite.TestSuite):
114 return obj
115 if hasattr(obj, '__call__'):
116 test = obj()
117 if isinstance(test, suite.TestSuite):
118 return test
119 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000120 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000121 else:
122 raise TypeError("calling %s returned %s, not a test" %
123 (obj, test))
124 else:
125 raise TypeError("don't know how to make test from: %s" % obj)
126
127 def loadTestsFromNames(self, names, module=None):
128 """Return a suite of all tests cases found using the given sequence
129 of string specifiers. See 'loadTestsFromName()'.
130 """
131 suites = [self.loadTestsFromName(name, module) for name in names]
132 return self.suiteClass(suites)
133
134 def getTestCaseNames(self, testCaseClass):
135 """Return a sorted sequence of method names found within testCaseClass
136 """
137 def isTestMethod(attrname, testCaseClass=testCaseClass,
138 prefix=self.testMethodPrefix):
139 return attrname.startswith(prefix) and \
140 hasattr(getattr(testCaseClass, attrname), '__call__')
141 testFnNames = testFnNames = list(filter(isTestMethod,
142 dir(testCaseClass)))
143 if self.sortTestMethodsUsing:
144 testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing))
145 return testFnNames
146
147 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
148 """Find and return all test modules from the specified start
149 directory, recursing into subdirectories to find them. Only test files
150 that match the pattern will be loaded. (Using shell style pattern
151 matching.)
152
153 All test modules must be importable from the top level of the project.
154 If the start directory is not the top level directory then the top
155 level directory must be specified separately.
156
157 If a test package name (directory with '__init__.py') matches the
158 pattern then the package will be checked for a 'load_tests' function. If
159 this exists then it will be called with loader, tests, pattern.
160
161 If load_tests exists then discovery does *not* recurse into the package,
162 load_tests is responsible for loading all tests in the package.
163
164 The pattern is deliberately not stored as a loader attribute so that
165 packages can continue discovery themselves. top_level_dir is stored so
166 load_tests does not need to pass this argument in to loader.discover().
167 """
168 if top_level_dir is None and self._top_level_dir is not None:
169 # make top_level_dir optional if called from load_tests in a package
170 top_level_dir = self._top_level_dir
171 elif top_level_dir is None:
172 top_level_dir = start_dir
173
174 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
175 start_dir = os.path.abspath(os.path.normpath(start_dir))
176
177 if not top_level_dir in sys.path:
178 # all test modules must be importable from the top level directory
179 sys.path.append(top_level_dir)
180 self._top_level_dir = top_level_dir
181
182 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
183 # what about __init__.pyc or pyo (etc)
184 raise ImportError('Start directory is not importable: %r' % start_dir)
185
186 tests = list(self._find_tests(start_dir, pattern))
187 return self.suiteClass(tests)
188
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000189 def _get_name_from_path(self, path):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000190 path = os.path.splitext(os.path.normpath(path))[0]
191
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000192 _relpath = os.path.relpath(path, self._top_level_dir)
193 assert not os.path.isabs(_relpath), "Path must be within the project"
194 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000195
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000196 name = _relpath.replace(os.path.sep, '.')
197 return name
198
199 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000200 __import__(name)
201 return sys.modules[name]
202
203 def _find_tests(self, start_dir, pattern):
204 """Used by discovery. Yields test suites it loads."""
205 paths = os.listdir(start_dir)
206
207 for path in paths:
208 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000209 if os.path.isfile(full_path):
210 if not VALID_MODULE_NAME.match(path):
211 # valid Python identifiers only
212 continue
213
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000214 if fnmatch(path, pattern):
215 # if the test file matches, load it
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000216 name = self._get_name_from_path(full_path)
217 try:
218 module = self._get_module_from_name(name)
219 except:
220 yield _make_failed_import_test(name, self.suiteClass)
221 else:
222 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000223 elif os.path.isdir(full_path):
224 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
225 continue
226
227 load_tests = None
228 tests = None
229 if fnmatch(path, pattern):
230 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000231 name = self._get_name_from_path(full_path)
232 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000233 load_tests = getattr(package, 'load_tests', None)
234 tests = self.loadTestsFromModule(package, use_load_tests=False)
235
236 if load_tests is None:
237 if tests is not None:
238 # tests loaded from package file
239 yield tests
240 # recurse into the package
241 for test in self._find_tests(full_path, pattern):
242 yield test
243 else:
Benjamin Peterson886af962010-03-21 23:13:07 +0000244 try:
245 yield load_tests(self, tests, pattern)
246 except Exception as e:
247 yield _make_failed_load_tests(package.__name__, e,
248 self.suiteClass)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000249
250defaultTestLoader = TestLoader()
251
252
253def _makeLoader(prefix, sortUsing, suiteClass=None):
254 loader = TestLoader()
255 loader.sortTestMethodsUsing = sortUsing
256 loader.testMethodPrefix = prefix
257 if suiteClass:
258 loader.suiteClass = suiteClass
259 return loader
260
261def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
262 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
263
264def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
265 suiteClass=suite.TestSuite):
266 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
267 testCaseClass)
268
269def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
270 suiteClass=suite.TestSuite):
271 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
272 module)