bpo-42128: Structural Pattern Matching (PEP 634) (GH-22917)

Co-authored-by: Guido van Rossum <guido@python.org>
Co-authored-by: Talin <viridia@gmail.com>
Co-authored-by: Pablo Galindo <pablogsal@gmail.com>
diff --git a/Python/ceval.c b/Python/ceval.c
index e1a8f15..7862643 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -880,6 +880,230 @@ _Py_CheckRecursiveCall(PyThreadState *tstate, const char *where)
     return 0;
 }
 
+
+// PEP 634: Structural Pattern Matching
+
+
+// Return a tuple of values corresponding to keys, with error checks for
+// duplicate/missing keys.
+static PyObject*
+match_keys(PyThreadState *tstate, PyObject *map, PyObject *keys)
+{
+    assert(PyTuple_CheckExact(keys));
+    Py_ssize_t nkeys = PyTuple_GET_SIZE(keys);
+    if (!nkeys) {
+        // No keys means no items.
+        return PyTuple_New(0);
+    }
+    PyObject *seen = NULL;
+    PyObject *dummy = NULL;
+    PyObject *values = NULL;
+    // We use the two argument form of map.get(key, default) for two reasons:
+    // - Atomically check for a key and get its value without error handling.
+    // - Don't cause key creation or resizing in dict subclasses like
+    //   collections.defaultdict that define __missing__ (or similar).
+    _Py_IDENTIFIER(get);
+    PyObject *get = _PyObject_GetAttrId(map, &PyId_get);
+    if (get == NULL) {
+        goto fail;
+    }
+    seen = PySet_New(NULL);
+    if (seen == NULL) {
+        goto fail;
+    }
+    // dummy = object()
+    dummy = _PyObject_CallNoArg((PyObject *)&PyBaseObject_Type);
+    if (dummy == NULL) {
+        goto fail;
+    }
+    values = PyList_New(0);
+    if (values == NULL) {
+        goto fail;
+    }
+    for (Py_ssize_t i = 0; i < nkeys; i++) {
+        PyObject *key = PyTuple_GET_ITEM(keys, i);
+        if (PySet_Contains(seen, key) || PySet_Add(seen, key)) {
+            if (!_PyErr_Occurred(tstate)) {
+                // Seen it before!
+                _PyErr_Format(tstate, PyExc_ValueError,
+                              "mapping pattern checks duplicate key (%R)", key);
+            }
+            goto fail;
+        }
+        PyObject *value = PyObject_CallFunctionObjArgs(get, key, dummy, NULL);
+        if (value == NULL) {
+            goto fail;
+        }
+        if (value == dummy) {
+            // key not in map!
+            Py_DECREF(value);
+            Py_DECREF(values);
+            // Return None:
+            Py_INCREF(Py_None);
+            values = Py_None;
+            goto done;
+        }
+        PyList_Append(values, value);
+        Py_DECREF(value);
+    }
+    Py_SETREF(values, PyList_AsTuple(values));
+    // Success:
+done:
+    Py_DECREF(get);
+    Py_DECREF(seen);
+    Py_DECREF(dummy);
+    return values;
+fail:
+    Py_XDECREF(get);
+    Py_XDECREF(seen);
+    Py_XDECREF(dummy);
+    Py_XDECREF(values);
+    return NULL;
+}
+
+// Extract a named attribute from the subject, with additional bookkeeping to
+// raise TypeErrors for repeated lookups. On failure, return NULL (with no
+// error set). Use _PyErr_Occurred(tstate) to disambiguate.
+static PyObject*
+match_class_attr(PyThreadState *tstate, PyObject *subject, PyObject *type,
+                 PyObject *name, PyObject *seen)
+{
+    assert(PyUnicode_CheckExact(name));
+    assert(PySet_CheckExact(seen));
+    if (PySet_Contains(seen, name) || PySet_Add(seen, name)) {
+        if (!_PyErr_Occurred(tstate)) {
+            // Seen it before!
+            _PyErr_Format(tstate, PyExc_TypeError,
+                          "%s() got multiple sub-patterns for attribute %R",
+                          ((PyTypeObject*)type)->tp_name, name);
+        }
+        return NULL;
+    }
+    PyObject *attr = PyObject_GetAttr(subject, name);
+    if (attr == NULL && _PyErr_ExceptionMatches(tstate, PyExc_AttributeError)) {
+        _PyErr_Clear(tstate);
+    }
+    return attr;
+}
+
+// On success (match), return a tuple of extracted attributes. On failure (no
+// match), return NULL. Use _PyErr_Occurred(tstate) to disambiguate.
+static PyObject*
+match_class(PyThreadState *tstate, PyObject *subject, PyObject *type,
+            Py_ssize_t nargs, PyObject *kwargs)
+{
+    if (!PyType_Check(type)) {
+        const char *e = "called match pattern must be a type";
+        _PyErr_Format(tstate, PyExc_TypeError, e);
+        return NULL;
+    }
+    assert(PyTuple_CheckExact(kwargs));
+    // First, an isinstance check:
+    if (PyObject_IsInstance(subject, type) <= 0) {
+        return NULL;
+    }
+    // So far so good:
+    PyObject *seen = PySet_New(NULL);
+    if (seen == NULL) {
+        return NULL;
+    }
+    PyObject *attrs = PyList_New(0);
+    if (attrs == NULL) {
+        Py_DECREF(seen);
+        return NULL;
+    }
+    // NOTE: From this point on, goto fail on failure:
+    PyObject *match_args = NULL;
+    // First, the positional subpatterns:
+    if (nargs) {
+        int match_self = 0;
+        match_args = PyObject_GetAttrString(type, "__match_args__");
+        if (match_args) {
+            if (PyList_CheckExact(match_args)) {
+                Py_SETREF(match_args, PyList_AsTuple(match_args));
+            }
+            if (match_args == NULL) {
+                goto fail;
+            }
+            if (!PyTuple_CheckExact(match_args)) {
+                const char *e = "%s.__match_args__ must be a list or tuple "
+                                "(got %s)";
+                _PyErr_Format(tstate, PyExc_TypeError, e,
+                              ((PyTypeObject *)type)->tp_name,
+                              Py_TYPE(match_args)->tp_name);
+                goto fail;
+            }
+        }
+        else if (_PyErr_ExceptionMatches(tstate, PyExc_AttributeError)) {
+            _PyErr_Clear(tstate);
+            // _Py_TPFLAGS_MATCH_SELF is only acknowledged if the type does not
+            // define __match_args__. This is natural behavior for subclasses:
+            // it's as if __match_args__ is some "magic" value that is lost as
+            // soon as they redefine it.
+            match_args = PyTuple_New(0);
+            match_self = PyType_HasFeature((PyTypeObject*)type,
+                                            _Py_TPFLAGS_MATCH_SELF);
+        }
+        else {
+            goto fail;
+        }
+        assert(PyTuple_CheckExact(match_args));
+        Py_ssize_t allowed = match_self ? 1 : PyTuple_GET_SIZE(match_args);
+        if (allowed < nargs) {
+            const char *plural = (allowed == 1) ? "" : "s";
+            _PyErr_Format(tstate, PyExc_TypeError,
+                          "%s() accepts %d positional sub-pattern%s (%d given)",
+                          ((PyTypeObject*)type)->tp_name,
+                          allowed, plural, nargs);
+            goto fail;
+        }
+        if (match_self) {
+            // Easy. Copy the subject itself, and move on to kwargs.
+            PyList_Append(attrs, subject);
+        }
+        else {
+            for (Py_ssize_t i = 0; i < nargs; i++) {
+                PyObject *name = PyTuple_GET_ITEM(match_args, i);
+                if (!PyUnicode_CheckExact(name)) {
+                    _PyErr_Format(tstate, PyExc_TypeError,
+                                  "__match_args__ elements must be strings "
+                                  "(got %s)", Py_TYPE(name)->tp_name);
+                    goto fail;
+                }
+                PyObject *attr = match_class_attr(tstate, subject, type, name,
+                                                  seen);
+                if (attr == NULL) {
+                    goto fail;
+                }
+                PyList_Append(attrs, attr);
+                Py_DECREF(attr);
+            }
+        }
+        Py_CLEAR(match_args);
+    }
+    // Finally, the keyword subpatterns:
+    for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(kwargs); i++) {
+        PyObject *name = PyTuple_GET_ITEM(kwargs, i);
+        PyObject *attr = match_class_attr(tstate, subject, type, name, seen);
+        if (attr == NULL) {
+            goto fail;
+        }
+        PyList_Append(attrs, attr);
+        Py_DECREF(attr);
+    }
+    Py_SETREF(attrs, PyList_AsTuple(attrs));
+    Py_DECREF(seen);
+    return attrs;
+fail:
+    // We really don't care whether an error was raised or not... that's our
+    // caller's problem. All we know is that the match failed.
+    Py_XDECREF(match_args);
+    Py_DECREF(seen);
+    Py_DECREF(attrs);
+    return NULL;
+}
+
+
 static int do_raise(PyThreadState *tstate, PyObject *exc, PyObject *cause);
 static int unpack_iterable(PyThreadState *, PyObject *, int, int, PyObject **);
 
