Issue #20317: Don't create a reference loop in ExitStack
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index f8e026b..f878285 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -231,11 +231,19 @@
         # we were actually nesting multiple with statements
         frame_exc = sys.exc_info()[1]
         def _fix_exception_context(new_exc, old_exc):
+            # Context isn't what we want, so find the end of the chain
             while 1:
                 exc_context = new_exc.__context__
-                if exc_context in (None, frame_exc):
+                if exc_context is old_exc:
+                    # Context is already set correctly (see issue 20317)
+                    return
+                if exc_context is None or exc_context is frame_exc:
                     break
+                details = id(new_exc), id(old_exc), id(exc_context)
+                raise Exception(str(details))
                 new_exc = exc_context
+            # Change the end of the chain to point to the exception
+            # we expect it to reference
             new_exc.__context__ = old_exc
 
         # Callbacks are invoked in LIFO order to match the behaviour of
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 9e45f70..e5365f7 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -600,6 +600,29 @@
         else:
             self.fail("Expected KeyError, but no exception was raised")
 
+    def test_exit_exception_with_correct_context(self):
+        # http://bugs.python.org/issue20317
+        @contextmanager
+        def gets_the_context_right():
+            try:
+                yield 6
+            finally:
+                1 / 0
+
+        # The contextmanager already fixes the context, so prior to the
+        # fix, ExitStack would try to fix it *again* and get into an
+        # infinite self-referential loop
+        try:
+            with ExitStack() as stack:
+                stack.enter_context(gets_the_context_right())
+                stack.enter_context(gets_the_context_right())
+                stack.enter_context(gets_the_context_right())
+        except ZeroDivisionError as exc:
+            self.assertIsInstance(exc.__context__, ZeroDivisionError)
+            self.assertIsInstance(exc.__context__.__context__, ZeroDivisionError)
+            self.assertIsNone(exc.__context__.__context__.__context__)
+
+
     def test_body_exception_suppress(self):
         def suppress_exc(*exc_details):
             return True