blob: c687b1bb552bdff477395249c2addff86863bb1e [file] [log] [blame]
Benjamin Petersonbed7d042009-07-19 21:01:52 +00001"""Loading unittests."""
2
3import os
4import sys
5import types
6
7from fnmatch import fnmatch
8
9from . import case, suite, util
10
11
12class TestLoader(object):
13 """
14 This class is responsible for loading tests according to various criteria
15 and returning them wrapped in a TestSuite
16 """
17 testMethodPrefix = 'test'
18 sortTestMethodsUsing = staticmethod(util.three_way_cmp)
19 suiteClass = suite.TestSuite
20 _top_level_dir = None
21
22 def loadTestsFromTestCase(self, testCaseClass):
23 """Return a suite of all tests cases contained in testCaseClass"""
24 if issubclass(testCaseClass, suite.TestSuite):
25 raise TypeError("Test cases should not be derived from TestSuite." \
26 " Maybe you meant to derive from TestCase?")
27 testCaseNames = self.getTestCaseNames(testCaseClass)
28 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
29 testCaseNames = ['runTest']
30 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
31 return loaded_suite
32
33 def loadTestsFromModule(self, module, use_load_tests=True):
34 """Return a suite of all tests cases contained in the given module"""
35 tests = []
36 for name in dir(module):
37 obj = getattr(module, name)
38 if isinstance(obj, type) and issubclass(obj, case.TestCase):
39 tests.append(self.loadTestsFromTestCase(obj))
40
41 load_tests = getattr(module, 'load_tests', None)
42 if use_load_tests and load_tests is not None:
43 return load_tests(self, tests, None)
44 return self.suiteClass(tests)
45
46 def loadTestsFromName(self, name, module=None):
47 """Return a suite of all tests cases given a string specifier.
48
49 The name may resolve either to a module, a test case class, a
50 test method within a test case class, or a callable object which
51 returns a TestCase or TestSuite instance.
52
53 The method optionally resolves the names relative to a given module.
54 """
55 parts = name.split('.')
56 if module is None:
57 parts_copy = parts[:]
58 while parts_copy:
59 try:
60 module = __import__('.'.join(parts_copy))
61 break
62 except ImportError:
63 del parts_copy[-1]
64 if not parts_copy:
65 raise
66 parts = parts[1:]
67 obj = module
68 for part in parts:
69 parent, obj = obj, getattr(obj, part)
70
71 if isinstance(obj, types.ModuleType):
72 return self.loadTestsFromModule(obj)
73 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
74 return self.loadTestsFromTestCase(obj)
75 elif (isinstance(obj, types.FunctionType) and
76 isinstance(parent, type) and
77 issubclass(parent, case.TestCase)):
78 name = obj.__name__
79 inst = parent(name)
80 # static methods follow a different path
81 if not isinstance(getattr(inst, name), types.FunctionType):
82 return suite.TestSuite([inst])
83 elif isinstance(obj, suite.TestSuite):
84 return obj
85 if hasattr(obj, '__call__'):
86 test = obj()
87 if isinstance(test, suite.TestSuite):
88 return test
89 elif isinstance(test, case.TestCase):
90 return suite.TestSuite([test])
91 else:
92 raise TypeError("calling %s returned %s, not a test" %
93 (obj, test))
94 else:
95 raise TypeError("don't know how to make test from: %s" % obj)
96
97 def loadTestsFromNames(self, names, module=None):
98 """Return a suite of all tests cases found using the given sequence
99 of string specifiers. See 'loadTestsFromName()'.
100 """
101 suites = [self.loadTestsFromName(name, module) for name in names]
102 return self.suiteClass(suites)
103
104 def getTestCaseNames(self, testCaseClass):
105 """Return a sorted sequence of method names found within testCaseClass
106 """
107 def isTestMethod(attrname, testCaseClass=testCaseClass,
108 prefix=self.testMethodPrefix):
109 return attrname.startswith(prefix) and \
110 hasattr(getattr(testCaseClass, attrname), '__call__')
111 testFnNames = testFnNames = list(filter(isTestMethod,
112 dir(testCaseClass)))
113 if self.sortTestMethodsUsing:
114 testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing))
115 return testFnNames
116
117 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
118 """Find and return all test modules from the specified start
119 directory, recursing into subdirectories to find them. Only test files
120 that match the pattern will be loaded. (Using shell style pattern
121 matching.)
122
123 All test modules must be importable from the top level of the project.
124 If the start directory is not the top level directory then the top
125 level directory must be specified separately.
126
127 If a test package name (directory with '__init__.py') matches the
128 pattern then the package will be checked for a 'load_tests' function. If
129 this exists then it will be called with loader, tests, pattern.
130
131 If load_tests exists then discovery does *not* recurse into the package,
132 load_tests is responsible for loading all tests in the package.
133
134 The pattern is deliberately not stored as a loader attribute so that
135 packages can continue discovery themselves. top_level_dir is stored so
136 load_tests does not need to pass this argument in to loader.discover().
137 """
138 if top_level_dir is None and self._top_level_dir is not None:
139 # make top_level_dir optional if called from load_tests in a package
140 top_level_dir = self._top_level_dir
141 elif top_level_dir is None:
142 top_level_dir = start_dir
143
144 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
145 start_dir = os.path.abspath(os.path.normpath(start_dir))
146
147 if not top_level_dir in sys.path:
148 # all test modules must be importable from the top level directory
149 sys.path.append(top_level_dir)
150 self._top_level_dir = top_level_dir
151
152 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
153 # what about __init__.pyc or pyo (etc)
154 raise ImportError('Start directory is not importable: %r' % start_dir)
155
156 tests = list(self._find_tests(start_dir, pattern))
157 return self.suiteClass(tests)
158
159
160 def _get_module_from_path(self, path):
161 """Load a module from a path relative to the top-level directory
162 of a project. Used by discovery."""
163 path = os.path.splitext(os.path.normpath(path))[0]
164
165 relpath = os.path.relpath(path, self._top_level_dir)
166 assert not os.path.isabs(relpath), "Path must be within the project"
167 assert not relpath.startswith('..'), "Path must be within the project"
168
169 name = relpath.replace(os.path.sep, '.')
170 __import__(name)
171 return sys.modules[name]
172
173 def _find_tests(self, start_dir, pattern):
174 """Used by discovery. Yields test suites it loads."""
175 paths = os.listdir(start_dir)
176
177 for path in paths:
178 full_path = os.path.join(start_dir, path)
179 # what about __init__.pyc or pyo (etc)
180 # we would need to avoid loading the same tests multiple times
181 # from '.py', '.pyc' *and* '.pyo'
182 if os.path.isfile(full_path) and path.lower().endswith('.py'):
183 if fnmatch(path, pattern):
184 # if the test file matches, load it
185 module = self._get_module_from_path(full_path)
186 yield self.loadTestsFromModule(module)
187 elif os.path.isdir(full_path):
188 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
189 continue
190
191 load_tests = None
192 tests = None
193 if fnmatch(path, pattern):
194 # only check load_tests if the package directory itself matches the filter
195 package = self._get_module_from_path(full_path)
196 load_tests = getattr(package, 'load_tests', None)
197 tests = self.loadTestsFromModule(package, use_load_tests=False)
198
199 if load_tests is None:
200 if tests is not None:
201 # tests loaded from package file
202 yield tests
203 # recurse into the package
204 for test in self._find_tests(full_path, pattern):
205 yield test
206 else:
207 yield load_tests(self, tests, pattern)
208
209defaultTestLoader = TestLoader()
210
211
212def _makeLoader(prefix, sortUsing, suiteClass=None):
213 loader = TestLoader()
214 loader.sortTestMethodsUsing = sortUsing
215 loader.testMethodPrefix = prefix
216 if suiteClass:
217 loader.suiteClass = suiteClass
218 return loader
219
220def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
221 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
222
223def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
224 suiteClass=suite.TestSuite):
225 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
226 testCaseClass)
227
228def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
229 suiteClass=suite.TestSuite):
230 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
231 module)