bpo-40636: PEP 618: add strict parameter to zip() (GH-20921)

zip() now supports PEP 618's strict parameter, which raises a
ValueError if the arguments are exhausted at different lengths.
Patch by Brandt Bucher.

Co-authored-by: Brandt Bucher <brandtbucher@gmail.com>
Co-authored-by: Ram Rachum <ram@rachum.com>
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index 290ba2c..40df7b6 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -1521,6 +1521,14 @@
         self.assertRaises(TypeError, vars, 42)
         self.assertEqual(vars(self.C_get_vars()), {'a':2})
 
+    def iter_error(self, iterable, error):
+        """Collect `iterable` into a list, catching an expected `error`."""
+        items = []
+        with self.assertRaises(error):
+            for item in iterable:
+                items.append(item)
+        return items
+
     def test_zip(self):
         a = (1, 2, 3)
         b = (4, 5, 6)
@@ -1573,6 +1581,66 @@
             z1 = zip(a, b)
             self.check_iter_pickle(z1, t, proto)
 
+    def test_zip_pickle_strict(self):
+        a = (1, 2, 3)
+        b = (4, 5, 6)
+        t = [(1, 4), (2, 5), (3, 6)]
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            z1 = zip(a, b, strict=True)
+            self.check_iter_pickle(z1, t, proto)
+
+    def test_zip_pickle_strict_fail(self):
+        a = (1, 2, 3)
+        b = (4, 5, 6, 7)
+        t = [(1, 4), (2, 5), (3, 6)]
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            z1 = zip(a, b, strict=True)
+            z2 = pickle.loads(pickle.dumps(z1, proto))
+            self.assertEqual(self.iter_error(z1, ValueError), t)
+            self.assertEqual(self.iter_error(z2, ValueError), t)
+
+    def test_zip_pickle_stability(self):
+        # Pickles of zip((1, 2, 3), (4, 5, 6)) dumped from 3.9:
+        pickles = [
+            b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\nI6\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\n.',
+            b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05K\x06tq\x05tq\x06Rq\x07K\x00btq\x08Rq\t.',
+            b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.',
+            b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.',
+            b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.',
+            b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.',
+        ]
+        for protocol, dump in enumerate(pickles):
+            z1 = zip((1, 2, 3), (4, 5, 6))
+            z2 = zip((1, 2, 3), (4, 5, 6), strict=False)
+            z3 = pickle.loads(dump)
+            l3 = list(z3)
+            self.assertEqual(type(z3), zip)
+            self.assertEqual(pickle.dumps(z1, protocol), dump)
+            self.assertEqual(pickle.dumps(z2, protocol), dump)
+            self.assertEqual(list(z1), l3)
+            self.assertEqual(list(z2), l3)
+
+    def test_zip_pickle_strict_stability(self):
+        # Pickles of zip((1, 2, 3), (4, 5), strict=True) dumped from 3.10:
+        pickles = [
+            b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\nI01\nb.',
+            b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05tq\x05tq\x06Rq\x07K\x00btq\x08Rq\tI01\nb.',
+            b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.',
+            b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.',
+            b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.',
+            b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.',
+        ]
+        a = (1, 2, 3)
+        b = (4, 5)
+        t = [(1, 4), (2, 5)]
+        for protocol, dump in enumerate(pickles):
+            z1 = zip(a, b, strict=True)
+            z2 = pickle.loads(dump)
+            self.assertEqual(pickle.dumps(z1, protocol), dump)
+            self.assertEqual(type(z2), zip)
+            self.assertEqual(self.iter_error(z1, ValueError), t)
+            self.assertEqual(self.iter_error(z2, ValueError), t)
+
     def test_zip_bad_iterable(self):
         exception = TypeError()
 
@@ -1585,6 +1653,88 @@
 
         self.assertIs(cm.exception, exception)
 
