blob: c04de062b291cb3f872bc87abcbcb7d33d361093 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite
12
13
14def _CmpToKey(mycmp):
15 'Convert a cmp= function into a key= function'
16 class K(object):
17 def __init__(self, obj):
18 self.obj = obj
19 def __lt__(self, other):
20 return mycmp(self.obj, other.obj) == -1
21 return K
22
23
Michael Foorde91ea562009-09-13 19:07:03 +000024# what about .pyc or .pyo (etc)
25# we would need to avoid loading the same tests multiple times
26# from '.py', '.pyc' *and* '.pyo'
27VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
28
29
30def _make_failed_import_test(name, suiteClass):
31 message = 'Failed to import test module: %s' % name
32 if hasattr(traceback, 'format_exc'):
33 # Python 2.3 compatibility
34 # format_exc returns two frames of discover.py as well
35 message += '\n%s' % traceback.format_exc()
36
37 def testImportFailure(self):
38 raise ImportError(message)
39 attrs = {name: testImportFailure}
40 ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs)
41 return suiteClass((ModuleImportFailure(name),))
42
43
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000044class TestLoader(object):
45 """
46 This class is responsible for loading tests according to various criteria
47 and returning them wrapped in a TestSuite
48 """
49 testMethodPrefix = 'test'
50 sortTestMethodsUsing = cmp
51 suiteClass = suite.TestSuite
52 _top_level_dir = None
53
54 def loadTestsFromTestCase(self, testCaseClass):
55 """Return a suite of all tests cases contained in testCaseClass"""
56 if issubclass(testCaseClass, suite.TestSuite):
57 raise TypeError("Test cases should not be derived from TestSuite." \
58 " Maybe you meant to derive from TestCase?")
59 testCaseNames = self.getTestCaseNames(testCaseClass)
60 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
61 testCaseNames = ['runTest']
62 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
63 return loaded_suite
64
65 def loadTestsFromModule(self, module, use_load_tests=True):
66 """Return a suite of all tests cases contained in the given module"""
67 tests = []
68 for name in dir(module):
69 obj = getattr(module, name)
70 if isinstance(obj, type) and issubclass(obj, case.TestCase):
71 tests.append(self.loadTestsFromTestCase(obj))
72
73 load_tests = getattr(module, 'load_tests', None)
Michael Foord08770602010-02-06 00:22:26 +000074 tests = self.suiteClass(tests)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000075 if use_load_tests and load_tests is not None:
76 return load_tests(self, tests, None)
Michael Foord08770602010-02-06 00:22:26 +000077 return tests
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000078
79 def loadTestsFromName(self, name, module=None):
80 """Return a suite of all tests cases given a string specifier.
81
82 The name may resolve either to a module, a test case class, a
83 test method within a test case class, or a callable object which
84 returns a TestCase or TestSuite instance.
85
86 The method optionally resolves the names relative to a given module.
87 """
88 parts = name.split('.')
89 if module is None:
90 parts_copy = parts[:]
91 while parts_copy:
92 try:
93 module = __import__('.'.join(parts_copy))
94 break
95 except ImportError:
96 del parts_copy[-1]
97 if not parts_copy:
98 raise
99 parts = parts[1:]
100 obj = module
101 for part in parts:
102 parent, obj = obj, getattr(obj, part)
103
104 if isinstance(obj, types.ModuleType):
105 return self.loadTestsFromModule(obj)
106 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
107 return self.loadTestsFromTestCase(obj)
108 elif (isinstance(obj, types.UnboundMethodType) and
109 isinstance(parent, type) and
110 issubclass(parent, case.TestCase)):
Michael Foord5a9719d2009-09-13 17:28:35 +0000111 return self.suiteClass([parent(obj.__name__)])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000112 elif isinstance(obj, suite.TestSuite):
113 return obj
114 elif hasattr(obj, '__call__'):
115 test = obj()
116 if isinstance(test, suite.TestSuite):
117 return test
118 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000119 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000120 else:
121 raise TypeError("calling %s returned %s, not a test" %
122 (obj, test))
123 else:
124 raise TypeError("don't know how to make test from: %s" % obj)
125
126 def loadTestsFromNames(self, names, module=None):
127 """Return a suite of all tests cases found using the given sequence
128 of string specifiers. See 'loadTestsFromName()'.
129 """
130 suites = [self.loadTestsFromName(name, module) for name in names]
131 return self.suiteClass(suites)
132
133 def getTestCaseNames(self, testCaseClass):
134 """Return a sorted sequence of method names found within testCaseClass
135 """
136 def isTestMethod(attrname, testCaseClass=testCaseClass,
137 prefix=self.testMethodPrefix):
138 return attrname.startswith(prefix) and \
139 hasattr(getattr(testCaseClass, attrname), '__call__')
140 testFnNames = filter(isTestMethod, dir(testCaseClass))
141 if self.sortTestMethodsUsing:
142 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
143 return testFnNames
144
145 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
146 """Find and return all test modules from the specified start
147 directory, recursing into subdirectories to find them. Only test files
148 that match the pattern will be loaded. (Using shell style pattern
149 matching.)
150
151 All test modules must be importable from the top level of the project.
152 If the start directory is not the top level directory then the top
153 level directory must be specified separately.
154
155 If a test package name (directory with '__init__.py') matches the
156 pattern then the package will be checked for a 'load_tests' function. If
157 this exists then it will be called with loader, tests, pattern.
158
159 If load_tests exists then discovery does *not* recurse into the package,
160 load_tests is responsible for loading all tests in the package.
161
162 The pattern is deliberately not stored as a loader attribute so that
163 packages can continue discovery themselves. top_level_dir is stored so
164 load_tests does not need to pass this argument in to loader.discover().
165 """
166 if top_level_dir is None and self._top_level_dir is not None:
167 # make top_level_dir optional if called from load_tests in a package
168 top_level_dir = self._top_level_dir
169 elif top_level_dir is None:
170 top_level_dir = start_dir
171
172 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
173 start_dir = os.path.abspath(os.path.normpath(start_dir))
174
175 if not top_level_dir in sys.path:
176 # all test modules must be importable from the top level directory
177 sys.path.append(top_level_dir)
178 self._top_level_dir = top_level_dir
179
180 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
181 # what about __init__.pyc or pyo (etc)
182 raise ImportError('Start directory is not importable: %r' % start_dir)
183
184 tests = list(self._find_tests(start_dir, pattern))
185 return self.suiteClass(tests)
186
Michael Foorde91ea562009-09-13 19:07:03 +0000187 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000188 path = os.path.splitext(os.path.normpath(path))[0]
189
Michael Foorde91ea562009-09-13 19:07:03 +0000190 _relpath = os.path.relpath(path, self._top_level_dir)
191 assert not os.path.isabs(_relpath), "Path must be within the project"
192 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000193
Michael Foorde91ea562009-09-13 19:07:03 +0000194 name = _relpath.replace(os.path.sep, '.')
195 return name
196
197 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000198 __import__(name)
199 return sys.modules[name]
200
201 def _find_tests(self, start_dir, pattern):
202 """Used by discovery. Yields test suites it loads."""
203 paths = os.listdir(start_dir)
204
205 for path in paths:
206 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000207 if os.path.isfile(full_path):
208 if not VALID_MODULE_NAME.match(path):
209 # valid Python identifiers only
210 continue
211
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000212 if fnmatch(path, pattern):
213 # if the test file matches, load it
Michael Foorde91ea562009-09-13 19:07:03 +0000214 name = self._get_name_from_path(full_path)
215 try:
216 module = self._get_module_from_name(name)
217 except:
218 yield _make_failed_import_test(name, self.suiteClass)
219 else:
220 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000221 elif os.path.isdir(full_path):
222 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
223 continue
224
225 load_tests = None
226 tests = None
227 if fnmatch(path, pattern):
228 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000229 name = self._get_name_from_path(full_path)
230 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000231 load_tests = getattr(package, 'load_tests', None)
232 tests = self.loadTestsFromModule(package, use_load_tests=False)
233
234 if load_tests is None:
235 if tests is not None:
236 # tests loaded from package file
237 yield tests
238 # recurse into the package
239 for test in self._find_tests(full_path, pattern):
240 yield test
241 else:
242 yield load_tests(self, tests, pattern)
243
244defaultTestLoader = TestLoader()
245
246
247def _makeLoader(prefix, sortUsing, suiteClass=None):
248 loader = TestLoader()
249 loader.sortTestMethodsUsing = sortUsing
250 loader.testMethodPrefix = prefix
251 if suiteClass:
252 loader.suiteClass = suiteClass
253 return loader
254
255def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
256 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
257
258def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
259 suiteClass=suite.TestSuite):
260 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
261
262def findTestCases(module, prefix='test', sortUsing=cmp,
263 suiteClass=suite.TestSuite):
264 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)