Close #15573: use value-based memoryview comparisons (patch by Stefan Krah)
diff --git a/Objects/memoryobject.c b/Objects/memoryobject.c
index 46a8416..f547983 100644
--- a/Objects/memoryobject.c
+++ b/Objects/memoryobject.c
@@ -246,7 +246,7 @@
     (view->suboffsets && view->suboffsets[dest->ndim-1] >= 0)
 
 Py_LOCAL_INLINE(int)
-last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src)
+last_dim_is_contiguous(const Py_buffer *dest, const Py_buffer *src)
 {
     assert(dest->ndim > 0 && src->ndim > 0);
     return (!HAVE_SUBOFFSETS_IN_LAST_DIM(dest) &&
@@ -255,37 +255,63 @@
             src->strides[src->ndim-1] == src->itemsize);
 }
 
-/* Check that the logical structure of the destination and source buffers
-   is identical. */
-static int
-cmp_structure(Py_buffer *dest, Py_buffer *src)
+/* This is not a general function for determining format equivalence.
+   It is used in copy_single() and copy_buffer() to weed out non-matching
+   formats. Skipping the '@' character is specifically used in slice
+   assignments, where the lvalue is already known to have a single character
+   format. This is a performance hack that could be rewritten (if properly
+   benchmarked). */
+Py_LOCAL_INLINE(int)
+equiv_format(const Py_buffer *dest, const Py_buffer *src)
 {
     const char *dfmt, *sfmt;
-    int i;
 
     assert(dest->format && src->format);
     dfmt = dest->format[0] == '@' ? dest->format+1 : dest->format;
     sfmt = src->format[0] == '@' ? src->format+1 : src->format;
 
     if (strcmp(dfmt, sfmt) != 0 ||
-        dest->itemsize != src->itemsize ||
-        dest->ndim != src->ndim) {
-        goto value_error;
+        dest->itemsize != src->itemsize) {
+        return 0;
     }
 
+    return 1;
+}
+
+/* Two shapes are equivalent if they are either equal or identical up
+   to a zero element at the same position. For example, in NumPy arrays
+   the shapes [1, 0, 5] and [1, 0, 7] are equivalent. */
+Py_LOCAL_INLINE(int)
+equiv_shape(const Py_buffer *dest, const Py_buffer *src)
+{
+    int i;
+
+    if (dest->ndim != src->ndim)
+        return 0;
+
     for (i = 0; i < dest->ndim; i++) {
         if (dest->shape[i] != src->shape[i])
-            goto value_error;
+            return 0;
         if (dest->shape[i] == 0)
             break;
     }
 
-    return 0;
+    return 1;
+}
 
-value_error:
-    PyErr_SetString(PyExc_ValueError,
-        "ndarray assignment: lvalue and rvalue have different structures");
-    return -1;
+/* Check that the logical structure of the destination and source buffers
+   is identical. */
+static int
+equiv_structure(const Py_buffer *dest, const Py_buffer *src)
+{
+    if (!equiv_format(dest, src) ||
+        !equiv_shape(dest, src)) {
+        PyErr_SetString(PyExc_ValueError,
+            "ndarray assignment: lvalue and rvalue have different structures");
+        return 0;
+    }
+
+    return 1;
 }
 
 /* Base case for recursive multi-dimensional copying. Contiguous arrays are
@@ -358,7 +384,7 @@
 
     assert(dest->ndim == 1);
 
-    if (cmp_structure(dest, src) < 0)
+    if (!equiv_structure(dest, src))
         return -1;
 
     if (!last_dim_is_contiguous(dest, src)) {
@@ -390,7 +416,7 @@
 
     assert(dest->ndim > 0);
 
-    if (cmp_structure(dest, src) < 0)
+    if (!equiv_structure(dest, src))
         return -1;
 
     if (!last_dim_is_contiguous(dest, src)) {
@@ -1828,6 +1854,131 @@
 
 
 /****************************************************************************/