+    def test_zip_strict(self):
+        self.assertEqual(tuple(zip((1, 2, 3), 'abc', strict=True)),
+                         ((1, 'a'), (2, 'b'), (3, 'c')))
+        self.assertRaises(ValueError, tuple,
+                          zip((1, 2, 3, 4), 'abc', strict=True))
+        self.assertRaises(ValueError, tuple,
+                          zip((1, 2), 'abc', strict=True))
+        self.assertRaises(ValueError, tuple,
+                          zip((1, 2), (1, 2), 'abc', strict=True))
+
+    def test_zip_strict_iterators(self):
+        x = iter(range(5))
+        y = [0]
+        z = iter(range(5))
+        self.assertRaises(ValueError, list,
+                          (zip(x, y, z, strict=True)))
+        self.assertEqual(next(x), 2)
+        self.assertEqual(next(z), 1)
+
+    def test_zip_strict_error_handling(self):
+
+        class Error(Exception):
+            pass
+
+        class Iter:
+            def __init__(self, size):
+                self.size = size
+            def __iter__(self):
+                return self
+            def __next__(self):
+                self.size -= 1
+                if self.size < 0:
+                    raise Error
+                return self.size
+
+        l1 = self.iter_error(zip("AB", Iter(1), strict=True), Error)
+        self.assertEqual(l1, [("A", 0)])
+        l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l2, [("A", 1, "A")])
+        l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), Error)
+        self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
+        l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError)
+        self.assertEqual(l4, [("A", 2), ("B", 1)])
+        l5 = self.iter_error(zip(Iter(1), "AB", strict=True), Error)
+        self.assertEqual(l5, [(0, "A")])
+        l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l6, [(1, "A")])
+        l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), Error)
+        self.assertEqual(l7, [(1, "A"), (0, "B")])
+        l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError)
+        self.assertEqual(l8, [(2, "A"), (1, "B")])
+
+    def test_zip_strict_error_handling_stopiteration(self):
+
+        class Iter:
+            def __init__(self, size):
+                self.size = size
+            def __iter__(self):
+                return self
+            def __next__(self):
+                self.size -= 1
+                if self.size < 0:
+                    raise StopIteration
+                return self.size
+
+        l1 = self.iter_error(zip("AB", Iter(1), strict=True), ValueError)
+        self.assertEqual(l1, [("A", 0)])
+        l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l2, [("A", 1, "A")])
+        l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), ValueError)
+        self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
+        l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError)
+        self.assertEqual(l4, [("A", 2), ("B", 1)])
+        l5 = self.iter_error(zip(Iter(1), "AB", strict=True), ValueError)
+        self.assertEqual(l5, [(0, "A")])
+        l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError)
+        self.assertEqual(l6, [(1, "A")])
+        l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), ValueError)
+        self.assertEqual(l7, [(1, "A"), (0, "B")])
+        l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError)
+        self.assertEqual(l8, [(2, "A"), (1, "B")])
+
     def test_format(self):
         # Test the basic machinery of the format() builtin.  Don't test
         #  the specifics of the various formatters
diff --git a/Misc/NEWS.d/next/Core and Builtins/2020-06-17-10-27-17.bpo-40636.MYaCIe.rst b/Misc/NEWS.d/next/Core and Builtins/2020-06-17-10-27-17.bpo-40636.MYaCIe.rst
new file mode 100644
index 0000000..ba26ad9
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2020-06-17-10-27-17.bpo-40636.MYaCIe.rst
@@ -0,0 +1,3 @@
+:func:`zip` now supports :pep:`618`'s ``strict`` parameter, which raises a
+:exc:`ValueError` if the arguments are exhausted at different lengths.
+Patch by Brandt Bucher.
diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c
index 65f9528..c6ede1c 100644
--- a/Python/bltinmodule.c
+++ b/Python/bltinmodule.c
@@ -2517,9 +2517,10 @@
 
 typedef struct {
     PyObject_HEAD
-    Py_ssize_t          tuplesize;
-    PyObject *ittuple;                  /* tuple of iterators */
+    Py_ssize_t tuplesize;
+    PyObject *ittuple;     /* tuple of iterators */
     PyObject *result;
+    int strict;
 } zipobject;
 
 static PyObject *
@@ -2530,9 +2531,21 @@
     PyObject *ittuple;  /* tuple of iterators */
     PyObject *result;
     Py_ssize_t tuplesize;
+    int strict = 0;
 
-    if (type == &PyZip_Type && !_PyArg_NoKeywords("zip", kwds))
-        return NULL;
+    if (kwds) {
+        PyObject *empty = PyTuple_New(0);
+        if (empty == NULL) {
+            return NULL;
+        }
+        static char *kwlist[] = {"strict", NULL};
+        int parsed = PyArg_ParseTupleAndKeywords(
+                empty, kwds, "|$p:zip", kwlist, &strict);
+        Py_DECREF(empty);
+        if (!parsed) {
+            return NULL;
+        }
+    }
 
     /* args must be a tuple */
     assert(PyTuple_Check(args));
