Merged revisions 71799 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r71799 | nick.coghlan | 2009-04-23 01:26:04 +1000 (Thu, 23 Apr 2009) | 1 line

  Issue 5354: Change API for import_fresh_module() to better support test_warnings use case (also fixes some bugs in the original implementation)
........
diff --git a/Lib/test/support.py b/Lib/test/support.py
index 28823ae..ebb3495 100644
--- a/Lib/test/support.py
+++ b/Lib/test/support.py
@@ -69,12 +69,43 @@
             raise unittest.SkipTest(str(msg))
 
 
-def import_fresh_module(name, blocked_names=None, deprecated=False):
+def _save_and_remove_module(name, orig_modules):
+    """Helper function to save and remove a module from sys.modules
+
+       Return value is True if the module was in sys.modules and
+       False otherwise."""
+    saved = True
+    try:
+        orig_modules[name] = sys.modules[name]
+    except KeyError:
+        saved = False
+    else:
+        del sys.modules[name]
+    return saved
+
+
+def _save_and_block_module(name, orig_modules):
+    """Helper function to save and block a module in sys.modules
+
+       Return value is True if the module was in sys.modules and
+       False otherwise."""
+    saved = True
+    try:
+        orig_modules[name] = sys.modules[name]
+    except KeyError:
+        saved = False
+    sys.modules[name] = 0
+    return saved
+
+
+def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
     """Imports and returns a module, deliberately bypassing the sys.modules cache
     and importing a fresh copy of the module. Once the import is complete,
     the sys.modules cache is restored to its original state.
 
-    Importing of modules named in blocked_names is prevented while the fresh import
+    Modules named in fresh are also imported anew if needed by the import.
+
+    Importing of modules named in blocked is prevented while the fresh import
     takes place.
 
     If deprecated is True, any module or package deprecation messages
@@ -82,21 +113,24 @@
     # NOTE: test_heapq and test_warnings include extra sanity checks to make
     # sure that this utility function is working as expected
     with _ignore_deprecated_imports(deprecated):
-        if blocked_names is None:
-            blocked_names = ()
+        # Keep track of modules saved for later restoration as well
+        # as those which just need a blocking entry removed
         orig_modules = {}
-        if name in sys.modules:
-            orig_modules[name] = sys.modules[name]
-            del sys.modules[name]
+        names_to_remove = []
+        _save_and_remove_module(name, orig_modules)
         try:
-            for blocked in blocked_names:
-                orig_modules[blocked] = sys.modules[blocked]
-                sys.modules[blocked] = 0
-            py_module = importlib.import_module(name)
+            for fresh_name in fresh:
+                _save_and_remove_module(fresh_name, orig_modules)
+            for blocked_name in blocked:
+                if not _save_and_block_module(blocked_name, orig_modules):
+                    names_to_remove.append(blocked_name)
+            fresh_module = importlib.import_module(name)
         finally:
-            for blocked, module in orig_modules.items():
-                sys.modules[blocked] = module
-        return py_module
+            for orig_name, module in orig_modules.items():
+                sys.modules[orig_name] = module
+            for name_to_remove in names_to_remove:
+                del sys.modules[name_to_remove]
+        return fresh_module
 
 
 def get_attribute(obj, name):