+/*                       unpack using the struct module                     */
+/****************************************************************************/
+
+/* For reasonable performance it is necessary to cache all objects required
+   for unpacking. An unpacker can handle the format passed to unpack_from().
+   Invariant: All pointer fields of the struct should either be NULL or valid
+   pointers. */
+struct unpacker {
+    PyObject *unpack_from; /* Struct.unpack_from(format) */
+    PyObject *mview;       /* cached memoryview */
+    char *item;            /* buffer for mview */
+    Py_ssize_t itemsize;   /* len(item) */
+};
+
+static struct unpacker *
+unpacker_new(void)
+{
+    struct unpacker *x = PyMem_Malloc(sizeof *x);
+
+    if (x == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
+
+    x->unpack_from = NULL;
+    x->mview = NULL;
+    x->item = NULL;
+    x->itemsize = 0;
+
+    return x;
+}
+
+static void
+unpacker_free(struct unpacker *x)
+{
+    if (x) {
+        Py_XDECREF(x->unpack_from);
+        Py_XDECREF(x->mview);
+        PyMem_Free(x->item);
+        PyMem_Free(x);
+    }
+}
+
+/* Return a new unpacker for the given format. */
+static struct unpacker *
+struct_get_unpacker(const char *fmt, Py_ssize_t itemsize)
+{
+    PyObject *structmodule;     /* XXX cache these two */
+    PyObject *Struct = NULL;    /* XXX in globals?     */
+    PyObject *structobj = NULL;
+    PyObject *format = NULL;
+    struct unpacker *x = NULL;
+
+    structmodule = PyImport_ImportModule("struct");
+    if (structmodule == NULL)
+        return NULL;
+
+    Struct = PyObject_GetAttrString(structmodule, "Struct");
+    Py_DECREF(structmodule);
+    if (Struct == NULL)
+        return NULL;
+
+    x = unpacker_new();
+    if (x == NULL)
+        goto error;
+
+    format = PyBytes_FromString(fmt);
+    if (format == NULL)
+        goto error;
+
+    structobj = PyObject_CallFunctionObjArgs(Struct, format, NULL);
+    if (structobj == NULL)
+        goto error;
+
+    x->unpack_from = PyObject_GetAttrString(structobj, "unpack_from");
+    if (x->unpack_from == NULL)
+        goto error;
+
+    x->item = PyMem_Malloc(itemsize);
+    if (x->item == NULL) {
+        PyErr_NoMemory();
+        goto error;
+    }
+    x->itemsize = itemsize;
+
+    x->mview = PyMemoryView_FromMemory(x->item, itemsize, PyBUF_WRITE);
+    if (x->mview == NULL)
+        goto error;
+
+
+out:
+    Py_XDECREF(Struct);
+    Py_XDECREF(format);
+    Py_XDECREF(structobj);
+    return x;
+
+error:
+    unpacker_free(x);
+    x = NULL;
+    goto out;
+}
+
+/* unpack a single item */
+static PyObject *
+struct_unpack_single(const char *ptr, struct unpacker *x)
+{
+    PyObject *v;
+
+    memcpy(x->item, ptr, x->itemsize);
+    v = PyObject_CallFunctionObjArgs(x->unpack_from, x->mview, NULL);
+    if (v == NULL)
+        return NULL;
+
+    if (PyTuple_GET_SIZE(v) == 1) {
+        PyObject *tmp = PyTuple_GET_ITEM(v, 0);
+        Py_INCREF(tmp);
+        Py_DECREF(v);
+        return tmp;
+    }
+
+    return v;
+}
+
+
+/****************************************************************************/
 /*                              Representations                             */
 /****************************************************************************/
 
@@ -2261,6 +2412,58 @@
 /*                             Comparisons                                */
 /**************************************************************************/
 
+#define MV_COMPARE_EX -1       /* exception */
+#define MV_COMPARE_NOT_IMPL -2 /* not implemented */
+
+/* Translate a StructError to "not equal". Preserve other exceptions. */
+static int
+fix_struct_error_int(void)
+{
+    assert(PyErr_Occurred());
+    /* XXX Cannot get at StructError directly? */
+    if (PyErr_ExceptionMatches(PyExc_ImportError) ||
+        PyErr_ExceptionMatches(PyExc_MemoryError)) {
+        return MV_COMPARE_EX;
+    }
+    /* StructError: invalid or unknown format -> not equal */
+    PyErr_Clear();
+    return 0;
+}
+
+/* Unpack and compare single items of p and q using the struct module. */
+static int
+struct_unpack_cmp(const char *p, const char *q,
+                  struct unpacker *unpack_p, struct unpacker *unpack_q)
+{
+    PyObject *v, *w;
+    int ret;
+
+    /* At this point any exception from the struct module should not be
+       StructError, since both formats have been accepted already. */
+    v = struct_unpack_single(p, unpack_p);
+    if (v == NULL)
+        return MV_COMPARE_EX;
+
+    w = struct_unpack_single(q, unpack_q);
+    if (w == NULL) {
+        Py_DECREF(v);
+        return MV_COMPARE_EX;
+    }
+
+    /* MV_COMPARE_EX == -1: exceptions are preserved */
+    ret = PyObject_RichCompareBool(v, w, Py_EQ);
+    Py_DECREF(v);
+    Py_DECREF(w);
+
+    return ret;
+}
+
+/* Unpack and compare single items of p and q. If both p and q have the same
+   single element native format, the comparison uses a fast path (gcc creates
+   a jump table and converts memcpy into simple assignments on x86/x64).
+
+   Otherwise, the comparison is delegated to the struct module, which is
+   30-60x slower. */
 #define CMP_SINGLE(p, q, type) \
     do {                                 \
         type x;                          \
@@ -2271,11 +2474,12 @@
     } while (0)
 
 Py_LOCAL_INLINE(int)
-unpack_cmp(const char *p, const char *q, const char *fmt)
+unpack_cmp(const char *p, const char *q, char fmt,
+           struct unpacker *unpack_p, struct unpacker *unpack_q)
 {
     int equal;
 
-    switch (fmt[0]) {
+    switch (fmt) {
 
     /* signed integers and fast path for 'B' */
     case 'B': return *((unsigned char *)p) == *((unsigned char *)q);
@@ -2317,9 +2521,17 @@
     /* pointer */
     case 'P': CMP_SINGLE(p, q, void *); return equal;
 
-    /* Py_NotImplemented */
-    default: return -1;
+    /* use the struct module */
+    case '_':
+        assert(unpack_p);
+        assert(unpack_q);
+        return struct_unpack_cmp(p, q, unpack_p, unpack_q);
     }
+
+    /* NOT REACHED */
+    PyErr_SetString(PyExc_RuntimeError,
+        "memoryview: internal error in richcompare");
+    return MV_COMPARE_EX;
 }
 
 /* Base case for recursive array comparisons. Assumption: ndim == 1. */
@@ -2327,7 +2539,7 @@
 cmp_base(const char *p, const char *q, const Py_ssize_t *shape,
          const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets,
          const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets,
-         const char *fmt)
+         char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q)
 {
     Py_ssize_t i;
     int equal;
@@ -2335,7 +2547,7 @@
     for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) {
         const char *xp = ADJUST_PTR(p, psuboffsets);
         const char *xq = ADJUST_PTR(q, qsuboffsets);
-        equal = unpack_cmp(xp, xq, fmt);
+        equal = unpack_cmp(xp, xq, fmt, unpack_p, unpack_q);
         if (equal <= 0)
             return equal;
     }
@@ -2350,7 +2562,7 @@
         Py_ssize_t ndim, const Py_ssize_t *shape,
         const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets,
         const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets,
-        const char *fmt)
+        char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q)
 {
     Py_ssize_t i;
     int equal;
@@ -2364,7 +2576,7 @@
         return cmp_base(p, q, shape,
                         pstrides, psuboffsets,
                         qstrides, qsuboffsets,
-                        fmt);
+                        fmt, unpack_p, unpack_q);
     }
 
     for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) {
@@ -2373,7 +2585,7 @@
         equal = cmp_rec(xp, xq, ndim-1, shape+1,
                         pstrides+1, psuboffsets ? psuboffsets+1 : NULL,
                         qstrides+1, qsuboffsets ? qsuboffsets+1 : NULL,
-                        fmt);
+                        fmt, unpack_p, unpack_q);
         if (equal <= 0)
             return equal;
     }
