bpo-29302: Implement contextlib.AsyncExitStack. (#4790)

diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 96c8c22..ef8f8c9 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -7,7 +7,7 @@
 
 __all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
            "AbstractContextManager", "AbstractAsyncContextManager",
-           "ContextDecorator", "ExitStack",
+           "AsyncExitStack", "ContextDecorator", "ExitStack",
            "redirect_stdout", "redirect_stderr", "suppress"]
 
 
@@ -365,85 +365,102 @@
         return exctype is not None and issubclass(exctype, self._exceptions)
 
 
-# Inspired by discussions on http://bugs.python.org/issue13585
-class ExitStack(AbstractContextManager):
-    """Context manager for dynamic management of a stack of exit callbacks
+class _BaseExitStack:
+    """A base class for ExitStack and AsyncExitStack."""
 
-    For example:
+    @staticmethod
+    def _create_exit_wrapper(cm, cm_exit):
+        def _exit_wrapper(exc_type, exc, tb):
+            return cm_exit(cm, exc_type, exc, tb)
+        return _exit_wrapper
 
-        with ExitStack() as stack:
-            files = [stack.enter_context(open(fname)) for fname in filenames]
-            # All opened files will automatically be closed at the end of
-            # the with statement, even if attempts to open files later
-            # in the list raise an exception
+    @staticmethod
+    def _create_cb_wrapper(callback, *args, **kwds):
+        def _exit_wrapper(exc_type, exc, tb):
+            callback(*args, **kwds)
+        return _exit_wrapper
 
-    """
     def __init__(self):
         self._exit_callbacks = deque()
 
     def pop_all(self):
-        """Preserve the context stack by transferring it to a new instance"""
+        """Preserve the context stack by transferring it to a new instance."""
         new_stack = type(self)()
         new_stack._exit_callbacks = self._exit_callbacks
         self._exit_callbacks = deque()
         return new_stack
 
-    def _push_cm_exit(self, cm, cm_exit):
-        """Helper to correctly register callbacks to __exit__ methods"""
-        def _exit_wrapper(*exc_details):
-            return cm_exit(cm, *exc_details)
-        _exit_wrapper.__self__ = cm
-        self.push(_exit_wrapper)
-
     def push(self, exit):
-        """Registers a callback with the standard __exit__ method signature
+        """Registers a callback with the standard __exit__ method signature.
 
-        Can suppress exceptions the same way __exit__ methods can.
-
+        Can suppress exceptions the same way __exit__ method can.
         Also accepts any object with an __exit__ method (registering a call
-        to the method instead of the object itself)
+        to the method instead of the object itself).
         """
         # We use an unbound method rather than a bound method to follow
-        # the standard lookup behaviour for special methods
+        # the standard lookup behaviour for special methods.
         _cb_type = type(exit)
+
         try:
             exit_method = _cb_type.__exit__
         except AttributeError:
-            # Not a context manager, so assume its a callable
-            self._exit_callbacks.append(exit)
+            # Not a context manager, so assume it's a callable.
+            self._push_exit_callback(exit)
         else:
             self._push_cm_exit(exit, exit_method)
-        return exit # Allow use as a decorator
-
-    def callback(self, callback, *args, **kwds):
-        """Registers an arbitrary callback and arguments.
-
-        Cannot suppress exceptions.
-        """
-        def _exit_wrapper(exc_type, exc, tb):
-            callback(*args, **kwds)
-        # We changed the signature, so using @wraps is not appropriate, but
-        # setting __wrapped__ may still help with introspection
-        _exit_wrapper.__wrapped__ = callback
-        self.push(_exit_wrapper)
-        return callback # Allow use as a decorator
+        return exit  # Allow use as a decorator.
 
     def enter_context(self, cm):
-        """Enters the supplied context manager
+        """Enters the supplied context manager.
 
         If successful, also pushes its __exit__ method as a callback and
         returns the result of the __enter__ method.
         """
-        # We look up the special methods on the type to match the with statement
+        # We look up the special methods on the type to match the with
+        # statement.
         _cm_type = type(cm)
         _exit = _cm_type.__exit__
         result = _cm_type.__enter__(cm)
         self._push_cm_exit(cm, _exit)
         return result
 
-    def close(self):
-        """Immediately unwind the context stack"""
-        self.__exit__(None, None, None)
+    def callback(self, callback, *args, **kwds):
+        """Registers an arbitrary callback and arguments.
+
+        Cannot suppress exceptions.
+        """
+        _exit_wrapper = self._create_cb_wrapper(callback, *args, **kwds)
+
+        # We changed the signature, so using @wraps is not appropriate, but
+        # setting __wrapped__ may still help with introspection.
+        _exit_wrapper.__wrapped__ = callback
+        self._push_exit_callback(_exit_wrapper)
+        return callback  # Allow use as a decorator
+
+    def _push_cm_exit(self, cm, cm_exit):
+        """Helper to correctly register callbacks to __exit__ methods."""
+        _exit_wrapper = self._create_exit_wrapper(cm, cm_exit)
+        _exit_wrapper.__self__ = cm
+        self._push_exit_callback(_exit_wrapper, True)
+
+    def _push_exit_callback(self, callback, is_sync=True):
+        self._exit_callbacks.append((is_sync, callback))
+
+
+# Inspired by discussions on http://bugs.python.org/issue13585
+class ExitStack(_BaseExitStack, AbstractContextManager):
+    """Context manager for dynamic management of a stack of exit callbacks.
+
+    For example:
+        with ExitStack() as stack:
+            files = [stack.enter_context(open(fname)) for fname in filenames]
+            # All opened files will automatically be closed at the end of
+            # the with statement, even if attempts to open files later
+            # in the list raise an exception.
+    """
+
+    def __enter__(self):
+        return self
 
     def __exit__(self, *exc_details):
         received_exc = exc_details[0] is not None
