bpo-28685: Optimize sorted() list.sort() with type-specialized comparisons (#582)

diff --git a/Lib/test/test_sort.py b/Lib/test/test_sort.py
index 98ccab5..f2f53cb 100644
--- a/Lib/test/test_sort.py
+++ b/Lib/test/test_sort.py
@@ -260,6 +260,120 @@
         self.assertEqual(data, copy2)
 
 #==============================================================================
+def check_against_PyObject_RichCompareBool(self, L):
+    ## The idea here is to exploit the fact that unsafe_tuple_compare uses
+    ## PyObject_RichCompareBool for the second elements of tuples. So we have,
+    ## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])]
+    ## This will work as long as __eq__ => not __lt__ for all the objects in L,
+    ## which holds for all the types used below.
+    ##
+    ## Testing this way ensures that the optimized implementation remains consistent
+    ## with the naive implementation, even if changes are made to any of the
+    ## richcompares.
+    ##
+    ## This function tests sorting for three lists (it randomly shuffles each one):
+    ##                        1. L
+    ##                        2. [(x,) for x in L]
+    ##                        3. [((x,),) for x in L]
+
+    random.seed(0)
+    random.shuffle(L)
+    L_1 = L[:]
+    L_2 = [(x,) for x in L]
+    L_3 = [((x,),) for x in L]
+    for L in [L_1, L_2, L_3]:
+        optimized = sorted(L)
+        reference = [y[1] for y in sorted([(0,x) for x in L])]
+        for (opt, ref) in zip(optimized, reference):
+            self.assertIs(opt, ref)
+            #note: not assertEqual! We want to ensure *identical* behavior.
+
+class TestOptimizedCompares(unittest.TestCase):
+    def test_safe_object_compare(self):
+        heterogeneous_lists = [[0, 'foo'],
+                               [0.0, 'foo'],
+                               [('foo',), 'foo']]
+        for L in heterogeneous_lists:
+            self.assertRaises(TypeError, L.sort)
+            self.assertRaises(TypeError, [(x,) for x in L].sort)
+            self.assertRaises(TypeError, [((x,),) for x in L].sort)
+
+        float_int_lists = [[1,1.1],
+                           [1<<70,1.1],
+                           [1.1,1],
+                           [1.1,1<<70]]
+        for L in float_int_lists:
+            check_against_PyObject_RichCompareBool(self, L)
+
+    def test_unsafe_object_compare(self):
+
+        # This test is by ppperry. It ensures that unsafe_object_compare is
+        # verifying ms->key_richcompare == tp->richcompare before comparing.
+
+        class WackyComparator(int):
+            def __lt__(self, other):
+                elem.__class__ = WackyList2
+                return int.__lt__(self, other)
+
+        class WackyList1(list):
+            pass
+
+        class WackyList2(list):
+            def __lt__(self, other):
+                raise ValueError
+
+        L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
+        elem = L[-1]
+        with self.assertRaises(ValueError):
+            L.sort()
+
+        L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
+        elem = L[-1]
+        with self.assertRaises(ValueError):
+            [(x,) for x in L].sort()
+
+        # The following test is also by ppperry. It ensures that
+        # unsafe_object_compare handles Py_NotImplemented appropriately.
+        class PointlessComparator:
+            def __lt__(self, other):
+                return NotImplemented
+        L = [PointlessComparator(), PointlessComparator()]
+        self.assertRaises(TypeError, L.sort)
+        self.assertRaises(TypeError, [(x,) for x in L].sort)
+
+        # The following tests go through various types that would trigger
+        # ms->key_compare = unsafe_object_compare
+        lists = [list(range(100)) + [(1<<70)],
+                 [str(x) for x in range(100)] + ['\uffff'],
+                 [bytes(x) for x in range(100)],
+                 [cmp_to_key(lambda x,y: x<y)(x) for x in range(100)]]
+        for L in lists:
+            check_against_PyObject_RichCompareBool(self, L)
+
+    def test_unsafe_latin_compare(self):
+        check_against_PyObject_RichCompareBool(self, [str(x) for
+                                                      x in range(100)])
+
+    def test_unsafe_long_compare(self):
+        check_against_PyObject_RichCompareBool(self, [x for
+                                                      x in range(100)])
+
+    def test_unsafe_float_compare(self):
+        check_against_PyObject_RichCompareBool(self, [float(x) for
+                                                      x in range(100)])
+
+    def test_unsafe_tuple_compare(self):
+        # This test was suggested by Tim Peters. It verifies that the tuple
+        # comparison respects the current tuple compare semantics, which do not
+        # guarantee that x < x <=> (x,) < (x,)
+        #
+        # Note that we don't have to put anything in tuples here, because
+        # the check function does a tuple test automatically.
+
+        check_against_PyObject_RichCompareBool(self, [float('nan')]*100)
+        check_against_PyObject_RichCompareBool(self, [float('nan') for
+                                                      _ in range(100)])
+#==============================================================================
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/ACKS b/Misc/ACKS
index 20210e8..bfcdd33 100644
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -554,6 +554,7 @@
 Chris Gonnerman
 Shelley Gooch
 David Goodger
+Elliot Gorokhovsky
 Hans de Graaff
 Tim Graham
 Kim Gräsman
diff --git a/Misc/NEWS.d/next/Core and Builtins/2018-01-28-15-09-33.bpo-28685.cHThLM.rst b/Misc/NEWS.d/next/Core and Builtins/2018-01-28-15-09-33.bpo-28685.cHThLM.rst
new file mode 100644
index 0000000..ccc3c08
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2018-01-28-15-09-33.bpo-28685.cHThLM.rst
@@ -0,0 +1,2 @@
+Optimize list.sort() and sorted() by using type specialized comparisons when
+possible.
diff --git a/Objects/listobject.c b/Objects/listobject.c
index 8794e37..9e32137 100644
--- a/Objects/listobject.c
+++ b/Objects/listobject.c
@@ -1081,11 +1081,12 @@
         slice->values += n;
 }
 
