blob: 022ed5781605a1f491a7f8a15df457f4be7468d6 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite
12
Michael Foordb1aa30f2010-03-22 00:06:30 +000013__unittest = True
14
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000015
16def _CmpToKey(mycmp):
17 'Convert a cmp= function into a key= function'
18 class K(object):
19 def __init__(self, obj):
20 self.obj = obj
21 def __lt__(self, other):
22 return mycmp(self.obj, other.obj) == -1
23 return K
24
25
Michael Foorde91ea562009-09-13 19:07:03 +000026# what about .pyc or .pyo (etc)
27# we would need to avoid loading the same tests multiple times
28# from '.py', '.pyc' *and* '.pyo'
29VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
30
31
32def _make_failed_import_test(name, suiteClass):
Michael Foord49899692010-03-22 01:41:11 +000033 message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
Michael Foord8cb253f2010-03-21 00:55:58 +000034 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
35 suiteClass)
Michael Foorde91ea562009-09-13 19:07:03 +000036
Michael Foord73dbe042010-03-21 00:53:39 +000037def _make_failed_load_tests(name, exception, suiteClass):
Michael Foord8cb253f2010-03-21 00:55:58 +000038 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
Michael Foord73dbe042010-03-21 00:53:39 +000039
Michael Foord8cb253f2010-03-21 00:55:58 +000040def _make_failed_test(classname, methodname, exception, suiteClass):
Michael Foord73dbe042010-03-21 00:53:39 +000041 def testFailure(self):
42 raise exception
43 attrs = {methodname: testFailure}
44 TestClass = type(classname, (case.TestCase,), attrs)
45 return suiteClass((TestClass(methodname),))
Michael Foorde91ea562009-09-13 19:07:03 +000046
47
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000048class TestLoader(object):
49 """
50 This class is responsible for loading tests according to various criteria
51 and returning them wrapped in a TestSuite
52 """
53 testMethodPrefix = 'test'
54 sortTestMethodsUsing = cmp
55 suiteClass = suite.TestSuite
56 _top_level_dir = None
57
58 def loadTestsFromTestCase(self, testCaseClass):
59 """Return a suite of all tests cases contained in testCaseClass"""
60 if issubclass(testCaseClass, suite.TestSuite):
61 raise TypeError("Test cases should not be derived from TestSuite." \
62 " Maybe you meant to derive from TestCase?")
63 testCaseNames = self.getTestCaseNames(testCaseClass)
64 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
65 testCaseNames = ['runTest']
66 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
67 return loaded_suite
68
69 def loadTestsFromModule(self, module, use_load_tests=True):
70 """Return a suite of all tests cases contained in the given module"""
71 tests = []
72 for name in dir(module):
73 obj = getattr(module, name)
74 if isinstance(obj, type) and issubclass(obj, case.TestCase):
75 tests.append(self.loadTestsFromTestCase(obj))
76
77 load_tests = getattr(module, 'load_tests', None)
Michael Foord08770602010-02-06 00:22:26 +000078 tests = self.suiteClass(tests)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000079 if use_load_tests and load_tests is not None:
Michael Foord73dbe042010-03-21 00:53:39 +000080 try:
81 return load_tests(self, tests, None)
82 except Exception, e:
83 return _make_failed_load_tests(module.__name__, e,
84 self.suiteClass)
Michael Foord08770602010-02-06 00:22:26 +000085 return tests
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000086
87 def loadTestsFromName(self, name, module=None):
88 """Return a suite of all tests cases given a string specifier.
89
90 The name may resolve either to a module, a test case class, a
91 test method within a test case class, or a callable object which
92 returns a TestCase or TestSuite instance.
93
94 The method optionally resolves the names relative to a given module.
95 """
96 parts = name.split('.')
97 if module is None:
98 parts_copy = parts[:]
99 while parts_copy:
100 try:
101 module = __import__('.'.join(parts_copy))
102 break
103 except ImportError:
104 del parts_copy[-1]
105 if not parts_copy:
106 raise
107 parts = parts[1:]
108 obj = module
109 for part in parts:
110 parent, obj = obj, getattr(obj, part)
111
112 if isinstance(obj, types.ModuleType):
113 return self.loadTestsFromModule(obj)
114 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
115 return self.loadTestsFromTestCase(obj)
116 elif (isinstance(obj, types.UnboundMethodType) and
117 isinstance(parent, type) and
118 issubclass(parent, case.TestCase)):
Michael Foord5a9719d2009-09-13 17:28:35 +0000119 return self.suiteClass([parent(obj.__name__)])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000120 elif isinstance(obj, suite.TestSuite):
121 return obj
122 elif hasattr(obj, '__call__'):
123 test = obj()
124 if isinstance(test, suite.TestSuite):
125 return test
126 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000127 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000128 else:
129 raise TypeError("calling %s returned %s, not a test" %
130 (obj, test))
131 else:
132 raise TypeError("don't know how to make test from: %s" % obj)
133
134 def loadTestsFromNames(self, names, module=None):
135 """Return a suite of all tests cases found using the given sequence
136 of string specifiers. See 'loadTestsFromName()'.
137 """
138 suites = [self.loadTestsFromName(name, module) for name in names]
139 return self.suiteClass(suites)
140
141 def getTestCaseNames(self, testCaseClass):
142 """Return a sorted sequence of method names found within testCaseClass
143 """
144 def isTestMethod(attrname, testCaseClass=testCaseClass,
145 prefix=self.testMethodPrefix):
146 return attrname.startswith(prefix) and \
147 hasattr(getattr(testCaseClass, attrname), '__call__')
148 testFnNames = filter(isTestMethod, dir(testCaseClass))
149 if self.sortTestMethodsUsing:
150 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
151 return testFnNames
152
153 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
154 """Find and return all test modules from the specified start
155 directory, recursing into subdirectories to find them. Only test files
156 that match the pattern will be loaded. (Using shell style pattern
157 matching.)
158
159 All test modules must be importable from the top level of the project.
160 If the start directory is not the top level directory then the top
161 level directory must be specified separately.
162
163 If a test package name (directory with '__init__.py') matches the
164 pattern then the package will be checked for a 'load_tests' function. If
165 this exists then it will be called with loader, tests, pattern.
166
167 If load_tests exists then discovery does *not* recurse into the package,
168 load_tests is responsible for loading all tests in the package.
169
170 The pattern is deliberately not stored as a loader attribute so that
171 packages can continue discovery themselves. top_level_dir is stored so
172 load_tests does not need to pass this argument in to loader.discover().
173 """
174 if top_level_dir is None and self._top_level_dir is not None:
175 # make top_level_dir optional if called from load_tests in a package
176 top_level_dir = self._top_level_dir
177 elif top_level_dir is None:
178 top_level_dir = start_dir
179
180 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
181 start_dir = os.path.abspath(os.path.normpath(start_dir))
182
183 if not top_level_dir in sys.path:
184 # all test modules must be importable from the top level directory
185 sys.path.append(top_level_dir)
186 self._top_level_dir = top_level_dir
187
188 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
189 # what about __init__.pyc or pyo (etc)
190 raise ImportError('Start directory is not importable: %r' % start_dir)
191
192 tests = list(self._find_tests(start_dir, pattern))
193 return self.suiteClass(tests)
194
Michael Foorde91ea562009-09-13 19:07:03 +0000195 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000196 path = os.path.splitext(os.path.normpath(path))[0]
197
Michael Foorde91ea562009-09-13 19:07:03 +0000198 _relpath = os.path.relpath(path, self._top_level_dir)
199 assert not os.path.isabs(_relpath), "Path must be within the project"
200 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000201
Michael Foorde91ea562009-09-13 19:07:03 +0000202 name = _relpath.replace(os.path.sep, '.')
203 return name
204
205 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000206 __import__(name)
207 return sys.modules[name]
208
209 def _find_tests(self, start_dir, pattern):
210 """Used by discovery. Yields test suites it loads."""
211 paths = os.listdir(start_dir)
212
213 for path in paths:
214 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000215 if os.path.isfile(full_path):
216 if not VALID_MODULE_NAME.match(path):
217 # valid Python identifiers only
218 continue
219
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000220 if fnmatch(path, pattern):
221 # if the test file matches, load it
Michael Foorde91ea562009-09-13 19:07:03 +0000222 name = self._get_name_from_path(full_path)
223 try:
224 module = self._get_module_from_name(name)
225 except:
226 yield _make_failed_import_test(name, self.suiteClass)
227 else:
228 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000229 elif os.path.isdir(full_path):
230 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
231 continue
232
233 load_tests = None
234 tests = None
235 if fnmatch(path, pattern):
236 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000237 name = self._get_name_from_path(full_path)
238 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000239 load_tests = getattr(package, 'load_tests', None)
240 tests = self.loadTestsFromModule(package, use_load_tests=False)
241
242 if load_tests is None:
243 if tests is not None:
244 # tests loaded from package file
245 yield tests
246 # recurse into the package
247 for test in self._find_tests(full_path, pattern):
248 yield test
249 else:
Michael Foord73dbe042010-03-21 00:53:39 +0000250 try:
251 yield load_tests(self, tests, pattern)
252 except Exception, e:
253 yield _make_failed_load_tests(package.__name__, e,
254 self.suiteClass)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000255
256defaultTestLoader = TestLoader()
257
258
259def _makeLoader(prefix, sortUsing, suiteClass=None):
260 loader = TestLoader()
261 loader.sortTestMethodsUsing = sortUsing
262 loader.testMethodPrefix = prefix
263 if suiteClass:
264 loader.suiteClass = suiteClass
265 return loader
266
267def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
268 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
269
270def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
271 suiteClass=suite.TestSuite):
272 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
273
274def findTestCases(module, prefix='test', sortUsing=cmp,
275 suiteClass=suite.TestSuite):
276 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)