implement chained exception tracebacks

patch from Antoine Pitrou #3112
diff --git a/Lib/test/test_raise.py b/Lib/test/test_raise.py
index 3072c14..ba9cfc5 100644
--- a/Lib/test/test_raise.py
+++ b/Lib/test/test_raise.py
@@ -278,6 +278,30 @@
         else:
             self.fail("No exception raised")
 
+    def test_cycle_broken(self):
+        # Self-cycles (when re-raising a caught exception) are broken
+        try:
+            try:
+                1/0
+            except ZeroDivisionError as e:
+                raise e
+        except ZeroDivisionError as e:
+            self.failUnless(e.__context__ is None, e.__context__)
+
+    def test_reraise_cycle_broken(self):
+        # Non-trivial context cycles (through re-raising a previous exception)
+        # are broken too.
+        try:
+            try:
+                xyzzy
+            except NameError as a:
+                try:
+                    1/0
+                except ZeroDivisionError:
+                    raise a
+        except NameError as e:
+            self.failUnless(e.__context__.__context__ is None)
+
 
 class TestRemovedFunctionality(unittest.TestCase):
     def test_tuples(self):
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 3f89e6a..3f69e5e 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -1,10 +1,11 @@
 """Test cases for traceback module"""
 
-from _testcapi import traceback_print
+from _testcapi import traceback_print, exception_print
 from io import StringIO
 import sys
 import unittest
-from test.support import run_unittest, is_jython, Error
+import re
+from test.support import run_unittest, is_jython, Error, captured_output
 
 import traceback
 
@@ -19,7 +20,7 @@
     raise Error("unable to create test traceback string")
 
 
-class TracebackCases(unittest.TestCase):
+class SyntaxTracebackCases(unittest.TestCase):
     # For now, a very minimal set of tests.  I want to be sure that
     # formatting of SyntaxErrors works based on changes for 2.1.
 
@@ -99,12 +100,135 @@
         banner, location, source_line = tb_lines
         self.assert_(banner.startswith('Traceback'))
         self.assert_(location.startswith('  File'))
-        self.assert_(source_line.startswith('raise'))
+        self.assert_(source_line.startswith('    raise'))
+
+
+cause_message = (
+    "\nThe above exception was the direct cause "
+    "of the following exception:\n\n")
+
+context_message = (
+    "\nDuring handling of the above exception, "
+    "another exception occurred:\n\n")
+
+boundaries = re.compile(
+    '(%s|%s)' % (re.escape(cause_message), re.escape(context_message)))
+
+
+class BaseExceptionReportingTests:
+
+    def get_exception(self, exception_or_callable):
+        if isinstance(exception_or_callable, Exception):
+            return exception_or_callable
+        try:
+            exception_or_callable()
+        except Exception as e:
+            return e
+
+    def zero_div(self):
+        1/0 # In zero_div
+
+    def check_zero_div(self, msg):
+        lines = msg.splitlines()
+        self.assert_(lines[-3].startswith('  File'))
+        self.assert_('1/0 # In zero_div' in lines[-2], lines[-2])
+        self.assert_(lines[-1].startswith('ZeroDivisionError'), lines[-1])
+
+    def test_simple(self):
+        try:
+            1/0 # Marker
+        except ZeroDivisionError as _:
+            e = _
+        lines = self.get_report(e).splitlines()
+        self.assertEquals(len(lines), 4)
+        self.assert_(lines[0].startswith('Traceback'))
+        self.assert_(lines[1].startswith('  File'))
+        self.assert_('1/0 # Marker' in lines[2])
+        self.assert_(lines[3].startswith('ZeroDivisionError'))
+
+    def test_cause(self):
+        def inner_raise():
+            try:
+                self.zero_div()
+            except ZeroDivisionError as e:
+                raise KeyError from e
+        def outer_raise():
+            inner_raise() # Marker
+        blocks = boundaries.split(self.get_report(outer_raise))
+        self.assertEquals(len(blocks), 3)
+        self.assertEquals(blocks[1], cause_message)
+        self.check_zero_div(blocks[0])
+        self.assert_('inner_raise() # Marker' in blocks[2])
+
+    def test_context(self):
+        def inner_raise():
+            try:
+                self.zero_div()
+            except ZeroDivisionError:
+                raise KeyError
+        def outer_raise():
+            inner_raise() # Marker
+        blocks = boundaries.split(self.get_report(outer_raise))
+        self.assertEquals(len(blocks), 3)
+        self.assertEquals(blocks[1], context_message)
+        self.check_zero_div(blocks[0])
+        self.assert_('inner_raise() # Marker' in blocks[2])
+
+    def test_cause_recursive(self):
+        def inner_raise():
+            try:
+                try:
+                    self.zero_div()
+                except ZeroDivisionError as e:
+                    z = e
+                    raise KeyError from e
+            except KeyError as e:
+                raise z from e
+        def outer_raise():
+            inner_raise() # Marker
+        blocks = boundaries.split(self.get_report(outer_raise))
+        self.assertEquals(len(blocks), 3)
+        self.assertEquals(blocks[1], cause_message)
+        # The first block is the KeyError raised from the ZeroDivisionError
+        self.assert_('raise KeyError from e' in blocks[0])
+        self.assert_('1/0' not in blocks[0])
+        # The second block (apart from the boundary) is the ZeroDivisionError
+        # re-raised from the KeyError
+        self.assert_('inner_raise() # Marker' in blocks[2])
+        self.check_zero_div(blocks[2])
+
+
+
+class PyExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
+    #
+    # This checks reporting through the 'traceback' module, with both
+    # format_exception() and print_exception().
+    #
+
+    def get_report(self, e):
+        e = self.get_exception(e)
+        s = ''.join(
+            traceback.format_exception(type(e), e, e.__traceback__))
+        with captured_output("stderr") as sio:
+            traceback.print_exception(type(e), e, e.__traceback__)
+        self.assertEquals(sio.getvalue(), s)
+        return s
+
+
+class CExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
+    #
+    # This checks built-in reporting by the interpreter.
+    #
+
+    def get_report(self, e):
+        e = self.get_exception(e)
+        with captured_output("stderr") as s:
+            exception_print(e)
+        return s.getvalue()
 
 
 def test_main():
