blob: 31c343b49b1e5f9d4badda68755e584b6ea02a17 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
Michael Foorde91ea562009-09-13 19:07:03 +00004import re
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00005import sys
Michael Foorde91ea562009-09-13 19:07:03 +00006import traceback
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00007import types
8
9from fnmatch import fnmatch
10
11from . import case, suite
12
13
14def _CmpToKey(mycmp):
15 'Convert a cmp= function into a key= function'
16 class K(object):
17 def __init__(self, obj):
18 self.obj = obj
19 def __lt__(self, other):
20 return mycmp(self.obj, other.obj) == -1
21 return K
22
23
Michael Foorde91ea562009-09-13 19:07:03 +000024# what about .pyc or .pyo (etc)
25# we would need to avoid loading the same tests multiple times
26# from '.py', '.pyc' *and* '.pyo'
27VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
28
29
30def _make_failed_import_test(name, suiteClass):
31 message = 'Failed to import test module: %s' % name
32 if hasattr(traceback, 'format_exc'):
33 # Python 2.3 compatibility
34 # format_exc returns two frames of discover.py as well
35 message += '\n%s' % traceback.format_exc()
36
37 def testImportFailure(self):
38 raise ImportError(message)
39 attrs = {name: testImportFailure}
40 ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs)
41 return suiteClass((ModuleImportFailure(name),))
42
43
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +000044class TestLoader(object):
45 """
46 This class is responsible for loading tests according to various criteria
47 and returning them wrapped in a TestSuite
48 """
49 testMethodPrefix = 'test'
50 sortTestMethodsUsing = cmp
51 suiteClass = suite.TestSuite
52 _top_level_dir = None
53
54 def loadTestsFromTestCase(self, testCaseClass):
55 """Return a suite of all tests cases contained in testCaseClass"""
56 if issubclass(testCaseClass, suite.TestSuite):
57 raise TypeError("Test cases should not be derived from TestSuite." \
58 " Maybe you meant to derive from TestCase?")
59 testCaseNames = self.getTestCaseNames(testCaseClass)
60 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
61 testCaseNames = ['runTest']
62 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
63 return loaded_suite
64
65 def loadTestsFromModule(self, module, use_load_tests=True):
66 """Return a suite of all tests cases contained in the given module"""
67 tests = []
68 for name in dir(module):
69 obj = getattr(module, name)
70 if isinstance(obj, type) and issubclass(obj, case.TestCase):
71 tests.append(self.loadTestsFromTestCase(obj))
72
73 load_tests = getattr(module, 'load_tests', None)
74 if use_load_tests and load_tests is not None:
75 return load_tests(self, tests, None)
76 return self.suiteClass(tests)
77
78 def loadTestsFromName(self, name, module=None):
79 """Return a suite of all tests cases given a string specifier.
80
81 The name may resolve either to a module, a test case class, a
82 test method within a test case class, or a callable object which
83 returns a TestCase or TestSuite instance.
84
85 The method optionally resolves the names relative to a given module.
86 """
87 parts = name.split('.')
88 if module is None:
89 parts_copy = parts[:]
90 while parts_copy:
91 try:
92 module = __import__('.'.join(parts_copy))
93 break
94 except ImportError:
95 del parts_copy[-1]
96 if not parts_copy:
97 raise
98 parts = parts[1:]
99 obj = module
100 for part in parts:
101 parent, obj = obj, getattr(obj, part)
102
103 if isinstance(obj, types.ModuleType):
104 return self.loadTestsFromModule(obj)
105 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
106 return self.loadTestsFromTestCase(obj)
107 elif (isinstance(obj, types.UnboundMethodType) and
108 isinstance(parent, type) and
109 issubclass(parent, case.TestCase)):
Michael Foord5a9719d2009-09-13 17:28:35 +0000110 return self.suiteClass([parent(obj.__name__)])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000111 elif isinstance(obj, suite.TestSuite):
112 return obj
113 elif hasattr(obj, '__call__'):
114 test = obj()
115 if isinstance(test, suite.TestSuite):
116 return test
117 elif isinstance(test, case.TestCase):
Michael Foord5a9719d2009-09-13 17:28:35 +0000118 return self.suiteClass([test])
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000119 else:
120 raise TypeError("calling %s returned %s, not a test" %
121 (obj, test))
122 else:
123 raise TypeError("don't know how to make test from: %s" % obj)
124
125 def loadTestsFromNames(self, names, module=None):
126 """Return a suite of all tests cases found using the given sequence
127 of string specifiers. See 'loadTestsFromName()'.
128 """
129 suites = [self.loadTestsFromName(name, module) for name in names]
130 return self.suiteClass(suites)
131
132 def getTestCaseNames(self, testCaseClass):
133 """Return a sorted sequence of method names found within testCaseClass
134 """
135 def isTestMethod(attrname, testCaseClass=testCaseClass,
136 prefix=self.testMethodPrefix):
137 return attrname.startswith(prefix) and \
138 hasattr(getattr(testCaseClass, attrname), '__call__')
139 testFnNames = filter(isTestMethod, dir(testCaseClass))
140 if self.sortTestMethodsUsing:
141 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
142 return testFnNames
143
144 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
145 """Find and return all test modules from the specified start
146 directory, recursing into subdirectories to find them. Only test files
147 that match the pattern will be loaded. (Using shell style pattern
148 matching.)
149
150 All test modules must be importable from the top level of the project.
151 If the start directory is not the top level directory then the top
152 level directory must be specified separately.
153
154 If a test package name (directory with '__init__.py') matches the
155 pattern then the package will be checked for a 'load_tests' function. If
156 this exists then it will be called with loader, tests, pattern.
157
158 If load_tests exists then discovery does *not* recurse into the package,
159 load_tests is responsible for loading all tests in the package.
160
161 The pattern is deliberately not stored as a loader attribute so that
162 packages can continue discovery themselves. top_level_dir is stored so
163 load_tests does not need to pass this argument in to loader.discover().
164 """
165 if top_level_dir is None and self._top_level_dir is not None:
166 # make top_level_dir optional if called from load_tests in a package
167 top_level_dir = self._top_level_dir
168 elif top_level_dir is None:
169 top_level_dir = start_dir
170
171 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
172 start_dir = os.path.abspath(os.path.normpath(start_dir))
173
174 if not top_level_dir in sys.path:
175 # all test modules must be importable from the top level directory
176 sys.path.append(top_level_dir)
177 self._top_level_dir = top_level_dir
178
179 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
180 # what about __init__.pyc or pyo (etc)
181 raise ImportError('Start directory is not importable: %r' % start_dir)
182
183 tests = list(self._find_tests(start_dir, pattern))
184 return self.suiteClass(tests)
185
Michael Foorde91ea562009-09-13 19:07:03 +0000186 def _get_name_from_path(self, path):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000187 path = os.path.splitext(os.path.normpath(path))[0]
188
Michael Foorde91ea562009-09-13 19:07:03 +0000189 _relpath = os.path.relpath(path, self._top_level_dir)
190 assert not os.path.isabs(_relpath), "Path must be within the project"
191 assert not _relpath.startswith('..'), "Path must be within the project"
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000192
Michael Foorde91ea562009-09-13 19:07:03 +0000193 name = _relpath.replace(os.path.sep, '.')
194 return name
195
196 def _get_module_from_name(self, name):
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000197 __import__(name)
198 return sys.modules[name]
199
200 def _find_tests(self, start_dir, pattern):
201 """Used by discovery. Yields test suites it loads."""
202 paths = os.listdir(start_dir)
203
204 for path in paths:
205 full_path = os.path.join(start_dir, path)
Michael Foorde91ea562009-09-13 19:07:03 +0000206 if os.path.isfile(full_path):
207 if not VALID_MODULE_NAME.match(path):
208 # valid Python identifiers only
209 continue
210
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000211 if fnmatch(path, pattern):
212 # if the test file matches, load it
Michael Foorde91ea562009-09-13 19:07:03 +0000213 name = self._get_name_from_path(full_path)
214 try:
215 module = self._get_module_from_name(name)
216 except:
217 yield _make_failed_import_test(name, self.suiteClass)
218 else:
219 yield self.loadTestsFromModule(module)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000220 elif os.path.isdir(full_path):
221 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
222 continue
223
224 load_tests = None
225 tests = None
226 if fnmatch(path, pattern):
227 # only check load_tests if the package directory itself matches the filter
Michael Foorde91ea562009-09-13 19:07:03 +0000228 name = self._get_name_from_path(full_path)
229 package = self._get_module_from_name(name)
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +0000230 load_tests = getattr(package, 'load_tests', None)
231 tests = self.loadTestsFromModule(package, use_load_tests=False)
232
233 if load_tests is None:
234 if tests is not None:
235 # tests loaded from package file
236 yield tests
237 # recurse into the package
238 for test in self._find_tests(full_path, pattern):
239 yield test
240 else:
241 yield load_tests(self, tests, pattern)
242
243defaultTestLoader = TestLoader()
244
245
246def _makeLoader(prefix, sortUsing, suiteClass=None):
247 loader = TestLoader()
248 loader.sortTestMethodsUsing = sortUsing
249 loader.testMethodPrefix = prefix
250 if suiteClass:
251 loader.suiteClass = suiteClass
252 return loader
253
254def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
255 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
256
257def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
258 suiteClass=suite.TestSuite):
259 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
260
261def findTestCases(module, prefix='test', sortUsing=cmp,
262 suiteClass=suite.TestSuite):
263 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)