blob: 9eede8b51a48b17aafa4423bfc260f2936f8871b [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite
12
13
14def _CmpToKey(mycmp):
15 'Convert a cmp= function into a key= function'
16 class K(object):
17 def __init__(self, obj):
18 self.obj = obj
19 def __lt__(self, other):
20 return mycmp(self.obj, other.obj) == -1
21 return K
22
23
Michael Foorde91ea562009-09-13 19:07:03 +000024# what about .pyc or .pyo (etc)
25# we would need to avoid loading the same tests multiple times
26# from '.py', '.pyc' *and* '.pyo'
27VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
28
29
30def _make_failed_import_test(name, suiteClass):
31 message = 'Failed to import test module: %s' % name
32 if hasattr(traceback, 'format_exc'):
33 # Python 2.3 compatibility
34 # format_exc returns two frames of discover.py as well
35 message += '\n%s' % traceback.format_exc()
Michael Foord8cb253f2010-03-21 00:55:58 +000036 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
37 suiteClass)
Michael Foorde91ea562009-09-13 19:07:03 +000038
Michael Foord73dbe042010-03-21 00:53:39 +000039def _make_failed_load_tests(name, exception, suiteClass):
Michael Foord8cb253f2010-03-21 00:55:58 +000040 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
Michael Foord73dbe042010-03-21 00:53:39 +000041
Michael Foord8cb253f2010-03-21 00:55:58 +000042def _make_failed_test(classname, methodname, exception, suiteClass):
Michael Foord73dbe042010-03-21 00:53:39 +000043 def testFailure(self):
44 raise exception
45 attrs = {methodname: testFailure}
46 TestClass = type(classname, (case.TestCase,), attrs)
47 return suiteClass((TestClass(methodname),))
Michael Foorde91ea562009-09-13 19:07:03 +000048
49
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000050class TestLoader(object):
51 """
52 This class is responsible for loading tests according to various criteria
53 and returning them wrapped in a TestSuite
54 """
55 testMethodPrefix = 'test'
56 sortTestMethodsUsing = cmp
57 suiteClass = suite.TestSuite
58 _top_level_dir = None
59
60 def loadTestsFromTestCase(self, testCaseClass):
61 """Return a suite of all tests cases contained in testCaseClass"""
62 if issubclass(testCaseClass, suite.TestSuite):
63 raise TypeError("Test cases should not be derived from TestSuite." \
64 " Maybe you meant to derive from TestCase?")
65 testCaseNames = self.getTestCaseNames(testCaseClass)
66 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
67 testCaseNames = ['runTest']
68 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
69 return loaded_suite
70
71 def loadTestsFromModule(self, module, use_load_tests=True):
72 """Return a suite of all tests cases contained in the given module"""
73 tests = []
74 for name in dir(module):
75 obj = getattr(module, name)
76 if isinstance(obj, type) and issubclass(obj, case.TestCase):
77 tests.append(self.loadTestsFromTestCase(obj))
78
79 load_tests = getattr(module, 'load_tests', None)
Michael Foord08770602010-02-06 00:22:26 +000080 tests = self.suiteClass(tests)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000081 if use_load_tests and load_tests is not None:
Michael Foord73dbe042010-03-21 00:53:39 +000082 try:
83 return load_tests(self, tests, None)
84 except Exception, e:
85 return _make_failed_load_tests(module.__name__, e,
86 self.suiteClass)
Michael Foord08770602010-02-06 00:22:26 +000087 return tests
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000088
89 def loadTestsFromName(self, name, module=None):
90 """Return a suite of all tests cases given a string specifier.
91
92 The name may resolve either to a module, a test case class, a
93 test method within a test case class, or a callable object which
94 returns a TestCase or TestSuite instance.
95
96 The method optionally resolves the names relative to a given module.
97 """
98 parts = name.split('.')
99 if module is None:
100 parts_copy = parts[:]
101 while parts_copy:
102 try:
103 module = __import__('.'.join(parts_copy))
104 break
105 except ImportError:
106 del parts_copy[-1]
107 if not parts_copy:
108 raise
109 parts = parts[1:]
110 obj = module
111 for part in parts:
112 parent, obj = obj, getattr(obj, part)
113
114 if isinstance(obj, types.ModuleType):
115 return self.loadTestsFromModule(obj)
116 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
117 return self.loadTestsFromTestCase(obj)
118 elif (isinstance(obj, types.UnboundMethodType) and
119 isinstance(parent, type) and
120 issubclass(parent, case.TestCase)):
Michael Foord5a9719d2009-09-13 17:28:35 +0000121 return self.suiteClass([parent(obj.__name__)])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000122 elif isinstance(obj, suite.TestSuite):
123 return obj
124 elif hasattr(obj, '__call__'):
125 test = obj()
126 if isinstance(test, suite.TestSuite):
127 return test
128 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000129 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000130 else:
131 raise TypeError("calling %s returned %s, not a test" %
132 (obj, test))
133 else:
134 raise TypeError("don't know how to make test from: %s" % obj)
135
136 def loadTestsFromNames(self, names, module=None):
137 """Return a suite of all tests cases found using the given sequence
138 of string specifiers. See 'loadTestsFromName()'.
139 """
140 suites = [self.loadTestsFromName(name, module) for name in names]
141 return self.suiteClass(suites)
142
143 def getTestCaseNames(self, testCaseClass):
144 """Return a sorted sequence of method names found within testCaseClass
145 """
146 def isTestMethod(attrname, testCaseClass=testCaseClass,
147 prefix=self.testMethodPrefix):
148 return attrname.startswith(prefix) and \
149 hasattr(getattr(testCaseClass, attrname), '__call__')
150 testFnNames = filter(isTestMethod, dir(testCaseClass))
151 if self.sortTestMethodsUsing:
152 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
153 return testFnNames
154
155 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
156 """Find and return all test modules from the specified start
157 directory, recursing into subdirectories to find them. Only test files
158 that match the pattern will be loaded. (Using shell style pattern
159 matching.)
160
161 All test modules must be importable from the top level of the project.
162 If the start directory is not the top level directory then the top
163 level directory must be specified separately.
164
165 If a test package name (directory with '__init__.py') matches the
166 pattern then the package will be checked for a 'load_tests' function. If
167 this exists then it will be called with loader, tests, pattern.
168
169 If load_tests exists then discovery does *not* recurse into the package,
170 load_tests is responsible for loading all tests in the package.
171
172 The pattern is deliberately not stored as a loader attribute so that
173 packages can continue discovery themselves. top_level_dir is stored so
174 load_tests does not need to pass this argument in to loader.discover().
175 """
176 if top_level_dir is None and self._top_level_dir is not None:
177 # make top_level_dir optional if called from load_tests in a package
178 top_level_dir = self._top_level_dir
179 elif top_level_dir is None:
180 top_level_dir = start_dir
181
182 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
183 start_dir = os.path.abspath(os.path.normpath(start_dir))
184
185 if not top_level_dir in sys.path:
186 # all test modules must be importable from the top level directory
187 sys.path.append(top_level_dir)
188 self._top_level_dir = top_level_dir
189
190 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
191 # what about __init__.pyc or pyo (etc)
192 raise ImportError('Start directory is not importable: %r' % start_dir)
193
194 tests = list(self._find_tests(start_dir, pattern))
195 return self.suiteClass(tests)
196
Michael Foorde91ea562009-09-13 19:07:03 +0000197 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000198 path = os.path.splitext(os.path.normpath(path))[0]
199
Michael Foorde91ea562009-09-13 19:07:03 +0000200 _relpath = os.path.relpath(path, self._top_level_dir)
201 assert not os.path.isabs(_relpath), "Path must be within the project"
202 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000203
Michael Foorde91ea562009-09-13 19:07:03 +0000204 name = _relpath.replace(os.path.sep, '.')
205 return name
206
207 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000208 __import__(name)
209 return sys.modules[name]
210
211 def _find_tests(self, start_dir, pattern):
212 """Used by discovery. Yields test suites it loads."""
213 paths = os.listdir(start_dir)
214
215 for path in paths:
216 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000217 if os.path.isfile(full_path):
218 if not VALID_MODULE_NAME.match(path):
219 # valid Python identifiers only
220 continue
221
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000222 if fnmatch(path, pattern):
223 # if the test file matches, load it
Michael Foorde91ea562009-09-13 19:07:03 +0000224 name = self._get_name_from_path(full_path)
225 try:
226 module = self._get_module_from_name(name)
227 except:
228 yield _make_failed_import_test(name, self.suiteClass)
229 else:
230 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000231 elif os.path.isdir(full_path):
232 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
233 continue
234
235 load_tests = None
236 tests = None
237 if fnmatch(path, pattern):
238 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000239 name = self._get_name_from_path(full_path)
240 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000241 load_tests = getattr(package, 'load_tests', None)
242 tests = self.loadTestsFromModule(package, use_load_tests=False)
243
244 if load_tests is None:
245 if tests is not None:
246 # tests loaded from package file
247 yield tests
248 # recurse into the package
249 for test in self._find_tests(full_path, pattern):
250 yield test
251 else:
Michael Foord73dbe042010-03-21 00:53:39 +0000252 try:
253 yield load_tests(self, tests, pattern)
254 except Exception, e:
255 yield _make_failed_load_tests(package.__name__, e,
256 self.suiteClass)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000257
258defaultTestLoader = TestLoader()
259
260
261def _makeLoader(prefix, sortUsing, suiteClass=None):
262 loader = TestLoader()
263 loader.sortTestMethodsUsing = sortUsing
264 loader.testMethodPrefix = prefix
265 if suiteClass:
266 loader.suiteClass = suiteClass
267 return loader
268
269def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
270 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
271
272def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
273 suiteClass=suite.TestSuite):
274 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
275
276def findTestCases(module, prefix='test', sortUsing=cmp,
277 suiteClass=suite.TestSuite):
278 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)