implement chained exception tracebacks

patch from Antoine Pitrou #3112
diff --git a/Lib/test/test_raise.py b/Lib/test/test_raise.py
index 3072c14..ba9cfc5 100644
--- a/Lib/test/test_raise.py
+++ b/Lib/test/test_raise.py
@@ -278,6 +278,30 @@
         else:
             self.fail("No exception raised")
 
+    def test_cycle_broken(self):
+        # Self-cycles (when re-raising a caught exception) are broken
+        try:
+            try:
+                1/0
+            except ZeroDivisionError as e:
+                raise e
+        except ZeroDivisionError as e:
+            self.failUnless(e.__context__ is None, e.__context__)
+
+    def test_reraise_cycle_broken(self):
+        # Non-trivial context cycles (through re-raising a previous exception)
+        # are broken too.
+        try:
+            try:
+                xyzzy
+            except NameError as a:
+                try:
+                    1/0
+                except ZeroDivisionError:
+                    raise a
+        except NameError as e:
+            self.failUnless(e.__context__.__context__ is None)
+
 
 class TestRemovedFunctionality(unittest.TestCase):
     def test_tuples(self):
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 3f89e6a..3f69e5e 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -1,10 +1,11 @@
 """Test cases for traceback module"""
 
-from _testcapi import traceback_print
+from _testcapi import traceback_print, exception_print
 from io import StringIO
 import sys
 import unittest
-from test.support import run_unittest, is_jython, Error
+import re
+from test.support import run_unittest, is_jython, Error, captured_output
 
 import traceback
 
@@ -19,7 +20,7 @@
     raise Error("unable to create test traceback string")
 
 
-class TracebackCases(unittest.TestCase):
+class SyntaxTracebackCases(unittest.TestCase):
     # For now, a very minimal set of tests.  I want to be sure that
     # formatting of SyntaxErrors works based on changes for 2.1.
 
@@ -99,12 +100,135 @@
         banner, location, source_line = tb_lines
         self.assert_(banner.startswith('Traceback'))
         self.assert_(location.startswith('  File'))
-        self.assert_(source_line.startswith('raise'))
+        self.assert_(source_line.startswith('    raise'))
+
+
+cause_message = (
+    "\nThe above exception was the direct cause "
+    "of the following exception:\n\n")
+
+context_message = (
+    "\nDuring handling of the above exception, "
+    "another exception occurred:\n\n")
+
+boundaries = re.compile(
+    '(%s|%s)' % (re.escape(cause_message), re.escape(context_message)))
+
+
+class BaseExceptionReportingTests:
+
+    def get_exception(self, exception_or_callable):
+        if isinstance(exception_or_callable, Exception):
+            return exception_or_callable
+        try:
+            exception_or_callable()
+        except Exception as e:
+            return e
+
+    def zero_div(self):
+        1/0 # In zero_div
+
+    def check_zero_div(self, msg):
+        lines = msg.splitlines()
+        self.assert_(lines[-3].startswith('  File'))
+        self.assert_('1/0 # In zero_div' in lines[-2], lines[-2])
+        self.assert_(lines[-1].startswith('ZeroDivisionError'), lines[-1])
+
+    def test_simple(self):
+        try:
+            1/0 # Marker
+        except ZeroDivisionError as _:
+            e = _
+        lines = self.get_report(e).splitlines()
+        self.assertEquals(len(lines), 4)
+        self.assert_(lines[0].startswith('Traceback'))
+        self.assert_(lines[1].startswith('  File'))
+        self.assert_('1/0 # Marker' in lines[2])
+        self.assert_(lines[3].startswith('ZeroDivisionError'))
+
+    def test_cause(self):
+        def inner_raise():
+            try:
+                self.zero_div()
+            except ZeroDivisionError as e:
+                raise KeyError from e
+        def outer_raise():
+            inner_raise() # Marker
+        blocks = boundaries.split(self.get_report(outer_raise))
+        self.assertEquals(len(blocks), 3)
+        self.assertEquals(blocks[1], cause_message)
+        self.check_zero_div(blocks[0])
+        self.assert_('inner_raise() # Marker' in blocks[2])
+
+    def test_context(self):
+        def inner_raise():
+            try:
+                self.zero_div()
+            except ZeroDivisionError:
+                raise KeyError
+        def outer_raise():
+            inner_raise() # Marker
+        blocks = boundaries.split(self.get_report(outer_raise))
+        self.assertEquals(len(blocks), 3)
+        self.assertEquals(blocks[1], context_message)
+        self.check_zero_div(blocks[0])
+        self.assert_('inner_raise() # Marker' in blocks[2])
+
+    def test_cause_recursive(self):
+        def inner_raise():
+            try:
+                try:
+                    self.zero_div()
+                except ZeroDivisionError as e:
+                    z = e
+                    raise KeyError from e
+            except KeyError as e:
+                raise z from e
+        def outer_raise():
+            inner_raise() # Marker
+        blocks = boundaries.split(self.get_report(outer_raise))
+        self.assertEquals(len(blocks), 3)
+        self.assertEquals(blocks[1], cause_message)
+        # The first block is the KeyError raised from the ZeroDivisionError
+        self.assert_('raise KeyError from e' in blocks[0])
+        self.assert_('1/0' not in blocks[0])
+        # The second block (apart from the boundary) is the ZeroDivisionError
+        # re-raised from the KeyError
+        self.assert_('inner_raise() # Marker' in blocks[2])
+        self.check_zero_div(blocks[2])
+
+
+
+class PyExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
+    #
+    # This checks reporting through the 'traceback' module, with both
+    # format_exception() and print_exception().
+    #
+
+    def get_report(self, e):
+        e = self.get_exception(e)
+        s = ''.join(
+            traceback.format_exception(type(e), e, e.__traceback__))
+        with captured_output("stderr") as sio:
+            traceback.print_exception(type(e), e, e.__traceback__)
+        self.assertEquals(sio.getvalue(), s)
+        return s
+
+
+class CExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
+    #
+    # This checks built-in reporting by the interpreter.
+    #
+
+    def get_report(self, e):
+        e = self.get_exception(e)
+        with captured_output("stderr") as s:
+            exception_print(e)
+        return s.getvalue()
 
 
 def test_main():
-    run_unittest(TracebackCases, TracebackFormatTests)
-
+    run_unittest(__name__)
 
 if __name__ == "__main__":
     test_main()