-/* Comparison function: PyObject_RichCompareBool with Py_LT.
+/* Comparison function: ms->key_compare, which is set at run-time in
+ * listsort_impl to optimize for various special cases.
  * Returns -1 on error, 1 if x < y, 0 if x >= y.
  */
 
-#define ISLT(X, Y) (PyObject_RichCompareBool(X, Y, Py_LT))
+#define ISLT(X, Y) (*(ms->key_compare))(X, Y, ms)
 
 /* Compare X to Y via "<".  Goto "fail" if the comparison raises an
    error.  Else "k" is set to true iff X<Y, and an "if (k)" block is
@@ -1094,6 +1095,75 @@
 #define IFLT(X, Y) if ((k = ISLT(X, Y)) < 0) goto fail;  \
            if (k)
 
+/* The maximum number of entries in a MergeState's pending-runs stack.
+ * This is enough to sort arrays of size up to about
+ *     32 * phi ** MAX_MERGE_PENDING
+ * where phi ~= 1.618.  85 is ridiculouslylarge enough, good for an array
+ * with 2**64 elements.
+ */
+#define MAX_MERGE_PENDING 85
+
+/* When we get into galloping mode, we stay there until both runs win less
+ * often than MIN_GALLOP consecutive times.  See listsort.txt for more info.
+ */
+#define MIN_GALLOP 7
+
+/* Avoid malloc for small temp arrays. */
+#define MERGESTATE_TEMP_SIZE 256
+
+/* One MergeState exists on the stack per invocation of mergesort.  It's just
+ * a convenient way to pass state around among the helper functions.
+ */
+struct s_slice {
+    sortslice base;
+    Py_ssize_t len;
+};
+
+typedef struct s_MergeState MergeState;
+struct s_MergeState {
+    /* This controls when we get *into* galloping mode.  It's initialized
+     * to MIN_GALLOP.  merge_lo and merge_hi tend to nudge it higher for
+     * random data, and lower for highly structured data.
+     */
+    Py_ssize_t min_gallop;
+
+    /* 'a' is temp storage to help with merges.  It contains room for
+     * alloced entries.
+     */
+    sortslice a;        /* may point to temparray below */
+    Py_ssize_t alloced;
+
+    /* A stack of n pending runs yet to be merged.  Run #i starts at
+     * address base[i] and extends for len[i] elements.  It's always
+     * true (so long as the indices are in bounds) that
+     *
+     *     pending[i].base + pending[i].len == pending[i+1].base
+     *
+     * so we could cut the storage for this, but it's a minor amount,
+     * and keeping all the info explicit simplifies the code.
+     */
+    int n;
+    struct s_slice pending[MAX_MERGE_PENDING];
+
+    /* 'a' points to this when possible, rather than muck with malloc. */
+    PyObject *temparray[MERGESTATE_TEMP_SIZE];
+
+    /* This is the function we will use to compare two keys,
+     * even when none of our special cases apply and we have to use
+     * safe_object_compare. */
+    int (*key_compare)(PyObject *, PyObject *, MergeState *);
+
+    /* This function is used by unsafe_object_compare to optimize comparisons
+     * when we know our list is type-homogeneous but we can't assume anything else.
+     * In the pre-sort check it is set equal to key->ob_type->tp_richcompare */
+    PyObject *(*key_richcompare)(PyObject *, PyObject *, int);
+
+    /* This function is used by unsafe_tuple_compare to compare the first elements
+     * of tuples. It may be set to safe_object_compare, but the idea is that hopefully
+     * we can assume more, and use one of the special-case compares. */
+    int (*tuple_elem_compare)(PyObject *, PyObject *, MergeState *);
+};
+
 /* binarysort is the best method for sorting small arrays: it does
    few compares, but can do data movement quadratic in the number of
    elements.
@@ -1106,7 +1176,7 @@
    the input (nothing is lost or duplicated).
 */
 static int
