blob: 070b5783d1e59b51c323e5584b12e2d508c26507 [file] [log] [blame]
Just van Rossum52e14d62002-12-30 22:08:05 +00001import sys
2import imp
3import os
4import unittest
5from test import test_support
6
7
8test_src = """\
9def get_name():
10 return __name__
11def get_file():
12 return __file__
13"""
14
15test_co = compile(test_src, "<???>", "exec")
16test_path = "!!!_test_!!!"
17
18
19class ImportTracker:
20 """Importer that only tracks attempted imports."""
21 def __init__(self):
22 self.imports = []
23 def find_module(self, fullname, path=None):
24 self.imports.append(fullname)
25 return None
26
27
28class TestImporter:
29
30 modules = {
31 "hooktestmodule": (False, test_co),
32 "hooktestpackage": (True, test_co),
33 "hooktestpackage.sub": (True, test_co),
34 "hooktestpackage.sub.subber": (False, test_co),
35 }
36
37 def __init__(self, path=test_path):
38 if path != test_path:
39 # if out class is on sys.path_hooks, we must raise
40 # ImportError for any path item that we can't handle.
41 raise ImportError
42 self.path = path
43
44 def _get__path__(self):
45 raise NotImplementedError
46
47 def find_module(self, fullname, path=None):
48 if fullname in self.modules:
49 return self
50 else:
51 return None
52
53 def load_module(self, fullname):
54 ispkg, code = self.modules[fullname]
55 mod = imp.new_module(fullname)
56 sys.modules[fullname] = mod
57 mod.__file__ = "<%s>" % self.__class__.__name__
58 mod.__loader__ = self
59 if ispkg:
60 mod.__path__ = self._get__path__()
61 exec code in mod.__dict__
62 return mod
63
64
65class MetaImporter(TestImporter):
66 def _get__path__(self):
67 return []
68
69class PathImporter(TestImporter):
70 def _get__path__(self):
71 return [self.path]
72
73
74class ImportBlocker:
75 """Place an ImportBlocker instance on sys.meta_path and you
76 can be sure the modules you specified can't be imported, even
77 if it's a builtin."""
78 def __init__(self, *namestoblock):
79 self.namestoblock = dict.fromkeys(namestoblock)
80 def find_module(self, fullname, path=None):
81 if fullname in self.namestoblock:
82 return self
83 return None
84 def load_module(self, fullname):
85 raise ImportError, "I dare you"
86
87
88class ImpWrapper:
89
90 def __init__(self, path=None):
91 if path is not None and not os.path.isdir(path):
92 raise ImportError
93 self.path = path
94
95 def find_module(self, fullname, path=None):
96 subname = fullname.split(".")[-1]
97 if subname != fullname and self.path is None:
98 return None
99 if self.path is None:
100 path = None
101 else:
102 path = [self.path]
103 try:
104 file, filename, stuff = imp.find_module(subname, path)
105 except ImportError:
106 return None
107 return ImpLoader(file, filename, stuff)
108
109
110class ImpLoader:
111
112 def __init__(self, file, filename, stuff):
113 self.file = file
114 self.filename = filename
115 self.stuff = stuff
116
117 def load_module(self, fullname):
118 mod = imp.load_module(fullname, self.file, self.filename, self.stuff)
119 if self.file:
120 self.file.close()
121 mod.__loader__ = self # for introspection
122 return mod
123
124
125class ImportHooksBaseTestCase(unittest.TestCase):
126
127 def setUp(self):
128 self.path = sys.path[:]
129 self.meta_path = sys.meta_path[:]
130 self.path_hooks = sys.path_hooks[:]
131 sys.path_importer_cache.clear()
132 self.tracker = ImportTracker()
133 sys.meta_path.insert(0, self.tracker)
134
135 def tearDown(self):
136 sys.path[:] = self.path
137 sys.meta_path[:] = self.meta_path
138 sys.path_hooks[:] = self.path_hooks
139 sys.path_importer_cache.clear()
140 for fullname in self.tracker.imports:
141 if fullname in sys.modules:
142 del sys.modules[fullname]
143
144
145class ImportHooksTestCase(ImportHooksBaseTestCase):
146
147 def doTestImports(self, importer=None):
148 import hooktestmodule
149 import hooktestpackage
150 import hooktestpackage.sub
151 import hooktestpackage.sub.subber
152 self.assertEqual(hooktestmodule.get_name(),
153 "hooktestmodule")
154 self.assertEqual(hooktestpackage.get_name(),
155 "hooktestpackage")
156 self.assertEqual(hooktestpackage.sub.get_name(),
157 "hooktestpackage.sub")
158 self.assertEqual(hooktestpackage.sub.subber.get_name(),
159 "hooktestpackage.sub.subber")
160 if importer:
161 self.assertEqual(hooktestmodule.__loader__, importer)
162 self.assertEqual(hooktestpackage.__loader__, importer)
163 self.assertEqual(hooktestpackage.sub.__loader__, importer)
164 self.assertEqual(hooktestpackage.sub.subber.__loader__, importer)
165
166 def testMetaPath(self):
167 i = MetaImporter()
168 sys.meta_path.append(i)
169 self.doTestImports(i)
170
171 def testPathHook(self):
172 sys.path_hooks.append(PathImporter)
173 sys.path.append(test_path)
174 self.doTestImports()
175
176 def testBlocker(self):
177 mname = "exceptions" # an arbitrary harmless builtin module
178 if mname in sys.modules:
179 del sys.modules[mname]
180 sys.meta_path.append(ImportBlocker(mname))
181 try:
182 __import__(mname)
183 except ImportError:
184 pass
185 else:
186 self.fail("'%s' was not supposed to be importable" % mname)
187
188 def testImpWrapper(self):
189 i = ImpWrapper()
190 sys.meta_path.append(i)
191 sys.path_hooks.append(ImpWrapper)
192 mnames = ("colorsys", "urlparse", "distutils.core", "compiler.misc")
193 for mname in mnames:
194 parent = mname.split(".")[0]
195 for n in sys.modules.keys():
196 if n.startswith(parent):
197 del sys.modules[n]
198 for mname in mnames:
199 m = __import__(mname, globals(), locals(), ["__dummy__"])
200 m.__loader__ # to make sure we actually handled the import
201
202
203if __name__ == "__main__":
204 test_support.run_unittest(ImportHooksTestCase)