Fix py::make_iterator's __next__() for past-the-end calls
Fixes #896.
From Python docs: "Once an iterator’s `__next__()` method raises
`StopIteration`, it must continue to do so on subsequent calls.
Implementations that do not obey this property are deemed broken."
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 786e36f..6d25fa5 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -1353,7 +1353,7 @@
struct iterator_state {
Iterator it;
Sentinel end;
- bool first;
+ bool first_or_done;
};
NAMESPACE_END(detail)
@@ -1374,17 +1374,19 @@
class_<state>(handle(), "iterator")
.def("__iter__", [](state &s) -> state& { return s; })
.def("__next__", [](state &s) -> ValueType {
- if (!s.first)
+ if (!s.first_or_done)
++s.it;
else
- s.first = false;
- if (s.it == s.end)
+ s.first_or_done = false;
+ if (s.it == s.end) {
+ s.first_or_done = true;
throw stop_iteration();
+ }
return *s.it;
}, std::forward<Extra>(extra)..., Policy);
}
- return (iterator) cast(state { first, last, true });
+ return cast(state{first, last, true});
}
/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a
@@ -1401,17 +1403,19 @@
class_<state>(handle(), "iterator")
.def("__iter__", [](state &s) -> state& { return s; })
.def("__next__", [](state &s) -> KeyType {
- if (!s.first)
+ if (!s.first_or_done)
++s.it;
else
- s.first = false;
- if (s.it == s.end)
+ s.first_or_done = false;
+ if (s.it == s.end) {
+ s.first_or_done = true;
throw stop_iteration();
+ }
return (*s.it).first;
}, std::forward<Extra>(extra)..., Policy);
}
- return (iterator) cast(state { first, last, true });
+ return cast(state{first, last, true});
}
/// Makes an iterator over values of an stl container or other container supporting
diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py
index 30b6aaf..e04c579 100644
--- a/tests/test_sequences_and_iterators.py
+++ b/tests/test_sequences_and_iterators.py
@@ -21,6 +21,17 @@
assert list(IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
assert list(IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []
+ # __next__ must continue to raise StopIteration
+ it = IntPairs([(0, 0)]).nonzero()
+ for _ in range(3):
+ with pytest.raises(StopIteration):
+ next(it)
+
+ it = IntPairs([(0, 0)]).nonzero_keys()
+ for _ in range(3):
+ with pytest.raises(StopIteration):
+ next(it)
+
def test_sequence():
from pybind11_tests import ConstructorStats
@@ -45,6 +56,12 @@
rev2 = s[::-1]
assert cstats.values() == ['of size', '5']
+ it = iter(Sequence(0))
+ for _ in range(3): # __next__ must continue to raise StopIteration
+ with pytest.raises(StopIteration):
+ next(it)
+ assert cstats.values() == ['of size', '0']
+
expected = [0, 56.78, 0, 0, 12.34]
assert allclose(rev, expected)
assert allclose(rev2, expected)
@@ -55,6 +72,8 @@
assert allclose(rev, [2, 56.78, 2, 0, 2])
+ assert cstats.alive() == 4
+ del it
assert cstats.alive() == 3
del s
assert cstats.alive() == 2
@@ -90,6 +109,11 @@
for k, v in m.items():
assert v == expected[k]
+ it = iter(StringMap({}))
+ for _ in range(3): # __next__ must continue to raise StopIteration
+ with pytest.raises(StopIteration):
+ next(it)
+
def test_python_iterator_in_cpp():
import pybind11_tests.sequences_and_iterators as m