Have itertools.chain() consume its inputs lazily instead of building a tuple of iterators at the outset.
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 2ee947d..3b8339c 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1601,92 +1601,92 @@
typedef struct {
PyObject_HEAD
- Py_ssize_t tuplesize;
- Py_ssize_t iternum; /* which iterator is active */
- PyObject *ittuple; /* tuple of iterators */
+ PyObject *source; /* Iterator over input iterables */
+ PyObject *active; /* Currently running input iterator */
} chainobject;
static PyTypeObject chain_type;
+static PyObject *
+chain_new_internal(PyTypeObject *type, PyObject *source)
+{
+ chainobject *lz;
+
+ lz = (chainobject *)type->tp_alloc(type, 0);
+ if (lz == NULL) {
+ Py_DECREF(source);
+ return NULL;
+ }
+
+ lz->source = source;
+ lz->active = NULL;
+ return (PyObject *)lz;
+}
+
static PyObject *
chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
- chainobject *lz;
- Py_ssize_t tuplesize = PySequence_Length(args);
- Py_ssize_t i;
- PyObject *ittuple;
+ PyObject *source;
if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds))
return NULL;
-
- /* obtain iterators */
- assert(PyTuple_Check(args));
- ittuple = PyTuple_New(tuplesize);
- if (ittuple == NULL)
+
+ source = PyObject_GetIter(args);
+ if (source == NULL)
return NULL;
- for (i=0; i < tuplesize; ++i) {
- PyObject *item = PyTuple_GET_ITEM(args, i);
- PyObject *it = PyObject_GetIter(item);
- if (it == NULL) {
- if (PyErr_ExceptionMatches(PyExc_TypeError))
- PyErr_Format(PyExc_TypeError,
- "chain argument #%zd must support iteration",
- i+1);
- Py_DECREF(ittuple);
- return NULL;
- }
- PyTuple_SET_ITEM(ittuple, i, it);
- }
- /* create chainobject structure */
- lz = (chainobject *)type->tp_alloc(type, 0);
- if (lz == NULL) {
- Py_DECREF(ittuple);
- return NULL;
- }
-
- lz->ittuple = ittuple;
- lz->iternum = 0;
- lz->tuplesize = tuplesize;
-
- return (PyObject *)lz;
+ return chain_new_internal(type, source);
}
static void
chain_dealloc(chainobject *lz)
{
PyObject_GC_UnTrack(lz);
- Py_XDECREF(lz->ittuple);
+ Py_XDECREF(lz->active);
+ Py_XDECREF(lz->source);
Py_TYPE(lz)->tp_free(lz);
}
static int
chain_traverse(chainobject *lz, visitproc visit, void *arg)
{
- Py_VISIT(lz->ittuple);
+ Py_VISIT(lz->source);
+ Py_VISIT(lz->active);
return 0;
}
static PyObject *
chain_next(chainobject *lz)
{
- PyObject *it;
PyObject *item;
- while (lz->iternum < lz->tuplesize) {
- it = PyTuple_GET_ITEM(lz->ittuple, lz->iternum);
- item = PyIter_Next(it);
- if (item != NULL)
- return item;
- if (PyErr_Occurred()) {
- if (PyErr_ExceptionMatches(PyExc_StopIteration))
- PyErr_Clear();
- else
- return NULL;
+ if (lz->source == NULL)
+ return NULL; /* already stopped */
+
+ if (lz->active == NULL) {
+ PyObject *iterable = PyIter_Next(lz->source);
+ if (iterable == NULL) {
+ Py_CLEAR(lz->source);
+ return NULL; /* no more input sources */
}
- lz->iternum++;
+ lz->active = PyObject_GetIter(iterable);
+ if (lz->active == NULL) {
+ Py_DECREF(iterable);
+ Py_CLEAR(lz->source);
+ return NULL; /* input not iterable */
+ }
}
- return NULL;
+ item = PyIter_Next(lz->active);
+ if (item != NULL)
+ return item;
+ if (PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_StopIteration))
+ PyErr_Clear();
+ else
+ return NULL; /* input raised an exception */
+ }
+ Py_CLEAR(lz->active);
+ return chain_next(lz); /* recurse and use next active */
}
PyDoc_STRVAR(chain_doc,