blob: bfee3dcc4117b17ab6db972942ae7ef4245b5de3 [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00004import re
Benjamin Petersonbed7d042009-07-19 21:01:52 +00005import sys
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +00006import traceback
Benjamin Petersonbed7d042009-07-19 21:01:52 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite, util
12
13
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +000014# what about .pyc or .pyo (etc)
15# we would need to avoid loading the same tests multiple times
16# from '.py', '.pyc' *and* '.pyo'
17VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
18
19
20def _make_failed_import_test(name, suiteClass):
21 message = 'Failed to import test module: %s' % name
22 if hasattr(traceback, 'format_exc'):
23 # Python 2.3 compatibility
24 # format_exc returns two frames of discover.py as well
25 message += '\n%s' % traceback.format_exc()
26
27 def testImportFailure(self):
28 raise ImportError(message)
29 attrs = {name: testImportFailure}
30 ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs)
31 return suiteClass((ModuleImportFailure(name),))
32
33
Benjamin Petersonbed7d042009-07-19 21:01:52 +000034class TestLoader(object):
35 """
36 This class is responsible for loading tests according to various criteria
37 and returning them wrapped in a TestSuite
38 """
39 testMethodPrefix = 'test'
40 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
41 suiteClass = suite.TestSuite
42 _top_level_dir = None
43
44 def loadTestsFromTestCase(self, testCaseClass):
45 """Return a suite of all tests cases contained in testCaseClass"""
46 if issubclass(testCaseClass, suite.TestSuite):
47 raise TypeError("Test cases should not be derived from TestSuite." \
48 " Maybe you meant to derive from TestCase?")
49 testCaseNames = self.getTestCaseNames(testCaseClass)
50 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
51 testCaseNames = ['runTest']
52 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
53 return loaded_suite
54
55 def loadTestsFromModule(self, module, use_load_tests=True):
56 """Return a suite of all tests cases contained in the given module"""
57 tests = []
58 for name in dir(module):
59 obj = getattr(module, name)
60 if isinstance(obj, type) and issubclass(obj, case.TestCase):
61 tests.append(self.loadTestsFromTestCase(obj))
62
63 load_tests = getattr(module, 'load_tests', None)
Michael Foord41647d62010-02-06 00:26:13 +000064 tests = self.suiteClass(tests)
Benjamin Petersonbed7d042009-07-19 21:01:52 +000065 if use_load_tests and load_tests is not None:
66 return load_tests(self, tests, None)
Michael Foord41647d62010-02-06 00:26:13 +000067 return tests
Benjamin Petersonbed7d042009-07-19 21:01:52 +000068
69 def loadTestsFromName(self, name, module=None):
70 """Return a suite of all tests cases given a string specifier.
71
72 The name may resolve either to a module, a test case class, a
73 test method within a test case class, or a callable object which
74 returns a TestCase or TestSuite instance.
75
76 The method optionally resolves the names relative to a given module.
77 """
78 parts = name.split('.')
79 if module is None:
80 parts_copy = parts[:]
81 while parts_copy:
82 try:
83 module = __import__('.'.join(parts_copy))
84 break
85 except ImportError:
86 del parts_copy[-1]
87 if not parts_copy:
88 raise
89 parts = parts[1:]
90 obj = module
91 for part in parts:
92 parent, obj = obj, getattr(obj, part)
93
94 if isinstance(obj, types.ModuleType):
95 return self.loadTestsFromModule(obj)
96 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
97 return self.loadTestsFromTestCase(obj)
98 elif (isinstance(obj, types.FunctionType) and
99 isinstance(parent, type) and
100 issubclass(parent, case.TestCase)):
101 name = obj.__name__
102 inst = parent(name)
103 # static methods follow a different path
104 if not isinstance(getattr(inst, name), types.FunctionType):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000105 return self.suiteClass([inst])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000106 elif isinstance(obj, suite.TestSuite):
107 return obj
108 if hasattr(obj, '__call__'):
109 test = obj()
110 if isinstance(test, suite.TestSuite):
111 return test
112 elif isinstance(test, case.TestCase):
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000113 return self.suiteClass([test])
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000114 else:
115 raise TypeError("calling %s returned %s, not a test" %
116 (obj, test))
117 else:
118 raise TypeError("don't know how to make test from: %s" % obj)
119
120 def loadTestsFromNames(self, names, module=None):
121 """Return a suite of all tests cases found using the given sequence
122 of string specifiers. See 'loadTestsFromName()'.
123 """
124 suites = [self.loadTestsFromName(name, module) for name in names]
125 return self.suiteClass(suites)
126
127 def getTestCaseNames(self, testCaseClass):
128 """Return a sorted sequence of method names found within testCaseClass
129 """
130 def isTestMethod(attrname, testCaseClass=testCaseClass,
131 prefix=self.testMethodPrefix):
132 return attrname.startswith(prefix) and \
133 hasattr(getattr(testCaseClass, attrname), '__call__')
134 testFnNames = testFnNames = list(filter(isTestMethod,
135 dir(testCaseClass)))
136 if self.sortTestMethodsUsing:
137 testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing))
138 return testFnNames
139
140 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
141 """Find and return all test modules from the specified start
142 directory, recursing into subdirectories to find them. Only test files
143 that match the pattern will be loaded. (Using shell style pattern
144 matching.)
145
146 All test modules must be importable from the top level of the project.
147 If the start directory is not the top level directory then the top
148 level directory must be specified separately.
149
150 If a test package name (directory with '__init__.py') matches the
151 pattern then the package will be checked for a 'load_tests' function. If
152 this exists then it will be called with loader, tests, pattern.
153
154 If load_tests exists then discovery does *not* recurse into the package,
155 load_tests is responsible for loading all tests in the package.
156
157 The pattern is deliberately not stored as a loader attribute so that
158 packages can continue discovery themselves. top_level_dir is stored so
159 load_tests does not need to pass this argument in to loader.discover().
160 """
161 if top_level_dir is None and self._top_level_dir is not None:
162 # make top_level_dir optional if called from load_tests in a package
163 top_level_dir = self._top_level_dir
164 elif top_level_dir is None:
165 top_level_dir = start_dir
166
167 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
168 start_dir = os.path.abspath(os.path.normpath(start_dir))
169
170 if not top_level_dir in sys.path:
171 # all test modules must be importable from the top level directory
172 sys.path.append(top_level_dir)
173 self._top_level_dir = top_level_dir
174
175 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
176 # what about __init__.pyc or pyo (etc)
177 raise ImportError('Start directory is not importable: %r' % start_dir)
178
179 tests = list(self._find_tests(start_dir, pattern))
180 return self.suiteClass(tests)
181
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000182 def _get_name_from_path(self, path):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000183 path = os.path.splitext(os.path.normpath(path))[0]
184
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000185 _relpath = os.path.relpath(path, self._top_level_dir)
186 assert not os.path.isabs(_relpath), "Path must be within the project"
187 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000188
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000189 name = _relpath.replace(os.path.sep, '.')
190 return name
191
192 def _get_module_from_name(self, name):
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000193 __import__(name)
194 return sys.modules[name]
195
196 def _find_tests(self, start_dir, pattern):
197 """Used by discovery. Yields test suites it loads."""
198 paths = os.listdir(start_dir)
199
200 for path in paths:
201 full_path = os.path.join(start_dir, path)
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000202 if os.path.isfile(full_path):
203 if not VALID_MODULE_NAME.match(path):
204 # valid Python identifiers only
205 continue
206
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000207 if fnmatch(path, pattern):
208 # if the test file matches, load it
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000209 name = self._get_name_from_path(full_path)
210 try:
211 module = self._get_module_from_name(name)
212 except:
213 yield _make_failed_import_test(name, self.suiteClass)
214 else:
215 yield self.loadTestsFromModule(module)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000216 elif os.path.isdir(full_path):
217 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
218 continue
219
220 load_tests = None
221 tests = None
222 if fnmatch(path, pattern):
223 # only check load_tests if the package directory itself matches the filter
Benjamin Peterson4ac9ce42009-10-04 14:49:41 +0000224 name = self._get_name_from_path(full_path)
225 package = self._get_module_from_name(name)
Benjamin Petersonbed7d042009-07-19 21:01:52 +0000226 load_tests = getattr(package, 'load_tests', None)
227 tests = self.loadTestsFromModule(package, use_load_tests=False)
228
229 if load_tests is None:
230 if tests is not None:
231 # tests loaded from package file
232 yield tests
233 # recurse into the package
234 for test in self._find_tests(full_path, pattern):
235 yield test
236 else:
237 yield load_tests(self, tests, pattern)
238
239defaultTestLoader = TestLoader()
240
241
242def _makeLoader(prefix, sortUsing, suiteClass=None):
243 loader = TestLoader()
244 loader.sortTestMethodsUsing = sortUsing
245 loader.testMethodPrefix = prefix
246 if suiteClass:
247 loader.suiteClass = suiteClass
248 return loader
249
250def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
251 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
252
253def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
254 suiteClass=suite.TestSuite):
255 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
256 testCaseClass)
257
258def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
259 suiteClass=suite.TestSuite):
260 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
261 module)