@@ -2385,9 +2597,12 @@
 memory_richcompare(PyObject *v, PyObject *w, int op)
 {
     PyObject *res;
-    Py_buffer wbuf, *vv, *ww = NULL;
-    const char *vfmt, *wfmt;
-    int equal = -1; /* Py_NotImplemented */
+    Py_buffer wbuf, *vv;
+    Py_buffer *ww = NULL;
+    struct unpacker *unpack_v = NULL;
+    struct unpacker *unpack_w = NULL;
+    char vfmt, wfmt;
+    int equal = MV_COMPARE_NOT_IMPL;
 
     if (op != Py_EQ && op != Py_NE)
         goto result; /* Py_NotImplemented */
@@ -2414,38 +2629,59 @@
         ww = &wbuf;
     }
 
-    vfmt = adjust_fmt(vv);
-    wfmt = adjust_fmt(ww);
-    if (vfmt == NULL || wfmt == NULL) {
-        PyErr_Clear();
-        goto result; /* Py_NotImplemented */
-    }
-
-    if (cmp_structure(vv, ww) < 0) {
+    if (!equiv_shape(vv, ww)) {
         PyErr_Clear();
         equal = 0;
         goto result;
     }
 
+    /* Use fast unpacking for identical primitive C type formats. */
+    if (get_native_fmtchar(&vfmt, vv->format) < 0)
+        vfmt = '_';
+    if (get_native_fmtchar(&wfmt, ww->format) < 0)
+        wfmt = '_';
+    if (vfmt == '_' || wfmt == '_' || vfmt != wfmt) {
+        /* Use struct module unpacking. NOTE: Even for equal format strings,
+           memcmp() cannot be used for item comparison since it would give
+           incorrect results in the case of NaNs or uninitialized padding
+           bytes. */
+        vfmt = '_';
+        unpack_v = struct_get_unpacker(vv->format, vv->itemsize);
+        if (unpack_v == NULL) {
+            equal = fix_struct_error_int();
+            goto result;
+        }
+        unpack_w = struct_get_unpacker(ww->format, ww->itemsize);
+        if (unpack_w == NULL) {
+            equal = fix_struct_error_int();
+            goto result;
+        }
+    }
+
     if (vv->ndim == 0) {
-        equal = unpack_cmp(vv->buf, ww->buf, vfmt);
+        equal = unpack_cmp(vv->buf, ww->buf,
+                           vfmt, unpack_v, unpack_w);
     }
     else if (vv->ndim == 1) {
         equal = cmp_base(vv->buf, ww->buf, vv->shape,
                          vv->strides, vv->suboffsets,
                          ww->strides, ww->suboffsets,
-                         vfmt);
+                         vfmt, unpack_v, unpack_w);
     }
     else {
         equal = cmp_rec(vv->buf, ww->buf, vv->ndim, vv->shape,
                         vv->strides, vv->suboffsets,
                         ww->strides, ww->suboffsets,
-                        vfmt);
+                        vfmt, unpack_v, unpack_w);
     }
 
 result:
-    if (equal < 0)
-        res = Py_NotImplemented; 
+    if (equal < 0) {
+        if (equal == MV_COMPARE_NOT_IMPL)
+            res = Py_NotImplemented;
+        else /* exception */
+            res = NULL;
+    }
     else if ((equal && op == Py_EQ) || (!equal && op == Py_NE))
         res = Py_True;
     else
@@ -2453,7 +2689,11 @@
 
     if (ww == &wbuf)
         PyBuffer_Release(ww);
-    Py_INCREF(res);
+
+    unpacker_free(unpack_v);
+    unpacker_free(unpack_w);
+
+    Py_XINCREF(res);
     return res;
 }