bpo-41428: Implementation for PEP 604 (GH-21515)

See https://www.python.org/dev/peps/pep-0604/ for more information.

Co-authored-by: Pablo Galindo <pablogsal@gmail.com>
diff --git a/Objects/abstract.c b/Objects/abstract.c
index 7bd72c9..c471f18 100644
--- a/Objects/abstract.c
+++ b/Objects/abstract.c
@@ -1,6 +1,7 @@
 /* Abstract Object Interface (many thanks to Jim Fulton) */
 
 #include "Python.h"
+#include "pycore_unionobject.h"      // _Py_UnionType && _Py_Union()
 #include "pycore_abstract.h"      // _PyIndex_Check()
 #include "pycore_ceval.h"         // _Py_EnterRecursiveCall()
 #include "pycore_pyerrors.h"      // _PyErr_Occurred()
@@ -839,7 +840,6 @@
                 Py_TYPE(w)->tp_name);
             return NULL;
         }
-
         return binop_type_error(v, w, op_name);
     }
     return result;
@@ -2412,7 +2412,6 @@
     PyObject *icls;
     int retval;
     _Py_IDENTIFIER(__class__);
-
     if (PyType_Check(cls)) {
         retval = PyObject_TypeCheck(inst, (PyTypeObject *)cls);
         if (retval == 0) {
@@ -2432,7 +2431,7 @@
     }
     else {
         if (!check_class(cls,
-            "isinstance() arg 2 must be a type or tuple of types"))
+            "isinstance() arg 2 must be a type, a tuple of types or a union"))
             return -1;
         retval = _PyObject_LookupAttrId(inst, &PyId___class__, &icls);
         if (icls != NULL) {
@@ -2525,10 +2524,14 @@
     if (!check_class(derived,
                      "issubclass() arg 1 must be a class"))
         return -1;
-    if (!check_class(cls,
-                    "issubclass() arg 2 must be a class"
-                    " or tuple of classes"))
+
+    PyTypeObject *type = Py_TYPE(cls);
+    int is_union = (PyType_Check(type) && type == &_Py_UnionType);
+    if (!is_union && !check_class(cls,
+                            "issubclass() arg 2 must be a class,"
+                            " a tuple of classes, or a union.")) {
         return -1;
+    }
 
     return abstract_issubclass(derived, cls);
 }
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 7404075..3bb2c33 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -6,6 +6,7 @@
 #include "pycore_object.h"
 #include "pycore_pyerrors.h"
 #include "pycore_pystate.h"       // _PyThreadState_GET()
+#include "pycore_unionobject.h"   // _Py_Union()
 #include "frameobject.h"
 #include "structmember.h"         // PyMemberDef
 
@@ -3753,6 +3754,21 @@
     return type->tp_flags & Py_TPFLAGS_HEAPTYPE;
 }
 
+static PyObject *
+type_or(PyTypeObject* self, PyObject* param) {
+    PyObject *tuple = PyTuple_Pack(2, self, param);
+    if (tuple == NULL) {
+        return NULL;
+    }
+    PyObject *new_union = _Py_Union(tuple);
+    Py_DECREF(tuple);
+    return new_union;
+}
+
+static PyNumberMethods type_as_number = {
+        .nb_or = (binaryfunc)type_or, // Add __or__ function
+};
+
 PyTypeObject PyType_Type = {
     PyVarObject_HEAD_INIT(&PyType_Type, 0)
     "type",                                     /* tp_name */
@@ -3764,7 +3780,7 @@
     0,                                          /* tp_setattr */
     0,                                          /* tp_as_async */
     (reprfunc)type_repr,                        /* tp_repr */
-    0,                                          /* tp_as_number */
+    &type_as_number,                            /* tp_as_number */
     0,                                          /* tp_as_sequence */
     0,                                          /* tp_as_mapping */
     0,                                          /* tp_hash */
@@ -5598,7 +5614,6 @@
             add_subclass((PyTypeObject *)b, type) < 0)
             goto error;
     }