-    run_unittest(TracebackCases, TracebackFormatTests)
-
+    run_unittest(__name__)
 
 if __name__ == "__main__":
     test_main()
diff --git a/Lib/traceback.py b/Lib/traceback.py
index fb1c5ad..b7130d8 100644
--- a/Lib/traceback.py
+++ b/Lib/traceback.py
@@ -3,6 +3,7 @@
 import linecache
 import sys
 import types
+import itertools
 
 __all__ = ['extract_stack', 'extract_tb', 'format_exception',
            'format_exception_only', 'format_list', 'format_stack',
@@ -107,7 +108,32 @@
     return list
 
 
-def print_exception(etype, value, tb, limit=None, file=None):
+_cause_message = (
+    "\nThe above exception was the direct cause "
+    "of the following exception:\n")
+
+_context_message = (
+    "\nDuring handling of the above exception, "
+    "another exception occurred:\n")
+
+def _iter_chain(exc, custom_tb=None, seen=None):
+    if seen is None:
+        seen = set()
+    seen.add(exc)
+    its = []
+    cause = exc.__cause__
+    context = exc.__context__
+    if cause is not None and cause not in seen:
+        its.append(_iter_chain(cause, None, seen))
+        its.append([(_cause_message, None)])
+    if context is not None and context is not cause and context not in seen:
+        its.append(_iter_chain(context, None, seen))
+        its.append([(_context_message, None)])
+    its.append([(exc, custom_tb or exc.__traceback__)])
+    return itertools.chain(*its)
+
+
+def print_exception(etype, value, tb, limit=None, file=None, chain=True):
     """Print exception up to 'limit' stack trace entries from 'tb' to 'file'.
 
     This differs from print_tb() in the following ways: (1) if
@@ -120,15 +146,23 @@
     """
     if file is None:
         file = sys.stderr
-    if tb:
-        _print(file, 'Traceback (most recent call last):')
-        print_tb(tb, limit, file)
-    lines = format_exception_only(etype, value)
-    for line in lines[:-1]:
-        _print(file, line, ' ')
-    _print(file, lines[-1], '')
+    if chain:
+        values = _iter_chain(value, tb)
+    else:
+        values = [(value, tb)]
+    for value, tb in values:
+        if isinstance(value, str):
+            _print(file, value)
+            continue
+        if tb:
+            _print(file, 'Traceback (most recent call last):')
+            print_tb(tb, limit, file)
+        lines = format_exception_only(type(value), value)
+        for line in lines[:-1]:
+            _print(file, line, ' ')
+        _print(file, lines[-1], '')
 
-def format_exception(etype, value, tb, limit = None):
+def format_exception(etype, value, tb, limit=None, chain=True):
     """Format a stack trace and the exception information.
 
     The arguments have the same meaning as the corresponding arguments
@@ -137,12 +171,19 @@
     these lines are concatenated and printed, exactly the same text is
     printed as does print_exception().
     """
-    if tb:
-        list = ['Traceback (most recent call last):\n']
-        list = list + format_tb(tb, limit)
+    list = []
+    if chain:
+        values = _iter_chain(value, tb)
     else:
-        list = []
-    list = list + format_exception_only(etype, value)
+        values = [(value, tb)]
+    for value, tb in values:
+        if isinstance(value, str):
+            list.append(value + '\n')
+            continue
+        if tb:
+            list.append('Traceback (most recent call last):\n')
+            list.extend(format_tb(tb, limit))
+        list.extend(format_exception_only(type(value), value))
     return list
 
 def format_exception_only(etype, value):
@@ -208,33 +249,34 @@
         return '<unprintable %s object>' % type(value).__name__
 
 
-def print_exc(limit=None, file=None):
+def print_exc(limit=None, file=None, chain=True):
     """Shorthand for 'print_exception(*sys.exc_info(), limit, file)'."""
     if file is None:
         file = sys.stderr
     try:
         etype, value, tb = sys.exc_info()
-        print_exception(etype, value, tb, limit, file)
+        print_exception(etype, value, tb, limit, file, chain)
     finally:
         etype = value = tb = None
 
 
-def format_exc(limit=None):
+def format_exc(limit=None, chain=True):
     """Like print_exc() but return a string."""
     try:
         etype, value, tb = sys.exc_info()
-        return ''.join(format_exception(etype, value, tb, limit))
+        return ''.join(
+            format_exception(etype, value, tb, limit, chain))
     finally:
         etype = value = tb = None
 
 
-def print_last(limit=None, file=None):
+def print_last(limit=None, file=None, chain=True):
     """This is a shorthand for 'print_exception(sys.last_type,
     sys.last_value, sys.last_traceback, limit, file)'."""
     if file is None:
         file = sys.stderr
     print_exception(sys.last_type, sys.last_value, sys.last_traceback,
-                    limit, file)
+                    limit, file, chain)
 
 
 def print_stack(f=None, limit=None, file=None):