blob: d6c8d8424519e26ee24c175ade9c044e9a17dbd7 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite
12
Michael Foordb1aa30f2010-03-22 00:06:30 +000013__unittest = True
14
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000015
16def _CmpToKey(mycmp):
17 'Convert a cmp= function into a key= function'
18 class K(object):
19 def __init__(self, obj):
20 self.obj = obj
21 def __lt__(self, other):
22 return mycmp(self.obj, other.obj) == -1
23 return K
24
25
Michael Foorde91ea562009-09-13 19:07:03 +000026# what about .pyc or .pyo (etc)
27# we would need to avoid loading the same tests multiple times
28# from '.py', '.pyc' *and* '.pyo'
29VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
30
31
32def _make_failed_import_test(name, suiteClass):
33 message = 'Failed to import test module: %s' % name
34 if hasattr(traceback, 'format_exc'):
35 # Python 2.3 compatibility
36 # format_exc returns two frames of discover.py as well
37 message += '\n%s' % traceback.format_exc()
Michael Foord8cb253f2010-03-21 00:55:58 +000038 return _make_failed_test('ModuleImportFailure', name, ImportError(message),
39 suiteClass)
Michael Foorde91ea562009-09-13 19:07:03 +000040
Michael Foord73dbe042010-03-21 00:53:39 +000041def _make_failed_load_tests(name, exception, suiteClass):
Michael Foord8cb253f2010-03-21 00:55:58 +000042 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
Michael Foord73dbe042010-03-21 00:53:39 +000043
Michael Foord8cb253f2010-03-21 00:55:58 +000044def _make_failed_test(classname, methodname, exception, suiteClass):
Michael Foord73dbe042010-03-21 00:53:39 +000045 def testFailure(self):
46 raise exception
47 attrs = {methodname: testFailure}
48 TestClass = type(classname, (case.TestCase,), attrs)
49 return suiteClass((TestClass(methodname),))
Michael Foorde91ea562009-09-13 19:07:03 +000050
51
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000052class TestLoader(object):
53 """
54 This class is responsible for loading tests according to various criteria
55 and returning them wrapped in a TestSuite
56 """
57 testMethodPrefix = 'test'
58 sortTestMethodsUsing = cmp
59 suiteClass = suite.TestSuite
60 _top_level_dir = None
61
62 def loadTestsFromTestCase(self, testCaseClass):
63 """Return a suite of all tests cases contained in testCaseClass"""
64 if issubclass(testCaseClass, suite.TestSuite):
65 raise TypeError("Test cases should not be derived from TestSuite." \
66 " Maybe you meant to derive from TestCase?")
67 testCaseNames = self.getTestCaseNames(testCaseClass)
68 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
69 testCaseNames = ['runTest']
70 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
71 return loaded_suite
72
73 def loadTestsFromModule(self, module, use_load_tests=True):
74 """Return a suite of all tests cases contained in the given module"""
75 tests = []
76 for name in dir(module):
77 obj = getattr(module, name)
78 if isinstance(obj, type) and issubclass(obj, case.TestCase):
79 tests.append(self.loadTestsFromTestCase(obj))
80
81 load_tests = getattr(module, 'load_tests', None)
Michael Foord08770602010-02-06 00:22:26 +000082 tests = self.suiteClass(tests)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000083 if use_load_tests and load_tests is not None:
Michael Foord73dbe042010-03-21 00:53:39 +000084 try:
85 return load_tests(self, tests, None)
86 except Exception, e:
87 return _make_failed_load_tests(module.__name__, e,
88 self.suiteClass)
Michael Foord08770602010-02-06 00:22:26 +000089 return tests
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000090
91 def loadTestsFromName(self, name, module=None):
92 """Return a suite of all tests cases given a string specifier.
93
94 The name may resolve either to a module, a test case class, a
95 test method within a test case class, or a callable object which
96 returns a TestCase or TestSuite instance.
97
98 The method optionally resolves the names relative to a given module.
99 """
100 parts = name.split('.')
101 if module is None:
102 parts_copy = parts[:]
103 while parts_copy:
104 try:
105 module = __import__('.'.join(parts_copy))
106 break
107 except ImportError:
108 del parts_copy[-1]
109 if not parts_copy:
110 raise
111 parts = parts[1:]
112 obj = module
113 for part in parts:
114 parent, obj = obj, getattr(obj, part)
115
116 if isinstance(obj, types.ModuleType):
117 return self.loadTestsFromModule(obj)
118 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
119 return self.loadTestsFromTestCase(obj)
120 elif (isinstance(obj, types.UnboundMethodType) and
121 isinstance(parent, type) and
122 issubclass(parent, case.TestCase)):
Michael Foord5a9719d2009-09-13 17:28:35 +0000123 return self.suiteClass([parent(obj.__name__)])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000124 elif isinstance(obj, suite.TestSuite):
125 return obj
126 elif hasattr(obj, '__call__'):
127 test = obj()
128 if isinstance(test, suite.TestSuite):
129 return test
130 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000131 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000132 else:
133 raise TypeError("calling %s returned %s, not a test" %
134 (obj, test))
135 else:
136 raise TypeError("don't know how to make test from: %s" % obj)
137
138 def loadTestsFromNames(self, names, module=None):
139 """Return a suite of all tests cases found using the given sequence
140 of string specifiers. See 'loadTestsFromName()'.
141 """
142 suites = [self.loadTestsFromName(name, module) for name in names]
143 return self.suiteClass(suites)
144
145 def getTestCaseNames(self, testCaseClass):
146 """Return a sorted sequence of method names found within testCaseClass
147 """
148 def isTestMethod(attrname, testCaseClass=testCaseClass,
149 prefix=self.testMethodPrefix):
150 return attrname.startswith(prefix) and \
151 hasattr(getattr(testCaseClass, attrname), '__call__')
152 testFnNames = filter(isTestMethod, dir(testCaseClass))
153 if self.sortTestMethodsUsing:
154 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
155 return testFnNames
156
157 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
158 """Find and return all test modules from the specified start
159 directory, recursing into subdirectories to find them. Only test files
160 that match the pattern will be loaded. (Using shell style pattern
161 matching.)
162
163 All test modules must be importable from the top level of the project.
164 If the start directory is not the top level directory then the top
165 level directory must be specified separately.
166
167 If a test package name (directory with '__init__.py') matches the
168 pattern then the package will be checked for a 'load_tests' function. If
169 this exists then it will be called with loader, tests, pattern.
170
171 If load_tests exists then discovery does *not* recurse into the package,
172 load_tests is responsible for loading all tests in the package.
173
174 The pattern is deliberately not stored as a loader attribute so that
175 packages can continue discovery themselves. top_level_dir is stored so
176 load_tests does not need to pass this argument in to loader.discover().
177 """
178 if top_level_dir is None and self._top_level_dir is not None:
179 # make top_level_dir optional if called from load_tests in a package
180 top_level_dir = self._top_level_dir
181 elif top_level_dir is None:
182 top_level_dir = start_dir
183
184 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
185 start_dir = os.path.abspath(os.path.normpath(start_dir))
186
187 if not top_level_dir in sys.path:
188 # all test modules must be importable from the top level directory
189 sys.path.append(top_level_dir)
190 self._top_level_dir = top_level_dir
191
192 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
193 # what about __init__.pyc or pyo (etc)
194 raise ImportError('Start directory is not importable: %r' % start_dir)
195
196 tests = list(self._find_tests(start_dir, pattern))
197 return self.suiteClass(tests)
198
Michael Foorde91ea562009-09-13 19:07:03 +0000199 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000200 path = os.path.splitext(os.path.normpath(path))[0]
201
Michael Foorde91ea562009-09-13 19:07:03 +0000202 _relpath = os.path.relpath(path, self._top_level_dir)
203 assert not os.path.isabs(_relpath), "Path must be within the project"
204 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000205
Michael Foorde91ea562009-09-13 19:07:03 +0000206 name = _relpath.replace(os.path.sep, '.')
207 return name
208
209 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000210 __import__(name)
211 return sys.modules[name]
212
213 def _find_tests(self, start_dir, pattern):
214 """Used by discovery. Yields test suites it loads."""
215 paths = os.listdir(start_dir)
216
217 for path in paths:
218 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000219 if os.path.isfile(full_path):
220 if not VALID_MODULE_NAME.match(path):
221 # valid Python identifiers only
222 continue
223
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000224 if fnmatch(path, pattern):
225 # if the test file matches, load it
Michael Foorde91ea562009-09-13 19:07:03 +0000226 name = self._get_name_from_path(full_path)
227 try:
228 module = self._get_module_from_name(name)
229 except:
230 yield _make_failed_import_test(name, self.suiteClass)
231 else:
232 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000233 elif os.path.isdir(full_path):
234 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
235 continue
236
237 load_tests = None
238 tests = None
239 if fnmatch(path, pattern):
240 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000241 name = self._get_name_from_path(full_path)
242 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000243 load_tests = getattr(package, 'load_tests', None)
244 tests = self.loadTestsFromModule(package, use_load_tests=False)
245
246 if load_tests is None:
247 if tests is not None:
248 # tests loaded from package file
249 yield tests
250 # recurse into the package
251 for test in self._find_tests(full_path, pattern):
252 yield test
253 else:
Michael Foord73dbe042010-03-21 00:53:39 +0000254 try:
255 yield load_tests(self, tests, pattern)
256 except Exception, e:
257 yield _make_failed_load_tests(package.__name__, e,
258 self.suiteClass)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000259
260defaultTestLoader = TestLoader()
261
262
263def _makeLoader(prefix, sortUsing, suiteClass=None):
264 loader = TestLoader()
265 loader.sortTestMethodsUsing = sortUsing
266 loader.testMethodPrefix = prefix
267 if suiteClass:
268 loader.suiteClass = suiteClass
269 return loader
270
271def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
272 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
273
274def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
275 suiteClass=suite.TestSuite):
276 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
277
278def findTestCases(module, prefix='test', sortUsing=cmp,
279 suiteClass=suite.TestSuite):
280 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)