Make py::iterator compatible with std algorithms
The added type aliases are required by `std::iterator_traits`.
Python iterators satisfy the `InputIterator` concept in C++.
diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h
index 3579da1..f09b5fe 100644
--- a/include/pybind11/pytypes.h
+++ b/include/pybind11/pytypes.h
@@ -602,6 +602,12 @@
\endrst */
class iterator : public object {
public:
+ using iterator_category = std::input_iterator_tag;
+ using difference_type = ssize_t;
+ using value_type = handle;
+ using reference = const handle;
+ using pointer = const handle *;
+
PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check)
iterator& operator++() {
@@ -615,7 +621,7 @@
return rv;
}
- handle operator*() const {
+ reference operator*() const {
if (m_ptr && !value.ptr()) {
auto& self = const_cast<iterator &>(*this);
self.advance();
@@ -623,7 +629,7 @@
return value;
}
- const handle *operator->() const { operator*(); return &value; }
+ pointer operator->() const { operator*(); return &value; }
/** \rst
The value which marks the end of the iteration. ``it == iterator::sentinel()``
diff --git a/tests/test_sequences_and_iterators.cpp b/tests/test_sequences_and_iterators.cpp
index 6401882..cda0af4 100644
--- a/tests/test_sequences_and_iterators.cpp
+++ b/tests/test_sequences_and_iterators.cpp
@@ -290,4 +290,14 @@
}
return l;
});
+
+ // Make sure that py::iterator works with std algorithms
+ m.def("count_none", [](py::object o) {
+ return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
+ });
+
+ m.def("find_none", [](py::object o) {
+ auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
+ return it->is_none();
+ });
});
diff --git a/tests/test_sequences_and_iterators.py b/tests/test_sequences_and_iterators.py
index b340451..3066647 100644
--- a/tests/test_sequences_and_iterators.py
+++ b/tests/test_sequences_and_iterators.py
@@ -113,3 +113,7 @@
with pytest.raises(RuntimeError) as excinfo:
m.iterator_to_list(iter(bad_next_call, None))
assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
+
+ l = [1, None, 0, None]
+ assert m.count_none(l) == 2
+ assert m.find_none(l) is True