blob: 4e9e152869ff0beadc7f25404080546914b2272a [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite, util
12
13
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000014# what about .pyc or .pyo (etc)
15# we would need to avoid loading the same tests multiple times
16# from '.py', '.pyc' *and* '.pyo'
17VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
18
19
20def _make_failed_import_test(name, suiteClass):
21 message = 'Failed to import test module: %s' % name
22 if hasattr(traceback, 'format_exc'):
23 # Python 2.3 compatibility
24 # format_exc returns two frames of discover.py as well
25 message += '\n%s' % traceback.format_exc()
Benjamin Peterson886af962010-03-21 23:13:07 +000026 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
27 suiteClass)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000028
Benjamin Peterson886af962010-03-21 23:13:07 +000029def _make_failed_load_tests(name, exception, suiteClass):
30 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
31
32def _make_failed_test(classname, methodname, exception, suiteClass):
33 def testFailure(self):
34 raise exception
35 attrs = {methodname: testFailure}
36 TestClass = type(classname, (case.TestCase,), attrs)
37 return suiteClass((TestClass(methodname),))
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000038
39
Benjamin Petersonbed7d042009-07-19 21:01:52 +000040class TestLoader(object):
41 """
42 This class is responsible for loading tests according to various criteria
43 and returning them wrapped in a TestSuite
44 """
45 testMethodPrefix = 'test'
46 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
47 suiteClass = suite.TestSuite
48 _top_level_dir = None
49
50 def loadTestsFromTestCase(self, testCaseClass):
51 """Return a suite of all tests cases contained in testCaseClass"""
52 if issubclass(testCaseClass, suite.TestSuite):
53 raise TypeError("Test cases should not be derived from TestSuite." \
54 " Maybe you meant to derive from TestCase?")
55 testCaseNames = self.getTestCaseNames(testCaseClass)
56 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
57 testCaseNames = ['runTest']
58 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
59 return loaded_suite
60
61 def loadTestsFromModule(self, module, use_load_tests=True):
62 """Return a suite of all tests cases contained in the given module"""
63 tests = []
64 for name in dir(module):
65 obj = getattr(module, name)
66 if isinstance(obj, type) and issubclass(obj, case.TestCase):
67 tests.append(self.loadTestsFromTestCase(obj))
68
69 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000070 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000071 if use_load_tests and load_tests is not None:
Benjamin Peterson886af962010-03-21 23:13:07 +000072 try:
73 return load_tests(self, tests, None)
74 except Exception as e:
75 return _make_failed_load_tests(module.__name__, e,
76 self.suiteClass)
Michael Foord41647d62010-02-06 00:26:13 +000077 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000078
79 def loadTestsFromName(self, name, module=None):
80 """Return a suite of all tests cases given a string specifier.
81
82 The name may resolve either to a module, a test case class, a
83 test method within a test case class, or a callable object which
84 returns a TestCase or TestSuite instance.
85
86 The method optionally resolves the names relative to a given module.
87 """
88 parts = name.split('.')
89 if module is None:
90 parts_copy = parts[:]
91 while parts_copy:
92 try:
93 module = __import__('.'.join(parts_copy))
94 break
95 except ImportError:
96 del parts_copy[-1]
97 if not parts_copy:
98 raise
99 parts = parts[1:]
100 obj = module
101 for part in parts:
102 parent, obj = obj, getattr(obj, part)
103
104 if isinstance(obj, types.ModuleType):
105 return self.loadTestsFromModule(obj)
106 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
107 return self.loadTestsFromTestCase(obj)
108 elif (isinstance(obj, types.FunctionType) and
109 isinstance(parent, type) and
110 issubclass(parent, case.TestCase)):
111 name = obj.__name__
112 inst = parent(name)
113 # static methods follow a different path
114 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000115 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000116 elif isinstance(obj, suite.TestSuite):
117 return obj
118 if hasattr(obj, '__call__'):
119 test = obj()
120 if isinstance(test, suite.TestSuite):
121 return test
122 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000123 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000124 else:
125 raise TypeError("calling %s returned %s, not a test" %
126 (obj, test))
127 else:
128 raise TypeError("don't know how to make test from: %s" % obj)
129
130 def loadTestsFromNames(self, names, module=None):
131 """Return a suite of all tests cases found using the given sequence
132 of string specifiers. See 'loadTestsFromName()'.
133 """
134 suites = [self.loadTestsFromName(name, module) for name in names]
135 return self.suiteClass(suites)
136
137 def getTestCaseNames(self, testCaseClass):
138 """Return a sorted sequence of method names found within testCaseClass
139 """
140 def isTestMethod(attrname, testCaseClass=testCaseClass,
141 prefix=self.testMethodPrefix):
142 return attrname.startswith(prefix) and \
143 hasattr(getattr(testCaseClass, attrname), '__call__')
144 testFnNames = testFnNames = list(filter(isTestMethod,
145 dir(testCaseClass)))
146 if self.sortTestMethodsUsing:
147 testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing))
148 return testFnNames
149
150 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
151 """Find and return all test modules from the specified start
152 directory, recursing into subdirectories to find them. Only test files
153 that match the pattern will be loaded. (Using shell style pattern
154 matching.)
155
156 All test modules must be importable from the top level of the project.
157 If the start directory is not the top level directory then the top
158 level directory must be specified separately.
159
160 If a test package name (directory with '__init__.py') matches the
161 pattern then the package will be checked for a 'load_tests' function. If
162 this exists then it will be called with loader, tests, pattern.
163
164 If load_tests exists then discovery does *not* recurse into the package,
165 load_tests is responsible for loading all tests in the package.
166
167 The pattern is deliberately not stored as a loader attribute so that
168 packages can continue discovery themselves. top_level_dir is stored so
169 load_tests does not need to pass this argument in to loader.discover().
170 """
171 if top_level_dir is None and self._top_level_dir is not None:
172 # make top_level_dir optional if called from load_tests in a package
173 top_level_dir = self._top_level_dir
174 elif top_level_dir is None:
175 top_level_dir = start_dir
176
177 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
178 start_dir = os.path.abspath(os.path.normpath(start_dir))
179
180 if not top_level_dir in sys.path:
181 # all test modules must be importable from the top level directory
182 sys.path.append(top_level_dir)
183 self._top_level_dir = top_level_dir
184
185 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
186 # what about __init__.pyc or pyo (etc)
187 raise ImportError('Start directory is not importable: %r' % start_dir)
188
189 tests = list(self._find_tests(start_dir, pattern))
190 return self.suiteClass(tests)
191
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000192 def _get_name_from_path(self, path):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000193 path = os.path.splitext(os.path.normpath(path))[0]
194
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000195 _relpath = os.path.relpath(path, self._top_level_dir)
196 assert not os.path.isabs(_relpath), "Path must be within the project"
197 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000198
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000199 name = _relpath.replace(os.path.sep, '.')
200 return name
201
202 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000203 __import__(name)
204 return sys.modules[name]
205
206 def _find_tests(self, start_dir, pattern):
207 """Used by discovery. Yields test suites it loads."""
208 paths = os.listdir(start_dir)
209
210 for path in paths:
211 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000212 if os.path.isfile(full_path):
213 if not VALID_MODULE_NAME.match(path):
214 # valid Python identifiers only
215 continue
216
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000217 if fnmatch(path, pattern):
218 # if the test file matches, load it
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000219 name = self._get_name_from_path(full_path)
220 try:
221 module = self._get_module_from_name(name)
222 except:
223 yield _make_failed_import_test(name, self.suiteClass)
224 else:
225 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000226 elif os.path.isdir(full_path):
227 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
228 continue
229
230 load_tests = None
231 tests = None
232 if fnmatch(path, pattern):
233 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000234 name = self._get_name_from_path(full_path)
235 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000236 load_tests = getattr(package, 'load_tests', None)
237 tests = self.loadTestsFromModule(package, use_load_tests=False)
238
239 if load_tests is None:
240 if tests is not None:
241 # tests loaded from package file
242 yield tests
243 # recurse into the package
244 for test in self._find_tests(full_path, pattern):
245 yield test
246 else:
Benjamin Peterson886af962010-03-21 23:13:07 +0000247 try:
248 yield load_tests(self, tests, pattern)
249 except Exception as e:
250 yield _make_failed_load_tests(package.__name__, e,
251 self.suiteClass)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000252
253defaultTestLoader = TestLoader()
254
255
256def _makeLoader(prefix, sortUsing, suiteClass=None):
257 loader = TestLoader()
258 loader.sortTestMethodsUsing = sortUsing
259 loader.testMethodPrefix = prefix
260 if suiteClass:
261 loader.suiteClass = suiteClass
262 return loader
263
264def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
265 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
266
267def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
268 suiteClass=suite.TestSuite):
269 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
270 testCaseClass)
271
272def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
273 suiteClass=suite.TestSuite):
274 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
275 module)