Close #10042: functools.total_ordering now handles NotImplemented

(Patch by Katie Miller)
diff --git a/Lib/functools.py b/Lib/functools.py
index 19f88c7..6a6974f 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -89,21 +89,91 @@
 ### total_ordering class decorator
 ################################################################################
 
+# The correct way to indicate that a comparison operation doesn't
+# recognise the other type is to return NotImplemented and let the
+# interpreter handle raising TypeError if both operands return
+# NotImplemented from their respective comparison methods
+#
+# This makes the implementation of total_ordering more complicated, since
+# we need to be careful not to trigger infinite recursion when two
+# different types that both use this decorator encounter each other.
+#
+# For example, if a type implements __lt__, it's natural to define
+# __gt__ as something like:
+#
+#    lambda self, other: not self < other and not self == other
+#
+# However, using the operator syntax like that ends up invoking the full
+# type checking machinery again and means we can end up bouncing back and
+# forth between the two operands until we run out of stack space.
+#
+# The solution is to define helper functions that invoke the appropriate
+# magic methods directly, ensuring we only try each operand once, and
+# return NotImplemented immediately if it is returned from the
+# underlying user provided method. Using this scheme, the __gt__ derived
+# from a user provided __lt__ becomes:
+#
+#    lambda self, other: _not_op_and_not_eq(self.__lt__, self, other))
+
+def _not_op(op, other):
+    # "not a < b" handles "a >= b"
+    # "not a <= b" handles "a > b"
+    # "not a >= b" handles "a < b"
+    # "not a > b" handles "a <= b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return not op_result
+
+def _op_or_eq(op, self, other):
+    # "a < b or a == b" handles "a <= b"
+    # "a > b or a == b" handles "a >= b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return op_result or self == other
+
+def _not_op_and_not_eq(op, self, other):
+    # "not (a < b or a == b)" handles "a > b"
+    # "not a < b and a != b" is equivalent
+    # "not (a > b or a == b)" handles "a < b"
+    # "not a > b and a != b" is equivalent
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return not op_result and self != other
+
+def _not_op_or_eq(op, self, other):
+    # "not a <= b or a == b" handles "a >= b"
+    # "not a >= b or a == b" handles "a <= b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return not op_result or self == other
+
+def _op_and_not_eq(op, self, other):
+    # "a <= b and not a == b" handles "a < b"
+    # "a >= b and not a == b" handles "a > b"
+    op_result = op(other)
+    if op_result is NotImplemented:
+        return NotImplemented
+    return op_result and self != other
+
 def total_ordering(cls):
     """Class decorator that fills in missing ordering methods"""
     convert = {
-        '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
-                   ('__le__', lambda self, other: self < other or self == other),
-                   ('__ge__', lambda self, other: not self < other)],
-        '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
-                   ('__lt__', lambda self, other: self <= other and not self == other),
-                   ('__gt__', lambda self, other: not self <= other)],
-        '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
-                   ('__ge__', lambda self, other: self > other or self == other),
-                   ('__le__', lambda self, other: not self > other)],
-        '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
-                   ('__gt__', lambda self, other: self >= other and not self == other),
-                   ('__lt__', lambda self, other: not self >= other)]
+        '__lt__': [('__gt__', lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)),
+                   ('__le__', lambda self, other: _op_or_eq(self.__lt__, self, other)),
+                   ('__ge__', lambda self, other: _not_op(self.__lt__, other))],
+        '__le__': [('__ge__', lambda self, other: _not_op_or_eq(self.__le__, self, other)),
+                   ('__lt__', lambda self, other: _op_and_not_eq(self.__le__, self, other)),
+                   ('__gt__', lambda self, other: _not_op(self.__le__, other))],
+        '__gt__': [('__lt__', lambda self, other: _not_op_and_not_eq(self.__gt__, self, other)),
+                   ('__ge__', lambda self, other: _op_or_eq(self.__gt__, self, other)),
+                   ('__le__', lambda self, other: _not_op(self.__gt__, other))],
+        '__ge__': [('__le__', lambda self, other: _not_op_or_eq(self.__ge__, self, other)),
+                   ('__gt__', lambda self, other: _op_and_not_eq(self.__ge__, self, other)),
+                   ('__lt__', lambda self, other: _not_op(self.__ge__, other))]
     }
     # Find user-defined comparisons (not those inherited from object).
     roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)]