@@ -2573,6 +2586,7 @@
     lz->ittuple = ittuple;
     lz->tuplesize = tuplesize;
     lz->result = result;
+    lz->strict = strict;
 
     return (PyObject *)lz;
 }
@@ -2613,6 +2627,9 @@
             item = (*Py_TYPE(it)->tp_iternext)(it);
             if (item == NULL) {
                 Py_DECREF(result);
+                if (lz->strict) {
+                    goto check;
+                }
                 return NULL;
             }
             olditem = PyTuple_GET_ITEM(result, i);
@@ -2628,28 +2645,85 @@
             item = (*Py_TYPE(it)->tp_iternext)(it);
             if (item == NULL) {
                 Py_DECREF(result);
+                if (lz->strict) {
+                    goto check;
+                }
                 return NULL;
             }
             PyTuple_SET_ITEM(result, i, item);
         }
     }
     return result;
+check:
+    if (PyErr_Occurred()) {
+        if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
+            // next() on argument i raised an exception (not StopIteration)
+            return NULL;
+        }
+        PyErr_Clear();
+    }
+    if (i) {
+        // ValueError: zip() argument 2 is shorter than argument 1
+        // ValueError: zip() argument 3 is shorter than arguments 1-2
+        const char* plural = i == 1 ? " " : "s 1-";
+        return PyErr_Format(PyExc_ValueError,
+                            "zip() argument %d is shorter than argument%s%d",
+                            i + 1, plural, i);
+    }
+    for (i = 1; i < tuplesize; i++) {
+        it = PyTuple_GET_ITEM(lz->ittuple, i);
+        item = (*Py_TYPE(it)->tp_iternext)(it);
+        if (item) {
+            Py_DECREF(item);
+            const char* plural = i == 1 ? " " : "s 1-";
+            return PyErr_Format(PyExc_ValueError,
+                                "zip() argument %d is longer than argument%s%d",
+                                i + 1, plural, i);
+        }
+        if (PyErr_Occurred()) {
+            if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
+                // next() on argument i raised an exception (not StopIteration)
+                return NULL;
+            }
+            PyErr_Clear();
+        }
+        // Argument i is exhausted. So far so good...
+    }
+    // All arguments are exhausted. Success!
+    return NULL;
 }
 
 static PyObject *
 zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
 {
     /* Just recreate the zip with the internal iterator tuple */
-    return Py_BuildValue("OO", Py_TYPE(lz), lz->ittuple);
+    if (lz->strict) {
+        return PyTuple_Pack(3, Py_TYPE(lz), lz->ittuple, Py_True);
+    }
+    return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
+}
+
+PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
+
+static PyObject *
+zip_setstate(zipobject *lz, PyObject *state)
+{
+    int strict = PyObject_IsTrue(state);
+    if (strict < 0) {
+        return NULL;
+    }
+    lz->strict = strict;
+    Py_RETURN_NONE;
 }
 
 static PyMethodDef zip_methods[] = {
     {"__reduce__",   (PyCFunction)zip_reduce,   METH_NOARGS, reduce_doc},
-    {NULL,           NULL}           /* sentinel */
+    {"__setstate__", (PyCFunction)zip_setstate, METH_O,      setstate_doc},
+    {NULL}  /* sentinel */
 };
 
 PyDoc_STRVAR(zip_doc,
-"zip(*iterables) --> A zip object yielding tuples until an input is exhausted.\n\
+"zip(*iterables, strict=False) --> Yield tuples until an input is exhausted.\n\
 \n\
    >>> list(zip('abcdefg', range(3), range(4)))\n\
    [('a', 0, 0), ('b', 1, 1), ('c', 2, 2)]\n\
@@ -2657,7 +2731,10 @@
 The zip object yields n-length tuples, where n is the number of iterables\n\
 passed as positional arguments to zip().  The i-th element in every tuple\n\
 comes from the i-th iterable argument to zip().  This continues until the\n\
-shortest argument is exhausted.");
+shortest argument is exhausted.\n\
+\n\
+If strict is true and one of the arguments is exhausted before the others,\n\
+raise a ValueError.");
 
 PyTypeObject PyZip_Type = {
     PyVarObject_HEAD_INIT(&PyType_Type, 0)