implement chained exception tracebacks

patch from Antoine Pitrou #3112
diff --git a/Lib/traceback.py b/Lib/traceback.py
index fb1c5ad..b7130d8 100644
--- a/Lib/traceback.py
+++ b/Lib/traceback.py
@@ -3,6 +3,7 @@
 import linecache
 import sys
 import types
+import itertools
 
 __all__ = ['extract_stack', 'extract_tb', 'format_exception',
            'format_exception_only', 'format_list', 'format_stack',
@@ -107,7 +108,32 @@
     return list
 
 
-def print_exception(etype, value, tb, limit=None, file=None):
+_cause_message = (
+    "\nThe above exception was the direct cause "
+    "of the following exception:\n")
+
+_context_message = (
+    "\nDuring handling of the above exception, "
+    "another exception occurred:\n")
+
+def _iter_chain(exc, custom_tb=None, seen=None):
+    if seen is None:
+        seen = set()
+    seen.add(exc)
+    its = []
+    cause = exc.__cause__
+    context = exc.__context__
+    if cause is not None and cause not in seen:
+        its.append(_iter_chain(cause, None, seen))
+        its.append([(_cause_message, None)])
+    if context is not None and context is not cause and context not in seen:
+        its.append(_iter_chain(context, None, seen))
+        its.append([(_context_message, None)])
+    its.append([(exc, custom_tb or exc.__traceback__)])
+    return itertools.chain(*its)
+
+
+def print_exception(etype, value, tb, limit=None, file=None, chain=True):
     """Print exception up to 'limit' stack trace entries from 'tb' to 'file'.
 
     This differs from print_tb() in the following ways: (1) if
@@ -120,15 +146,23 @@
     """
     if file is None:
         file = sys.stderr
-    if tb:
-        _print(file, 'Traceback (most recent call last):')
-        print_tb(tb, limit, file)
-    lines = format_exception_only(etype, value)
-    for line in lines[:-1]:
-        _print(file, line, ' ')
-    _print(file, lines[-1], '')
+    if chain:
+        values = _iter_chain(value, tb)
+    else:
+        values = [(value, tb)]
+    for value, tb in values:
+        if isinstance(value, str):
+            _print(file, value)
+            continue
+        if tb:
+            _print(file, 'Traceback (most recent call last):')
+            print_tb(tb, limit, file)
+        lines = format_exception_only(type(value), value)
+        for line in lines[:-1]:
+            _print(file, line, ' ')
+        _print(file, lines[-1], '')
 
-def format_exception(etype, value, tb, limit = None):
+def format_exception(etype, value, tb, limit=None, chain=True):
     """Format a stack trace and the exception information.
 
     The arguments have the same meaning as the corresponding arguments
@@ -137,12 +171,19 @@
     these lines are concatenated and printed, exactly the same text is
     printed as does print_exception().
     """
-    if tb:
-        list = ['Traceback (most recent call last):\n']
-        list = list + format_tb(tb, limit)
+    list = []
+    if chain:
+        values = _iter_chain(value, tb)
     else:
-        list = []
-    list = list + format_exception_only(etype, value)
+        values = [(value, tb)]
+    for value, tb in values:
+        if isinstance(value, str):
+            list.append(value + '\n')
+            continue
+        if tb:
+            list.append('Traceback (most recent call last):\n')
+            list.extend(format_tb(tb, limit))
+        list.extend(format_exception_only(type(value), value))
     return list
 
 def format_exception_only(etype, value):
@@ -208,33 +249,34 @@
         return '<unprintable %s object>' % type(value).__name__
 
 
-def print_exc(limit=None, file=None):
+def print_exc(limit=None, file=None, chain=True):
     """Shorthand for 'print_exception(*sys.exc_info(), limit, file)'."""
     if file is None:
         file = sys.stderr
     try:
         etype, value, tb = sys.exc_info()
-        print_exception(etype, value, tb, limit, file)
+        print_exception(etype, value, tb, limit, file, chain)
     finally:
         etype = value = tb = None
 
 
-def format_exc(limit=None):
+def format_exc(limit=None, chain=True):
     """Like print_exc() but return a string."""
     try:
         etype, value, tb = sys.exc_info()
-        return ''.join(format_exception(etype, value, tb, limit))
+        return ''.join(
+            format_exception(etype, value, tb, limit, chain))
     finally:
         etype = value = tb = None
 
 
-def print_last(limit=None, file=None):
+def print_last(limit=None, file=None, chain=True):
     """This is a shorthand for 'print_exception(sys.last_type,
     sys.last_value, sys.last_traceback, limit, file)'."""
     if file is None:
         file = sys.stderr
     print_exception(sys.last_type, sys.last_value, sys.last_traceback,
-                    limit, file)
+                    limit, file, chain)
 
 
 def print_stack(f=None, limit=None, file=None):