-
     /* All done -- set the ready flag */
     type->tp_flags =
         (type->tp_flags & ~Py_TPFLAGS_READYING) | Py_TPFLAGS_READY;
diff --git a/Objects/unionobject.c b/Objects/unionobject.c
new file mode 100644
index 0000000..0ef7abb
--- /dev/null
+++ b/Objects/unionobject.c
@@ -0,0 +1,464 @@
+// types.Union -- used to represent e.g. Union[int, str], int | str
+#include "Python.h"
+#include "pycore_unionobject.h"
+#include "structmember.h"
+
+
+typedef struct {
+    PyObject_HEAD
+    PyObject *args;
+} unionobject;
+
+static void
+unionobject_dealloc(PyObject *self)
+{
+    unionobject *alias = (unionobject *)self;
+
+    Py_XDECREF(alias->args);
+    self->ob_type->tp_free(self);
+}
+
+static Py_hash_t
+union_hash(PyObject *self)
+{
+    unionobject *alias = (unionobject *)self;
+    Py_hash_t h1 = PyObject_Hash(alias->args);
+    if (h1 == -1) {
+        return -1;
+    }
+    return h1;
+}
+
+static int
+is_generic_alias_in_args(PyObject *args) {
+    Py_ssize_t nargs = PyTuple_GET_SIZE(args);
+    for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
+        PyObject *arg = PyTuple_GET_ITEM(args, iarg);
+        if (Py_TYPE(arg) == &Py_GenericAliasType) {
+            return 0;
+        }
+    }
+    return 1;
+}
+
+static PyObject *
+union_instancecheck(PyObject *self, PyObject *instance)
+{
+    unionobject *alias = (unionobject *) self;
+    Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args);
+    if (!is_generic_alias_in_args(alias->args)) {
+        PyErr_SetString(PyExc_TypeError,
+            "isinstance() argument 2 cannot contain a parameterized generic");
+        return NULL;
+    }
+    for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
+        PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg);
+        if (arg == Py_None) {
+            arg = (PyObject *)&_PyNone_Type;
+        }
+        if (PyType_Check(arg) && PyObject_IsInstance(instance, arg) != 0) {
+            Py_RETURN_TRUE;
+        }
+    }
+    Py_RETURN_FALSE;
+}
+
+static PyObject *
+union_subclasscheck(PyObject *self, PyObject *instance)
+{
+    if (!PyType_Check(instance)) {
+        PyErr_SetString(PyExc_TypeError, "issubclass() arg 1 must be a class");
+        return NULL;
+    }
+    unionobject *alias = (unionobject *)self;
+    if (!is_generic_alias_in_args(alias->args)) {
+        PyErr_SetString(PyExc_TypeError,
+            "issubclass() argument 2 cannot contain a parameterized generic");
+        return NULL;
+    }
+    Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args);
+    for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
+        PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg);
+        if (PyType_Check(arg) && (PyType_IsSubtype((PyTypeObject *)instance, (PyTypeObject *)arg) != 0)) {
+            Py_RETURN_TRUE;
+        }
+    }
+   Py_RETURN_FALSE;
+}
+
+static int
+is_typing_module(PyObject *obj) {
+    PyObject *module = PyObject_GetAttrString(obj, "__module__");
+    if (module == NULL) {
+        return -1;
+    }
+    int is_typing = PyUnicode_Check(module) && _PyUnicode_EqualToASCIIString(module, "typing");
+    Py_DECREF(module);
+    return is_typing;
+}
+
+static int
+is_typing_name(PyObject *obj, char *name)
+{
+    PyTypeObject *type = Py_TYPE(obj);
+    if (strcmp(type->tp_name, name) != 0) {
+        return 0;
+    }
+    return is_typing_module(obj);
+}
+
+static PyObject *
+union_richcompare(PyObject *a, PyObject *b, int op)
+{
+    PyObject *result = NULL;
+    if (op != Py_EQ && op != Py_NE) {
+        result = Py_NotImplemented;
+        Py_INCREF(result);
+        return result;
+    }
+
+    PyTypeObject *type = Py_TYPE(b);
+
+    PyObject* a_set = PySet_New(((unionobject*)a)->args);
+    if (a_set == NULL) {
+        return NULL;
+    }
+    PyObject* b_set = PySet_New(NULL);
+    if (b_set == NULL) {
+        goto exit;
+    }
+
+    // Populate b_set with the data from the right object
+    int is_typing_union = is_typing_name(b, "_UnionGenericAlias");
+    if (is_typing_union < 0) {
+        goto exit;
+    }
+    if (is_typing_union) {
+        PyObject *b_args = PyObject_GetAttrString(b, "__args__");
+        if (b_args == NULL) {
+            goto exit;
+        }
+        if (!PyTuple_CheckExact(b_args)) {
+            Py_DECREF(b_args);
+            PyErr_SetString(PyExc_TypeError, "__args__ argument of typing.Union object is not a tuple");
+            goto exit;
+        }
+        Py_ssize_t b_arg_length = PyTuple_GET_SIZE(b_args);
+        for (Py_ssize_t i = 0; i < b_arg_length; i++) {
+            PyObject* arg = PyTuple_GET_ITEM(b_args, i);
+            if (arg == (PyObject *)&_PyNone_Type) {
+                arg = Py_None;
+            }
+            if (PySet_Add(b_set, arg) == -1) {
+                Py_DECREF(b_args);
+                goto exit;
+            }
+        }
+        Py_DECREF(b_args);
+    } else if (type == &_Py_UnionType) {
+        PyObject* args = ((unionobject*) b)->args;
+        Py_ssize_t arg_length = PyTuple_GET_SIZE(args);
+        for (Py_ssize_t i = 0; i < arg_length; i++) {
+            PyObject* arg = PyTuple_GET_ITEM(args, i);
+            if (PySet_Add(b_set, arg) == -1) {
+                goto exit;
+            }
+        }
+    } else {
+        if (PySet_Add(b_set, b) == -1) {
+            goto exit;
+        }
+    }
+    result = PyObject_RichCompare(a_set, b_set, op);
+exit:
+    Py_XDECREF(a_set);
+    Py_XDECREF(b_set);
+    return result;
+}
+
+static PyObject*
+flatten_args(PyObject* args)
+{
+    int arg_length = PyTuple_GET_SIZE(args);
+    int total_args = 0;
+    // Get number of total args once it's flattened.
+    for (Py_ssize_t i = 0; i < arg_length; i++) {
+        PyObject *arg = PyTuple_GET_ITEM(args, i);
+        PyTypeObject* arg_type = Py_TYPE(arg);
+        if (arg_type == &_Py_UnionType) {
+            total_args += PyTuple_GET_SIZE(((unionobject*) arg)->args);
+        } else {
+            total_args++;
+        }
+    }
+    // Create new tuple of flattened args.
+    PyObject *flattened_args = PyTuple_New(total_args);
+    if (flattened_args == NULL) {
+        return NULL;
+    }
+    Py_ssize_t pos = 0;
+    for (Py_ssize_t i = 0; i < arg_length; i++) {
+        PyObject *arg = PyTuple_GET_ITEM(args, i);
+        PyTypeObject* arg_type = Py_TYPE(arg);
+        if (arg_type == &_Py_UnionType) {
+            PyObject* nested_args = ((unionobject*)arg)->args;
+            int nested_arg_length = PyTuple_GET_SIZE(nested_args);
+            for (int j = 0; j < nested_arg_length; j++) {
+                PyObject* nested_arg = PyTuple_GET_ITEM(nested_args, j);
+                Py_INCREF(nested_arg);
+                PyTuple_SET_ITEM(flattened_args, pos, nested_arg);
+                pos++;
+            }
+        } else {
+            Py_INCREF(arg);
+            PyTuple_SET_ITEM(flattened_args, pos, arg);
+            pos++;
+        }
+    }
+    return flattened_args;
+}
+
+static PyObject*
+dedup_and_flatten_args(PyObject* args)
+{
+    args = flatten_args(args);
+    if (args == NULL) {
+        return NULL;
+    }
+    Py_ssize_t arg_length = PyTuple_GET_SIZE(args);
+    PyObject *new_args = PyTuple_New(arg_length);
+    if (new_args == NULL) {
+        return NULL;
+    }
+    // Add unique elements to an array.
+    int added_items = 0;
+    for (Py_ssize_t i = 0; i < arg_length; i++) {
+        int is_duplicate = 0;
+        PyObject* i_element = PyTuple_GET_ITEM(args, i);
+        for (Py_ssize_t j = i + 1; j < arg_length; j++) {
+            PyObject* j_element = PyTuple_GET_ITEM(args, j);
+            if (i_element == j_element) {
+                is_duplicate = 1;
+            }
+        }
+        if (!is_duplicate) {
+            Py_INCREF(i_element);
+            PyTuple_SET_ITEM(new_args, added_items, i_element);
+            added_items++;
+        }
+    }
+    Py_DECREF(args);
+    _PyTuple_Resize(&new_args, added_items);
+    return new_args;
+}
+
+static int
+is_typevar(PyObject *obj)
+{
+    return is_typing_name(obj, "TypeVar");
+}
+
+static int
+is_special_form(PyObject *obj)
+{
+    return is_typing_name(obj, "_SpecialForm");
+}
+
+static int
+is_new_type(PyObject *obj)
+{
+    PyTypeObject *type = Py_TYPE(obj);
+    if (type != &PyFunction_Type) {
+        return 0;
+    }
+    return is_typing_module(obj);
+}
+
+static int
+is_unionable(PyObject *obj)
+{
+    if (obj == Py_None) {
+        return 1;
+    }
+    PyTypeObject *type = Py_TYPE(obj);
+    return (
+        is_typevar(obj) ||
+        is_new_type(obj) ||
+        is_special_form(obj) ||
+        PyType_Check(obj) ||
+        type == &Py_GenericAliasType ||
+        type == &_Py_UnionType);
+}
+
+static PyObject *
+type_or(PyTypeObject* self, PyObject* param)
+{
+    PyObject *tuple = PyTuple_Pack(2, self, param);
+    if (tuple == NULL) {
+        return NULL;
+    }
+    PyObject *new_union = _Py_Union(tuple);
+    Py_DECREF(tuple);
+    return new_union;
+}
+
+static int
+union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
+{
+    _Py_IDENTIFIER(__module__);
+    _Py_IDENTIFIER(__qualname__);
+    _Py_IDENTIFIER(__origin__);
+    _Py_IDENTIFIER(__args__);
+    PyObject *qualname = NULL;
+    PyObject *module = NULL;
+    PyObject *r = NULL;
+    int err;
+
+    int has_origin = _PyObject_HasAttrId(p, &PyId___origin__);
+    if (has_origin < 0) {
+        goto exit;
+    }
+
+    if (has_origin) {
+        int has_args = _PyObject_HasAttrId(p, &PyId___args__);
+        if (has_args < 0) {
+            goto exit;
+        }
+        if (has_args) {
+            // It looks like a GenericAlias
+            goto use_repr;
+        }
+    }
+
+    if (_PyObject_LookupAttrId(p, &PyId___qualname__, &qualname) < 0) {
+        goto exit;
+    }
+    if (qualname == NULL) {
+        goto use_repr;
+    }
+    if (_PyObject_LookupAttrId(p, &PyId___module__, &module) < 0) {
+        goto exit;
+    }
+    if (module == NULL || module == Py_None) {
+        goto use_repr;
+    }
+
+    // Looks like a class
+    if (PyUnicode_Check(module) &&
+        _PyUnicode_EqualToASCIIString(module, "builtins"))
+    {
+        // builtins don't need a module name
+        r = PyObject_Str(qualname);
+        goto exit;
+    }
+    else {
+        r = PyUnicode_FromFormat("%S.%S", module, qualname);
+        goto exit;
+    }
+
+use_repr:
+    r = PyObject_Repr(p);
+exit:
+    Py_XDECREF(qualname);
+    Py_XDECREF(module);
+    if (r == NULL) {
+        return -1;
+    }
+    err = _PyUnicodeWriter_WriteStr(writer, r);
+    Py_DECREF(r);
+    return err;
+}
+
+static PyObject *
+union_repr(PyObject *self)
+{
+    unionobject *alias = (unionobject *)self;
+    Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
+
+    _PyUnicodeWriter writer;
+    _PyUnicodeWriter_Init(&writer);
+     for (Py_ssize_t i = 0; i < len; i++) {
+        if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
+            goto error;
+        }
+        PyObject *p = PyTuple_GET_ITEM(alias->args, i);
+        if (union_repr_item(&writer, p) < 0) {
+            goto error;
+        }
+    }
+    return _PyUnicodeWriter_Finish(&writer);
+error:
+    _PyUnicodeWriter_Dealloc(&writer);
+    return NULL;
+}
+
+static PyMemberDef union_members[] = {
+        {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
+        {0}
+};
+
+static PyMethodDef union_methods[] = {
+        {"__instancecheck__", union_instancecheck, METH_O},
+        {"__subclasscheck__", union_subclasscheck, METH_O},
+        {0}};
+
+static PyNumberMethods union_as_number = {
+        .nb_or = (binaryfunc)type_or, // Add __or__ function
+};
+
+PyTypeObject _Py_UnionType = {
+    PyVarObject_HEAD_INIT(&PyType_Type, 0)
+    .tp_name = "types.Union",
+    .tp_doc = "Represent a PEP 604 union type\n"
+              "\n"
+              "E.g. for int | str",
+    .tp_basicsize = sizeof(unionobject),
+    .tp_dealloc = unionobject_dealloc,
+    .tp_alloc = PyType_GenericAlloc,
+    .tp_free = PyObject_Del,
+    .tp_flags = Py_TPFLAGS_DEFAULT,
+    .tp_hash = union_hash,
+    .tp_getattro = PyObject_GenericGetAttr,
+    .tp_members = union_members,
+    .tp_methods = union_methods,
+    .tp_richcompare = union_richcompare,
+    .tp_as_number = &union_as_number,
+    .tp_repr = union_repr,
+};
+
+PyObject *
+_Py_Union(PyObject *args)
+{
+    assert(PyTuple_CheckExact(args));
+
+    unionobject* result = NULL;
+
+    // Check arguments are unionable.
+    int nargs = PyTuple_GET_SIZE(args);
+    for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
+        PyObject *arg = PyTuple_GET_ITEM(args, iarg);
+        if (arg == NULL) {
+            return NULL;
+        }
+        int is_arg_unionable = is_unionable(arg);
+        if (is_arg_unionable < 0) {
+            return NULL;
+        }
+        if (!is_arg_unionable) {
+            Py_INCREF(Py_NotImplemented);
+            return Py_NotImplemented;
+        }
+    }
+
+    result = PyObject_New(unionobject, &_Py_UnionType);
+    if (result == NULL) {
+        return NULL;
+    }
+
+    result->args = dedup_and_flatten_args(args);
+    if (result->args == NULL) {
+        Py_DECREF(result);
+        return NULL;
+    }
+    return (PyObject*)result;
+}