@@ -3618,6 +3842,164 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, PyFrameObject *f, int throwflag)
 #endif
         }
 
+        case TARGET(GET_LEN): {
+            // PUSH(len(TOS))
+            Py_ssize_t len_i = PyObject_Length(TOP());
+            if (len_i < 0) {
+                goto error;
+            }
+            PyObject *len_o = PyLong_FromSsize_t(len_i);
+            if (len_o == NULL) {
+                goto error;
+            }
+            PUSH(len_o);
+            DISPATCH();
+        }
+
+        case TARGET(MATCH_CLASS): {
+            // Pop TOS. On success, set TOS to True and TOS1 to a tuple of
+            // attributes. On failure, set TOS to False.
+            PyObject *names = POP();
+            PyObject *type = TOP();
+            PyObject *subject = SECOND();
+            assert(PyTuple_CheckExact(names));
+            PyObject *attrs = match_class(tstate, subject, type, oparg, names);
+            Py_DECREF(names);
+            if (attrs) {
+                // Success!
+                assert(PyTuple_CheckExact(attrs));
+                Py_DECREF(subject);
+                SET_SECOND(attrs);
+            }
+            else if (_PyErr_Occurred(tstate)) {
+                goto error;
+            }
+            Py_DECREF(type);
+            SET_TOP(PyBool_FromLong(!!attrs));
+            DISPATCH();
+        }
+
+        case TARGET(MATCH_MAPPING): {
+            // PUSH(isinstance(TOS, _collections_abc.Mapping))
+            PyObject *subject = TOP();
+            // Fast path for dicts:
+            if (PyDict_Check(subject)) {
+                Py_INCREF(Py_True);
+                PUSH(Py_True);
+                DISPATCH();
+            }
+            // Lazily import _collections_abc.Mapping, and keep it handy on the
+            // PyInterpreterState struct (it gets cleaned up at exit):
+            PyInterpreterState *interp = PyInterpreterState_Get();
+            if (interp->map_abc == NULL) {
+                PyObject *abc = PyImport_ImportModule("_collections_abc");
+                if (abc == NULL) {
+                    goto error;
+                }
+                interp->map_abc = PyObject_GetAttrString(abc, "Mapping");
+                if (interp->map_abc == NULL) {
+                    goto error;
+                }
+            }
+            int match = PyObject_IsInstance(subject, interp->map_abc);
+            if (match < 0) {
+                goto error;
+            }
+            PUSH(PyBool_FromLong(match));
+            DISPATCH();
+        }
+
+        case TARGET(MATCH_SEQUENCE): {
+            // PUSH(not isinstance(TOS, (bytearray, bytes, str))
+            //      and isinstance(TOS, _collections_abc.Sequence))
+            PyObject *subject = TOP();
+            // Fast path for lists and tuples:
+            if (PyType_FastSubclass(Py_TYPE(subject),
+                                    Py_TPFLAGS_LIST_SUBCLASS |
+                                    Py_TPFLAGS_TUPLE_SUBCLASS))
+            {
+                Py_INCREF(Py_True);
+                PUSH(Py_True);
+                DISPATCH();
+            }
+            // Bail on some possible Sequences that we intentionally exclude:
+            if (PyType_FastSubclass(Py_TYPE(subject),
+                                    Py_TPFLAGS_BYTES_SUBCLASS |
+                                    Py_TPFLAGS_UNICODE_SUBCLASS) ||
+                PyByteArray_Check(subject))
+            {
+                Py_INCREF(Py_False);
+                PUSH(Py_False);
+                DISPATCH();
+            }
+            // Lazily import _collections_abc.Sequence, and keep it handy on the
+            // PyInterpreterState struct (it gets cleaned up at exit):
+            PyInterpreterState *interp = PyInterpreterState_Get();
+            if (interp->seq_abc == NULL) {
+                PyObject *abc = PyImport_ImportModule("_collections_abc");
+                if (abc == NULL) {
+                    goto error;
+                }
+                interp->seq_abc = PyObject_GetAttrString(abc, "Sequence");
+                if (interp->seq_abc == NULL) {
+                    goto error;
+                }
+            }
+            int match = PyObject_IsInstance(subject, interp->seq_abc);
+            if (match < 0) {
+                goto error;
+            }
+            PUSH(PyBool_FromLong(match));
+            DISPATCH();
+        }
+
+        case TARGET(MATCH_KEYS): {
+            // On successful match for all keys, PUSH(values) and PUSH(True).
+            // Otherwise, PUSH(None) and PUSH(False).
+            PyObject *keys = TOP();
+            PyObject *subject = SECOND();
+            PyObject *values_or_none = match_keys(tstate, subject, keys);
+            if (values_or_none == NULL) {
+                goto error;
+            }
+            PUSH(values_or_none);
+            if (values_or_none == Py_None) {
+                Py_INCREF(Py_False);
+                PUSH(Py_False);
+                DISPATCH();
+            }
+            assert(PyTuple_CheckExact(values_or_none));
+            Py_INCREF(Py_True);
+            PUSH(Py_True);
+            DISPATCH();
+        }
+
+        case TARGET(COPY_DICT_WITHOUT_KEYS): {
+            // rest = dict(TOS1)
+            // for key in TOS:
+            //     del rest[key]
+            // SET_TOP(rest)
+            PyObject *keys = TOP();
+            PyObject *subject = SECOND();
+            PyObject *rest = PyDict_New();
+            if (rest == NULL || PyDict_Update(rest, subject)) {
+                Py_XDECREF(rest);
+                goto error;
+            }
+            // This may seem a bit inefficient, but keys is rarely big enough to
+            // actually impact runtime.
+            assert(PyTuple_CheckExact(keys));
+            for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(keys); i++) {
+                if (PyDict_DelItem(rest, PyTuple_GET_ITEM(keys, i))) {
+                    Py_DECREF(rest);
+                    goto error;
+                }
+            }
+            Py_DECREF(keys);
+            SET_TOP(rest);
+            DISPATCH();
+        }
+
         case TARGET(GET_ITER): {
             /* before: [obj]; after [getiter(obj)] */
             PyObject *iterable = TOP();