blob: 21c7ed02b196ef04aa4672c15b0aa113e9e1b040 [file] [log] [blame]
Benjamin Petersond7b0eeb2009-07-19 20:18:21 +00001"""Loading unittests."""
2
3import os
4import sys
5import types
6
7from fnmatch import fnmatch
8
9from . import case, suite
10
11
12def _CmpToKey(mycmp):
13 'Convert a cmp= function into a key= function'
14 class K(object):
15 def __init__(self, obj):
16 self.obj = obj
17 def __lt__(self, other):
18 return mycmp(self.obj, other.obj) == -1
19 return K
20
21
22class TestLoader(object):
23 """
24 This class is responsible for loading tests according to various criteria
25 and returning them wrapped in a TestSuite
26 """
27 testMethodPrefix = 'test'
28 sortTestMethodsUsing = cmp
29 suiteClass = suite.TestSuite
30 _top_level_dir = None
31
32 def loadTestsFromTestCase(self, testCaseClass):
33 """Return a suite of all tests cases contained in testCaseClass"""
34 if issubclass(testCaseClass, suite.TestSuite):
35 raise TypeError("Test cases should not be derived from TestSuite." \
36 " Maybe you meant to derive from TestCase?")
37 testCaseNames = self.getTestCaseNames(testCaseClass)
38 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
39 testCaseNames = ['runTest']
40 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
41 return loaded_suite
42
43 def loadTestsFromModule(self, module, use_load_tests=True):
44 """Return a suite of all tests cases contained in the given module"""
45 tests = []
46 for name in dir(module):
47 obj = getattr(module, name)
48 if isinstance(obj, type) and issubclass(obj, case.TestCase):
49 tests.append(self.loadTestsFromTestCase(obj))
50
51 load_tests = getattr(module, 'load_tests', None)
52 if use_load_tests and load_tests is not None:
53 return load_tests(self, tests, None)
54 return self.suiteClass(tests)
55
56 def loadTestsFromName(self, name, module=None):
57 """Return a suite of all tests cases given a string specifier.
58
59 The name may resolve either to a module, a test case class, a
60 test method within a test case class, or a callable object which
61 returns a TestCase or TestSuite instance.
62
63 The method optionally resolves the names relative to a given module.
64 """
65 parts = name.split('.')
66 if module is None:
67 parts_copy = parts[:]
68 while parts_copy:
69 try:
70 module = __import__('.'.join(parts_copy))
71 break
72 except ImportError:
73 del parts_copy[-1]
74 if not parts_copy:
75 raise
76 parts = parts[1:]
77 obj = module
78 for part in parts:
79 parent, obj = obj, getattr(obj, part)
80
81 if isinstance(obj, types.ModuleType):
82 return self.loadTestsFromModule(obj)
83 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
84 return self.loadTestsFromTestCase(obj)
85 elif (isinstance(obj, types.UnboundMethodType) and
86 isinstance(parent, type) and
87 issubclass(parent, case.TestCase)):
88 return suite.TestSuite([parent(obj.__name__)])
89 elif isinstance(obj, suite.TestSuite):
90 return obj
91 elif hasattr(obj, '__call__'):
92 test = obj()
93 if isinstance(test, suite.TestSuite):
94 return test
95 elif isinstance(test, case.TestCase):
96 return suite.TestSuite([test])
97 else:
98 raise TypeError("calling %s returned %s, not a test" %
99 (obj, test))
100 else:
101 raise TypeError("don't know how to make test from: %s" % obj)
102
103 def loadTestsFromNames(self, names, module=None):
104 """Return a suite of all tests cases found using the given sequence
105 of string specifiers. See 'loadTestsFromName()'.
106 """
107 suites = [self.loadTestsFromName(name, module) for name in names]
108 return self.suiteClass(suites)
109
110 def getTestCaseNames(self, testCaseClass):
111 """Return a sorted sequence of method names found within testCaseClass
112 """
113 def isTestMethod(attrname, testCaseClass=testCaseClass,
114 prefix=self.testMethodPrefix):
115 return attrname.startswith(prefix) and \
116 hasattr(getattr(testCaseClass, attrname), '__call__')
117 testFnNames = filter(isTestMethod, dir(testCaseClass))
118 if self.sortTestMethodsUsing:
119 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
120 return testFnNames
121
122 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
123 """Find and return all test modules from the specified start
124 directory, recursing into subdirectories to find them. Only test files
125 that match the pattern will be loaded. (Using shell style pattern
126 matching.)
127
128 All test modules must be importable from the top level of the project.
129 If the start directory is not the top level directory then the top
130 level directory must be specified separately.
131
132 If a test package name (directory with '__init__.py') matches the
133 pattern then the package will be checked for a 'load_tests' function. If
134 this exists then it will be called with loader, tests, pattern.
135
136 If load_tests exists then discovery does *not* recurse into the package,
137 load_tests is responsible for loading all tests in the package.
138
139 The pattern is deliberately not stored as a loader attribute so that
140 packages can continue discovery themselves. top_level_dir is stored so
141 load_tests does not need to pass this argument in to loader.discover().
142 """
143 if top_level_dir is None and self._top_level_dir is not None:
144 # make top_level_dir optional if called from load_tests in a package
145 top_level_dir = self._top_level_dir
146 elif top_level_dir is None:
147 top_level_dir = start_dir
148
149 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
150 start_dir = os.path.abspath(os.path.normpath(start_dir))
151
152 if not top_level_dir in sys.path:
153 # all test modules must be importable from the top level directory
154 sys.path.append(top_level_dir)
155 self._top_level_dir = top_level_dir
156
157 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
158 # what about __init__.pyc or pyo (etc)
159 raise ImportError('Start directory is not importable: %r' % start_dir)
160
161 tests = list(self._find_tests(start_dir, pattern))
162 return self.suiteClass(tests)
163
164
165 def _get_module_from_path(self, path):
166 """Load a module from a path relative to the top-level directory
167 of a project. Used by discovery."""
168 path = os.path.splitext(os.path.normpath(path))[0]
169
170 relpath = os.path.relpath(path, self._top_level_dir)
171 assert not os.path.isabs(relpath), "Path must be within the project"
172 assert not relpath.startswith('..'), "Path must be within the project"
173
174 name = relpath.replace(os.path.sep, '.')
175 __import__(name)
176 return sys.modules[name]
177
178 def _find_tests(self, start_dir, pattern):
179 """Used by discovery. Yields test suites it loads."""
180 paths = os.listdir(start_dir)
181
182 for path in paths:
183 full_path = os.path.join(start_dir, path)
184 # what about __init__.pyc or pyo (etc)
185 # we would need to avoid loading the same tests multiple times
186 # from '.py', '.pyc' *and* '.pyo'
187 if os.path.isfile(full_path) and path.lower().endswith('.py'):
188 if fnmatch(path, pattern):
189 # if the test file matches, load it
190 module = self._get_module_from_path(full_path)
191 yield self.loadTestsFromModule(module)
192 elif os.path.isdir(full_path):
193 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
194 continue
195
196 load_tests = None
197 tests = None
198 if fnmatch(path, pattern):
199 # only check load_tests if the package directory itself matches the filter
200 package = self._get_module_from_path(full_path)
201 load_tests = getattr(package, 'load_tests', None)
202 tests = self.loadTestsFromModule(package, use_load_tests=False)
203
204 if load_tests is None:
205 if tests is not None:
206 # tests loaded from package file
207 yield tests
208 # recurse into the package
209 for test in self._find_tests(full_path, pattern):
210 yield test
211 else:
212 yield load_tests(self, tests, pattern)
213
214defaultTestLoader = TestLoader()
215
216
217def _makeLoader(prefix, sortUsing, suiteClass=None):
218 loader = TestLoader()
219 loader.sortTestMethodsUsing = sortUsing
220 loader.testMethodPrefix = prefix
221 if suiteClass:
222 loader.suiteClass = suiteClass
223 return loader
224
225def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
226 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
227
228def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
229 suiteClass=suite.TestSuite):
230 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
231
232def findTestCases(module, prefix='test', sortUsing=cmp,
233 suiteClass=suite.TestSuite):
234 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)