bpo-43977: Use tp_flags for collection matching (GH-25723)
* Add Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING, add to all relevant standard builtin classes.
* Set relevant flags on collections.abc.Sequence and Mapping.
* Use flags in MATCH_SEQUENCE and MATCH_MAPPING opcodes.
* Inherit Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING.
* Add NEWS
* Remove interpreter-state map_abc and seq_abc fields.
diff --git a/Modules/_abc.c b/Modules/_abc.c
index 0ddc2ab..39261dd 100644
--- a/Modules/_abc.c
+++ b/Modules/_abc.c
@@ -15,6 +15,7 @@ PyDoc_STRVAR(_abc__doc__,
_Py_IDENTIFIER(__abstractmethods__);
_Py_IDENTIFIER(__class__);
_Py_IDENTIFIER(__dict__);
+_Py_IDENTIFIER(__abc_tpflags__);
_Py_IDENTIFIER(__bases__);
_Py_IDENTIFIER(_abc_impl);
_Py_IDENTIFIER(__subclasscheck__);
@@ -417,6 +418,8 @@ compute_abstract_methods(PyObject *self)
return ret;
}
+#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
+
/*[clinic input]
_abc._abc_init
@@ -446,6 +449,31 @@ _abc__abc_init(PyObject *module, PyObject *self)
return NULL;
}
Py_DECREF(data);
+ /* If __abc_tpflags__ & COLLECTION_FLAGS is set, then set the corresponding bit(s)
+ * in the new class.
+ * Used by collections.abc.Sequence and collections.abc.Mapping to indicate
+ * their special status w.r.t. pattern matching. */
+ if (PyType_Check(self)) {
+ PyTypeObject *cls = (PyTypeObject *)self;
+ PyObject *flags = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___abc_tpflags__);
+ if (flags == NULL) {
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ }
+ else {
+ if (PyLong_CheckExact(flags)) {
+ long val = PyLong_AsLong(flags);
+ if (val == -1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ ((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
+ }
+ if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
+ return NULL;
+ }
+ }
+ }
Py_RETURN_NONE;
}
@@ -499,6 +527,11 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
/* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++;
+ if (PyType_Check(subclass) && PyType_Check(self) &&
+ !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE))
+ {
+ ((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
+ }
Py_INCREF(subclass);
return subclass;
}
diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c
index 8b01a7f..9c8701a 100644
--- a/Modules/_collectionsmodule.c
+++ b/Modules/_collectionsmodule.c
@@ -1662,7 +1662,8 @@ static PyTypeObject deque_type = {
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
- Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
+ Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_SEQUENCE,
/* tp_flags */
deque_doc, /* tp_doc */
(traverseproc)deque_traverse, /* tp_traverse */
diff --git a/Modules/arraymodule.c b/Modules/arraymodule.c
index 367621f..d65c144 100644
--- a/Modules/arraymodule.c
+++ b/Modules/arraymodule.c
@@ -2848,7 +2848,8 @@ static PyType_Spec array_spec = {
.name = "array.array",
.basicsize = sizeof(arrayobject),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
- Py_TPFLAGS_IMMUTABLETYPE),
+ Py_TPFLAGS_IMMUTABLETYPE |
+ Py_TPFLAGS_SEQUENCE),
.slots = array_slots,
};