-binarysort(sortslice lo, PyObject **hi, PyObject **start)
+binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
 {
     Py_ssize_t k;
     PyObject **l, **p, **r;
@@ -1180,7 +1250,7 @@
 Returns -1 in case of error.
 */
 static Py_ssize_t
-count_run(PyObject **lo, PyObject **hi, int *descending)
+count_run(MergeState *ms, PyObject **lo, PyObject **hi, int *descending)
 {
     Py_ssize_t k;
     Py_ssize_t n;
@@ -1235,7 +1305,7 @@
 Returns -1 on error.  See listsort.txt for info on the method.
 */
 static Py_ssize_t
-gallop_left(PyObject *key, PyObject **a, Py_ssize_t n, Py_ssize_t hint)
+gallop_left(MergeState *ms, PyObject *key, PyObject **a, Py_ssize_t n, Py_ssize_t hint)
 {
     Py_ssize_t ofs;
     Py_ssize_t lastofs;
@@ -1326,7 +1396,7 @@
 written as one routine with yet another "left or right?" flag.
 */
 static Py_ssize_t
-gallop_right(PyObject *key, PyObject **a, Py_ssize_t n, Py_ssize_t hint)
+gallop_right(MergeState *ms, PyObject *key, PyObject **a, Py_ssize_t n, Py_ssize_t hint)
 {
     Py_ssize_t ofs;
     Py_ssize_t lastofs;
@@ -1402,59 +1472,6 @@
     return -1;
 }
 
-/* The maximum number of entries in a MergeState's pending-runs stack.
- * This is enough to sort arrays of size up to about
- *     32 * phi ** MAX_MERGE_PENDING
- * where phi ~= 1.618.  85 is ridiculouslylarge enough, good for an array
- * with 2**64 elements.
- */
-#define MAX_MERGE_PENDING 85
-
-/* When we get into galloping mode, we stay there until both runs win less
- * often than MIN_GALLOP consecutive times.  See listsort.txt for more info.
- */
-#define MIN_GALLOP 7
-
-/* Avoid malloc for small temp arrays. */
-#define MERGESTATE_TEMP_SIZE 256
-
-/* One MergeState exists on the stack per invocation of mergesort.  It's just
- * a convenient way to pass state around among the helper functions.
- */
-struct s_slice {
-    sortslice base;
-    Py_ssize_t len;
-};
-
-typedef struct s_MergeState {
-    /* This controls when we get *into* galloping mode.  It's initialized
-     * to MIN_GALLOP.  merge_lo and merge_hi tend to nudge it higher for
-     * random data, and lower for highly structured data.
-     */
-    Py_ssize_t min_gallop;
-
-    /* 'a' is temp storage to help with merges.  It contains room for
-     * alloced entries.
-     */
-    sortslice a;        /* may point to temparray below */
-    Py_ssize_t alloced;
-
-    /* A stack of n pending runs yet to be merged.  Run #i starts at
-     * address base[i] and extends for len[i] elements.  It's always
-     * true (so long as the indices are in bounds) that
-     *
-     *     pending[i].base + pending[i].len == pending[i+1].base
-     *
-     * so we could cut the storage for this, but it's a minor amount,
-     * and keeping all the info explicit simplifies the code.
-     */
-    int n;
-    struct s_slice pending[MAX_MERGE_PENDING];
-
-    /* 'a' points to this when possible, rather than muck with malloc. */
-    PyObject *temparray[MERGESTATE_TEMP_SIZE];
-} MergeState;
-
 /* Conceptually a MergeState's constructor. */
 static void
 merge_init(MergeState *ms, Py_ssize_t list_size, int has_keyfunc)
@@ -1514,11 +1531,11 @@
      * we don't care what's in the block.
      */
     merge_freemem(ms);
-    if ((size_t)need > PY_SSIZE_T_MAX / sizeof(PyObject*) / multiplier) {
+    if ((size_t)need > PY_SSIZE_T_MAX / sizeof(PyObject *) / multiplier) {
         PyErr_NoMemory();
         return -1;
     }
-    ms->a.keys = (PyObject**)PyMem_Malloc(multiplier * need
+    ms->a.keys = (PyObject **)PyMem_Malloc(multiplier * need
                                           * sizeof(PyObject *));
     if (ms->a.keys != NULL) {
         ms->alloced = need;
@@ -1607,7 +1624,7 @@
             assert(na > 1 && nb > 0);
             min_gallop -= min_gallop > 1;
             ms->min_gallop = min_gallop;
-            k = gallop_right(ssb.keys[0], ssa.keys, na, 0);
+            k = gallop_right(ms, ssb.keys[0], ssa.keys, na, 0);
             acount = k;
             if (k) {
                 if (k < 0)
@@ -1630,7 +1647,7 @@
             if (nb == 0)
                 goto Succeed;
 
-            k = gallop_left(ssa.keys[0], ssb.keys, nb, 0);
+            k = gallop_left(ms, ssa.keys[0], ssb.keys, nb, 0);
             bcount = k;
             if (k) {
                 if (k < 0)
@@ -1745,7 +1762,7 @@
             assert(na > 0 && nb > 1);
             min_gallop -= min_gallop > 1;
             ms->min_gallop = min_gallop;
-            k = gallop_right(ssb.keys[0], basea.keys, na, na-1);
+            k = gallop_right(ms, ssb.keys[0], basea.keys, na, na-1);
             if (k < 0)
                 goto Fail;
             k = na - k;
@@ -1763,7 +1780,7 @@
             if (nb == 1)
                 goto CopyA;
 
-            k = gallop_left(ssa.keys[0], baseb.keys, nb, nb-1);
+            k = gallop_left(ms, ssa.keys[0], baseb.keys, nb, nb-1);
             if (k < 0)
                 goto Fail;
             k = nb - k;
@@ -1840,7 +1857,7 @@
     /* Where does b start in a?  Elements in a before that can be
      * ignored (already in place).
      */
-    k = gallop_right(*ssb.keys, ssa.keys, na, 0);
+    k = gallop_right(ms, *ssb.keys, ssa.keys, na, 0);
     if (k < 0)
         return -1;
     sortslice_advance(&ssa, k);
@@ -1851,7 +1868,7 @@
     /* Where does a end in b?  Elements in b after that can be
      * ignored (already in place).
      */
-    nb = gallop_left(ssa.keys[na-1], ssb.keys, nb, nb-1);
+    nb = gallop_left(ms, ssa.keys[na-1], ssb.keys, nb, nb-1);
     if (nb <= 0)
         return nb;
 
@@ -1890,8 +1907,8 @@
                 return -1;
         }
         else if (p[n].len <= p[n+1].len) {
-                 if (merge_at(ms, n) < 0)
-                        return -1;
+            if (merge_at(ms, n) < 0)
+                return -1;
         }
         else
             break;
@@ -1951,6 +1968,170 @@
         reverse_slice(s->values, &s->values[n]);
 }
 
+/* Here we define custom comparison functions to optimize for the cases one commonly
+ * encounters in practice: homogeneous lists, often of one of the basic types. */
+
+/* This struct holds the comparison function and helper functions
+ * selected in the pre-sort check. */
+
+/* These are the special case compare functions.
+ * ms->key_compare will always point to one of these: */
+
+/* Heterogeneous compare: default, always safe to fall back on. */
+static int
+safe_object_compare(PyObject *v, PyObject *w, MergeState *ms)
+{
+    /* No assumptions necessary! */
+    return PyObject_RichCompareBool(v, w, Py_LT);
+}
+
+/* Homogeneous compare: safe for any two compareable objects of the same type.
+ * (ms->key_richcompare is set to ob_type->tp_richcompare in the
+ *  pre-sort check.)
+ */
+static int
+unsafe_object_compare(PyObject *v, PyObject *w, MergeState *ms)
+{
+    PyObject *res_obj; int res;
+
+    /* No assumptions, because we check first: */
+    if (v->ob_type->tp_richcompare != ms->key_richcompare)
+        return PyObject_RichCompareBool(v, w, Py_LT);
+
+    assert(ms->key_richcompare != NULL);
+    res_obj = (*(ms->key_richcompare))(v, w, Py_LT);
+
+    if (res_obj == Py_NotImplemented) {
+        Py_DECREF(res_obj);
+        return PyObject_RichCompareBool(v, w, Py_LT);
+    }
+    if (res_obj == NULL)
+        return -1;
+
+    if (PyBool_Check(res_obj)) {
+        res = (res_obj == Py_True);
+    }
+    else {
+        res = PyObject_IsTrue(res_obj);
+    }
+    Py_DECREF(res_obj);
+
+    /* Note that we can't assert
+     *     res == PyObject_RichCompareBool(v, w, Py_LT);
+     * because of evil compare functions like this:
+     *     lambda a, b:  int(random.random() * 3) - 1)
+     * (which is actually in test_sort.py) */
+    return res;
+}
+
+/* Latin string compare: safe for any two latin (one byte per char) strings. */
+static int
+unsafe_latin_compare(PyObject *v, PyObject *w, MergeState *ms)
+{
+    int len, res;
+
+    /* Modified from Objects/unicodeobject.c:unicode_compare, assuming: */
+    assert(v->ob_type == w->ob_type);
+    assert(v->ob_type == &PyUnicode_Type);
+    assert(PyUnicode_KIND(v) == PyUnicode_KIND(w));
+    assert(PyUnicode_KIND(v) == PyUnicode_1BYTE_KIND);
+
+    len = Py_MIN(PyUnicode_GET_LENGTH(v), PyUnicode_GET_LENGTH(w));
+    res = memcmp(PyUnicode_DATA(v), PyUnicode_DATA(w), len);
+
+    res = (res != 0 ?
+           res < 0 :
+           PyUnicode_GET_LENGTH(v) < PyUnicode_GET_LENGTH(w));
+
+    assert(res == PyObject_RichCompareBool(v, w, Py_LT));;
+    return res;
+}
+
+/* Bounded int compare: compare any two longs that fit in a single machine word. */
+static int
+unsafe_long_compare(PyObject *v, PyObject *w, MergeState *ms)
+{
+    PyLongObject *vl, *wl; sdigit v0, w0; int res;
+
+    /* Modified from Objects/longobject.c:long_compare, assuming: */
+    assert(v->ob_type == w->ob_type);
+    assert(v->ob_type == &PyLong_Type);
+    assert(Py_ABS(Py_SIZE(v)) <= 1);
+    assert(Py_ABS(Py_SIZE(w)) <= 1);
+
+    vl = (PyLongObject*)v;
+    wl = (PyLongObject*)w;
+
+    v0 = Py_SIZE(vl) == 0 ? 0 : (sdigit)vl->ob_digit[0];
+    w0 = Py_SIZE(wl) == 0 ? 0 : (sdigit)wl->ob_digit[0];
+
+    if (Py_SIZE(vl) < 0)
+        v0 = -v0;
+    if (Py_SIZE(wl) < 0)
+        w0 = -w0;
+
+    res = v0 < w0;
+    assert(res == PyObject_RichCompareBool(v, w, Py_LT));
+    return res;
+}
+
+/* Float compare: compare any two floats. */
+static int
+unsafe_float_compare(PyObject *v, PyObject *w, MergeState *ms)
+{
+    int res;
+
+    /* Modified from Objects/floatobject.c:float_richcompare, assuming: */
+    assert(v->ob_type == w->ob_type);
+    assert(v->ob_type == &PyFloat_Type);
+
+    res = PyFloat_AS_DOUBLE(v) < PyFloat_AS_DOUBLE(w);
+    assert(res == PyObject_RichCompareBool(v, w, Py_LT));
+    return res;
+}
+
+/* Tuple compare: compare *any* two tuples, using
+ * ms->tuple_elem_compare to compare the first elements, which is set
+ * using the same pre-sort check as we use for ms->key_compare,
+ * but run on the list [x[0] for x in L]. This allows us to optimize compares
+ * on two levels (as long as [x[0] for x in L] is type-homogeneous.) The idea is
+ * that most tuple compares don't involve x[1:]. */
+static int
+unsafe_tuple_compare(PyObject *v, PyObject *w, MergeState *ms)
+{
+    PyTupleObject *vt, *wt;
+    Py_ssize_t i, vlen, wlen;
+    int k;
+
+    /* Modified from Objects/tupleobject.c:tuplerichcompare, assuming: */
+    assert(v->ob_type == w->ob_type);
+    assert(v->ob_type == &PyTuple_Type);
+    assert(Py_SIZE(v) > 0);
+    assert(Py_SIZE(w) > 0);
+
+    vt = (PyTupleObject *)v;
+    wt = (PyTupleObject *)w;
+
+    vlen = Py_SIZE(vt);
+    wlen = Py_SIZE(wt);
+
+    for (i = 0; i < vlen && i < wlen; i++) {
+        k = PyObject_RichCompareBool(vt->ob_item[i], wt->ob_item[i], Py_EQ);
+        if (k < 0)
+            return -1;
+        if (!k)
+            break;
+    }
+
+    if (i >= vlen || i >= wlen)
+        return vlen < wlen;
+
+    if (i == 0)
+        return ms->tuple_elem_compare(vt->ob_item[i], wt->ob_item[i], ms);
+    else
+        return PyObject_RichCompareBool(vt->ob_item[i], wt->ob_item[i], Py_LT);
+}
+
 /* An adaptive, stable, natural mergesort.  See listsort.txt.
  * Returns Py_None on success, NULL on error.  Even in case of error, the
  * list will be some permutation of its input state (nothing is lost or
@@ -2031,6 +2212,91 @@
         lo.values = saved_ob_item;
     }
 
+
+    /* The pre-sort check: here's where we decide which compare function to use.
+     * How much optimization is safe? We test for homogeneity with respect to
+     * several properties that are expensive to check at compare-time, and
+     * set ms appropriately. */
+    if (saved_ob_size > 1) {
+        /* Assume the first element is representative of the whole list. */
+        int keys_are_in_tuples = (lo.keys[0]->ob_type == &PyTuple_Type &&
+                                  Py_SIZE(lo.keys[0]) > 0);
+
+        PyTypeObject* key_type = (keys_are_in_tuples ?
+                                  PyTuple_GET_ITEM(lo.keys[0], 0)->ob_type :
+                                  lo.keys[0]->ob_type);
+
+        int keys_are_all_same_type = 1;
+        int strings_are_latin = 1;
+        int ints_are_bounded = 1;
+
+        /* Prove that assumption by checking every key. */
+        int i;
+        for (i=0; i < saved_ob_size; i++) {
+
+            if (keys_are_in_tuples &&
+                !(lo.keys[i]->ob_type == &PyTuple_Type && Py_SIZE(lo.keys[i]) != 0)) {
+                keys_are_in_tuples = 0;
+                keys_are_all_same_type = 0;
+                break;
+            }
+
+            /* Note: for lists of tuples, key is the first element of the tuple
+             * lo.keys[i], not lo.keys[i] itself! We verify type-homogeneity
+             * for lists of tuples in the if-statement directly above. */
+            PyObject *key = (keys_are_in_tuples ?
+                             PyTuple_GET_ITEM(lo.keys[i], 0) :
+                             lo.keys[i]);
+
+            if (key->ob_type != key_type) {
+                keys_are_all_same_type = 0;
+                break;
+            }
+
+            if (key_type == &PyLong_Type) {
+                if (ints_are_bounded && Py_ABS(Py_SIZE(key)) > 1)
+                    ints_are_bounded = 0;
+            }
+            else if (key_type == &PyUnicode_Type){
+                if (strings_are_latin &&
+                    PyUnicode_KIND(key) != PyUnicode_1BYTE_KIND)
+                strings_are_latin = 0;
+            }
+        }
+
+        /* Choose the best compare, given what we now know about the keys. */
+        if (keys_are_all_same_type) {
+
+            if (key_type == &PyUnicode_Type && strings_are_latin) {
+                ms.key_compare = unsafe_latin_compare;
+            }
+            else if (key_type == &PyLong_Type && ints_are_bounded) {
+                ms.key_compare = unsafe_long_compare;
+            }
+            else if (key_type == &PyFloat_Type) {
+                ms.key_compare = unsafe_float_compare;
+            }
+            else if ((ms.key_richcompare = key_type->tp_richcompare) != NULL) {
+                ms.key_compare = unsafe_object_compare;
+            }
+        }
+        else {
+            ms.key_compare = safe_object_compare;
+        }
+
+        if (keys_are_in_tuples) {
+            /* Make sure we're not dealing with tuples of tuples
+             * (remember: here, key_type refers list [key[0] for key in keys]) */
+            if (key_type == &PyTuple_Type)
+                ms.tuple_elem_compare = safe_object_compare;
+            else
+                ms.tuple_elem_compare = ms.key_compare;
+
+            ms.key_compare = unsafe_tuple_compare;
+        }
+    }
+    /* End of pre-sort check: ms is now set properly! */
+
     merge_init(&ms, saved_ob_size, keys != NULL);
 
     nremaining = saved_ob_size;
@@ -2054,7 +2320,7 @@
         Py_ssize_t n;
 
         /* Identify next run. */
-        n = count_run(lo.keys, lo.keys + nremaining, &descending);
+        n = count_run(&ms, lo.keys, lo.keys + nremaining, &descending);
         if (n < 0)
             goto fail;
         if (descending)
@@ -2063,7 +2329,7 @@
         if (n < minrun) {
             const Py_ssize_t force = nremaining <= minrun ?
                               nremaining : minrun;
-            if (binarysort(lo, lo.keys + force, lo.keys + n) < 0)
+            if (binarysort(&ms, lo, lo.keys + force, lo.keys + n) < 0)
                 goto fail;
             n = force;
         }
diff --git a/Objects/listsort.txt b/Objects/listsort.txt
index 17d2797..8c87751 100644
--- a/Objects/listsort.txt
+++ b/Objects/listsort.txt
@@ -753,3 +753,11 @@
 locations:  before B[4], between B[4] and B[5], between B[5] and B[6], and
 after B[6].  In general, across 2**(k-1)-1 elements, there are 2**(k-1)
 locations.  That's why k-1 binary searches are necessary and sufficient.
+
+OPTIMIZATION OF INDIVIDUAL COMPARISONS
+As noted above, even the simplest Python comparison triggers a large pile of
+C-level pointer dereferences, conditionals, and function calls.  This can be
+partially mitigated by pre-scanning the data to determine whether the data is
+homogenous with respect to type.  If so, it is sometimes possible to
+substitute faster type-specific comparisons for the slower, generic
+PyObject_RichCompareBool.