@@ -470,7 +487,8 @@
         suppressed_exc = False
         pending_raise = False
         while self._exit_callbacks:
-            cb = self._exit_callbacks.pop()
+            is_sync, cb = self._exit_callbacks.pop()
+            assert is_sync
             try:
                 if cb(*exc_details):
                     suppressed_exc = True
@@ -493,6 +511,147 @@
                 raise
         return received_exc and suppressed_exc
 
+    def close(self):
+        """Immediately unwind the context stack."""
+        self.__exit__(None, None, None)
+
+
+# Inspired by discussions on https://bugs.python.org/issue29302
+class AsyncExitStack(_BaseExitStack, AbstractAsyncContextManager):
+    """Async context manager for dynamic management of a stack of exit
+    callbacks.
+
+    For example:
+        async with AsyncExitStack() as stack:
+            connections = [await stack.enter_async_context(get_connection())
+                for i in range(5)]
+            # All opened connections will automatically be released at the
+            # end of the async with statement, even if attempts to open a
+            # connection later in the list raise an exception.
+    """
+
+    @staticmethod
+    def _create_async_exit_wrapper(cm, cm_exit):
+        async def _exit_wrapper(exc_type, exc, tb):
+            return await cm_exit(cm, exc_type, exc, tb)
+        return _exit_wrapper
+
+    @staticmethod
+    def _create_async_cb_wrapper(callback, *args, **kwds):
+        async def _exit_wrapper(exc_type, exc, tb):
+            await callback(*args, **kwds)
+        return _exit_wrapper
+
+    async def enter_async_context(self, cm):
+        """Enters the supplied async context manager.
+
+        If successful, also pushes its __aexit__ method as a callback and
+        returns the result of the __aenter__ method.
+        """
+        _cm_type = type(cm)
+        _exit = _cm_type.__aexit__
+        result = await _cm_type.__aenter__(cm)
+        self._push_async_cm_exit(cm, _exit)
+        return result
+
+    def push_async_exit(self, exit):
+        """Registers a coroutine function with the standard __aexit__ method
+        signature.
+
+        Can suppress exceptions the same way __aexit__ method can.
+        Also accepts any object with an __aexit__ method (registering a call
+        to the method instead of the object itself).
+        """
+        _cb_type = type(exit)
+        try:
+            exit_method = _cb_type.__aexit__
+        except AttributeError:
+            # Not an async context manager, so assume it's a coroutine function
+            self._push_exit_callback(exit, False)
+        else:
+            self._push_async_cm_exit(exit, exit_method)
+        return exit  # Allow use as a decorator
+
+    def push_async_callback(self, callback, *args, **kwds):
+        """Registers an arbitrary coroutine function and arguments.
+
+        Cannot suppress exceptions.
+        """
+        _exit_wrapper = self._create_async_cb_wrapper(callback, *args, **kwds)
+
+        # We changed the signature, so using @wraps is not appropriate, but
+        # setting __wrapped__ may still help with introspection.
+        _exit_wrapper.__wrapped__ = callback
+        self._push_exit_callback(_exit_wrapper, False)
+        return callback  # Allow use as a decorator
+
+    async def aclose(self):
+        """Immediately unwind the context stack."""
+        await self.__aexit__(None, None, None)
+
+    def _push_async_cm_exit(self, cm, cm_exit):
+        """Helper to correctly register coroutine function to __aexit__
+        method."""
+        _exit_wrapper = self._create_async_exit_wrapper(cm, cm_exit)
+        _exit_wrapper.__self__ = cm
+        self._push_exit_callback(_exit_wrapper, False)
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, *exc_details):
+        received_exc = exc_details[0] is not None
+
+        # We manipulate the exception state so it behaves as though
+        # we were actually nesting multiple with statements
+        frame_exc = sys.exc_info()[1]
+        def _fix_exception_context(new_exc, old_exc):
+            # Context may not be correct, so find the end of the chain
+            while 1:
+                exc_context = new_exc.__context__
+                if exc_context is old_exc:
+                    # Context is already set correctly (see issue 20317)
+                    return
+                if exc_context is None or exc_context is frame_exc:
+                    break
+                new_exc = exc_context
+            # Change the end of the chain to point to the exception
+            # we expect it to reference
+            new_exc.__context__ = old_exc
+
+        # Callbacks are invoked in LIFO order to match the behaviour of
+        # nested context managers
+        suppressed_exc = False
+        pending_raise = False
+        while self._exit_callbacks:
+            is_sync, cb = self._exit_callbacks.pop()
+            try:
+                if is_sync:
+                    cb_suppress = cb(*exc_details)
+                else:
+                    cb_suppress = await cb(*exc_details)
+
+                if cb_suppress:
+                    suppressed_exc = True
+                    pending_raise = False
+                    exc_details = (None, None, None)
+            except:
+                new_exc_details = sys.exc_info()
+                # simulate the stack of exceptions by setting the context
+                _fix_exception_context(new_exc_details[1], exc_details[1])
+                pending_raise = True
+                exc_details = new_exc_details
+        if pending_raise:
+            try:
+                # bare "raise exc_details[1]" replaces our carefully
+                # set-up context
+                fixed_ctx = exc_details[1].__context__
+                raise exc_details[1]
+            except BaseException:
+                exc_details[1].__context__ = fixed_ctx
+                raise
+        return received_exc and suppressed_exc
+
 
 class nullcontext(AbstractContextManager):
     """Context manager that does no additional processing.