blob: eea5c13366a1f14976a9730e333400cfa6995595 [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite, util
12
Benjamin Petersondccc1fc2010-03-22 00:15:53 +000013__unittest = True
Benjamin Petersonbed7d042009-07-19 21:01:52 +000014
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000015# what about .pyc or .pyo (etc)
16# we would need to avoid loading the same tests multiple times
17# from '.py', '.pyc' *and* '.pyo'
18VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
19
20
21def _make_failed_import_test(name, suiteClass):
22 message = 'Failed to import test module: %s' % name
23 if hasattr(traceback, 'format_exc'):
24 # Python 2.3 compatibility
25 # format_exc returns two frames of discover.py as well
26 message += '\n%s' % traceback.format_exc()
Benjamin Peterson886af962010-03-21 23:13:07 +000027 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
28 suiteClass)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000029
Benjamin Peterson886af962010-03-21 23:13:07 +000030def _make_failed_load_tests(name, exception, suiteClass):
31 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
32
33def _make_failed_test(classname, methodname, exception, suiteClass):
34 def testFailure(self):
35 raise exception
36 attrs = {methodname: testFailure}
37 TestClass = type(classname, (case.TestCase,), attrs)
38 return suiteClass((TestClass(methodname),))
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000039
40
Benjamin Petersonbed7d042009-07-19 21:01:52 +000041class TestLoader(object):
42 """
43 This class is responsible for loading tests according to various criteria
44 and returning them wrapped in a TestSuite
45 """
46 testMethodPrefix = 'test'
47 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
48 suiteClass = suite.TestSuite
49 _top_level_dir = None
50
51 def loadTestsFromTestCase(self, testCaseClass):
52 """Return a suite of all tests cases contained in testCaseClass"""
53 if issubclass(testCaseClass, suite.TestSuite):
54 raise TypeError("Test cases should not be derived from TestSuite." \
55 " Maybe you meant to derive from TestCase?")
56 testCaseNames = self.getTestCaseNames(testCaseClass)
57 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
58 testCaseNames = ['runTest']
59 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
60 return loaded_suite
61
62 def loadTestsFromModule(self, module, use_load_tests=True):
63 """Return a suite of all tests cases contained in the given module"""
64 tests = []
65 for name in dir(module):
66 obj = getattr(module, name)
67 if isinstance(obj, type) and issubclass(obj, case.TestCase):
68 tests.append(self.loadTestsFromTestCase(obj))
69
70 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000071 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000072 if use_load_tests and load_tests is not None:
Benjamin Peterson886af962010-03-21 23:13:07 +000073 try:
74 return load_tests(self, tests, None)
75 except Exception as e:
76 return _make_failed_load_tests(module.__name__, e,
77 self.suiteClass)
Michael Foord41647d62010-02-06 00:26:13 +000078 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000079
80 def loadTestsFromName(self, name, module=None):
81 """Return a suite of all tests cases given a string specifier.
82
83 The name may resolve either to a module, a test case class, a
84 test method within a test case class, or a callable object which
85 returns a TestCase or TestSuite instance.
86
87 The method optionally resolves the names relative to a given module.
88 """
89 parts = name.split('.')
90 if module is None:
91 parts_copy = parts[:]
92 while parts_copy:
93 try:
94 module = __import__('.'.join(parts_copy))
95 break
96 except ImportError:
97 del parts_copy[-1]
98 if not parts_copy:
99 raise
100 parts = parts[1:]
101 obj = module
102 for part in parts:
103 parent, obj = obj, getattr(obj, part)
104
105 if isinstance(obj, types.ModuleType):
106 return self.loadTestsFromModule(obj)
107 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
108 return self.loadTestsFromTestCase(obj)
109 elif (isinstance(obj, types.FunctionType) and
110 isinstance(parent, type) and
111 issubclass(parent, case.TestCase)):
112 name = obj.__name__
113 inst = parent(name)
114 # static methods follow a different path
115 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000116 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000117 elif isinstance(obj, suite.TestSuite):
118 return obj
119 if hasattr(obj, '__call__'):
120 test = obj()
121 if isinstance(test, suite.TestSuite):
122 return test
123 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000124 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000125 else:
126 raise TypeError("calling %s returned %s, not a test" %
127 (obj, test))
128 else:
129 raise TypeError("don't know how to make test from: %s" % obj)
130
131 def loadTestsFromNames(self, names, module=None):
132 """Return a suite of all tests cases found using the given sequence
133 of string specifiers. See 'loadTestsFromName()'.
134 """
135 suites = [self.loadTestsFromName(name, module) for name in names]
136 return self.suiteClass(suites)
137
138 def getTestCaseNames(self, testCaseClass):
139 """Return a sorted sequence of method names found within testCaseClass
140 """
141 def isTestMethod(attrname, testCaseClass=testCaseClass,
142 prefix=self.testMethodPrefix):
143 return attrname.startswith(prefix) and \
144 hasattr(getattr(testCaseClass, attrname), '__call__')
145 testFnNames = testFnNames = list(filter(isTestMethod,
146 dir(testCaseClass)))
147 if self.sortTestMethodsUsing:
148 testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing))
149 return testFnNames
150
151 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
152 """Find and return all test modules from the specified start
153 directory, recursing into subdirectories to find them. Only test files
154 that match the pattern will be loaded. (Using shell style pattern
155 matching.)
156
157 All test modules must be importable from the top level of the project.
158 If the start directory is not the top level directory then the top
159 level directory must be specified separately.
160
161 If a test package name (directory with '__init__.py') matches the
162 pattern then the package will be checked for a 'load_tests' function. If
163 this exists then it will be called with loader, tests, pattern.
164
165 If load_tests exists then discovery does *not* recurse into the package,
166 load_tests is responsible for loading all tests in the package.
167
168 The pattern is deliberately not stored as a loader attribute so that
169 packages can continue discovery themselves. top_level_dir is stored so
170 load_tests does not need to pass this argument in to loader.discover().
171 """
172 if top_level_dir is None and self._top_level_dir is not None:
173 # make top_level_dir optional if called from load_tests in a package
174 top_level_dir = self._top_level_dir
175 elif top_level_dir is None:
176 top_level_dir = start_dir
177
178 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
179 start_dir = os.path.abspath(os.path.normpath(start_dir))
180
181 if not top_level_dir in sys.path:
182 # all test modules must be importable from the top level directory
183 sys.path.append(top_level_dir)
184 self._top_level_dir = top_level_dir
185
186 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
187 # what about __init__.pyc or pyo (etc)
188 raise ImportError('Start directory is not importable: %r' % start_dir)
189
190 tests = list(self._find_tests(start_dir, pattern))
191 return self.suiteClass(tests)
192
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000193 def _get_name_from_path(self, path):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000194 path = os.path.splitext(os.path.normpath(path))[0]
195
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000196 _relpath = os.path.relpath(path, self._top_level_dir)
197 assert not os.path.isabs(_relpath), "Path must be within the project"
198 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000199
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000200 name = _relpath.replace(os.path.sep, '.')
201 return name
202
203 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000204 __import__(name)
205 return sys.modules[name]
206
207 def _find_tests(self, start_dir, pattern):
208 """Used by discovery. Yields test suites it loads."""
209 paths = os.listdir(start_dir)
210
211 for path in paths:
212 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000213 if os.path.isfile(full_path):
214 if not VALID_MODULE_NAME.match(path):
215 # valid Python identifiers only
216 continue
217
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000218 if fnmatch(path, pattern):
219 # if the test file matches, load it
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000220 name = self._get_name_from_path(full_path)
221 try:
222 module = self._get_module_from_name(name)
223 except:
224 yield _make_failed_import_test(name, self.suiteClass)
225 else:
226 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000227 elif os.path.isdir(full_path):
228 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
229 continue
230
231 load_tests = None
232 tests = None
233 if fnmatch(path, pattern):
234 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000235 name = self._get_name_from_path(full_path)
236 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000237 load_tests = getattr(package, 'load_tests', None)
238 tests = self.loadTestsFromModule(package, use_load_tests=False)
239
240 if load_tests is None:
241 if tests is not None:
242 # tests loaded from package file
243 yield tests
244 # recurse into the package
245 for test in self._find_tests(full_path, pattern):
246 yield test
247 else:
Benjamin Peterson886af962010-03-21 23:13:07 +0000248 try:
249 yield load_tests(self, tests, pattern)
250 except Exception as e:
251 yield _make_failed_load_tests(package.__name__, e,
252 self.suiteClass)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000253
254defaultTestLoader = TestLoader()
255
256
257def _makeLoader(prefix, sortUsing, suiteClass=None):
258 loader = TestLoader()
259 loader.sortTestMethodsUsing = sortUsing
260 loader.testMethodPrefix = prefix
261 if suiteClass:
262 loader.suiteClass = suiteClass
263 return loader
264
265def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
266 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
267
268def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
269 suiteClass=suite.TestSuite):
270 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
271 testCaseClass)
272
273def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
274 suiteClass=suite.TestSuite):
275 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
276 module)