Support more natural syntax for vector extend
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 606be41..28791c8 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -94,6 +94,11 @@
* A few minor typo fixes and improvements to the test suite, and
patches that silence compiler warnings.
+* Vectors now support construction from generators, as well as ``extend()`` from a
+ list or generator.
+ `#1496 <https://github.com/pybind/pybind11/pull/1496>`_.
+
+
v2.2.3 (April 29, 2018)
-----------------------------------------------------
diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h
index b4f4be9..db7dfec 100644
--- a/include/pybind11/pytypes.h
+++ b/include/pybind11/pytypes.h
@@ -1346,6 +1346,21 @@
return (size_t) result;
}
+inline size_t len_hint(handle h) {
+#if PY_VERSION_HEX >= 0x03040000
+ ssize_t result = PyObject_LengthHint(h.ptr(), 0);
+#else
+ ssize_t result = PyObject_Length(h.ptr());
+#endif
+ if (result < 0) {
+ // Sometimes a length can't be determined at all (eg generators)
+ // In which case simply return 0
+ PyErr_Clear();
+ return 0;
+ }
+ return (size_t) result;
+}
+
inline str repr(handle h) {
PyObject *str_value = PyObject_Repr(h.ptr());
if (!str_value) throw error_already_set();
diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h
index d6f4c63..1f87252 100644
--- a/include/pybind11/stl_bind.h
+++ b/include/pybind11/stl_bind.h
@@ -122,7 +122,7 @@
cl.def(init([](iterable it) {
auto v = std::unique_ptr<Vector>(new Vector());
- v->reserve(len(it));
+ v->reserve(len_hint(it));
for (handle h : it)
v->push_back(h.cast<T>());
return v.release();
@@ -136,6 +136,28 @@
"Extend the list by appending all the items in the given list"
);
+ cl.def("extend",
+ [](Vector &v, iterable it) {
+ const size_t old_size = v.size();
+ v.reserve(old_size + len_hint(it));
+ try {
+ for (handle h : it) {
+ v.push_back(h.cast<T>());
+ }
+ } catch (const cast_error &) {
+ v.erase(v.begin() + static_cast<typename Vector::difference_type>(old_size), v.end());
+ try {
+ v.shrink_to_fit();
+ } catch (const std::exception &) {
+ // Do nothing
+ }
+ throw;
+ }
+ },
+ arg("L"),
+ "Extend the list by appending all the items in the given list"
+ );
+
cl.def("insert",
[](Vector &v, SizeType i, const T &x) {
if (i > v.size())
diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py
index 0030924..52c8ac0 100644
--- a/tests/test_stl_binders.py
+++ b/tests/test_stl_binders.py
@@ -11,6 +11,10 @@
assert len(v_int) == 2
assert bool(v_int) is True
+ # test construction from a generator
+ v_int1 = m.VectorInt(x for x in range(5))
+ assert v_int1 == m.VectorInt([0, 1, 2, 3, 4])
+
v_int2 = m.VectorInt([0, 0])
assert v_int == v_int2
v_int2[1] = 1
@@ -33,6 +37,22 @@
del v_int2[0]
assert v_int2 == m.VectorInt([0, 99, 2, 3])
+ v_int2.extend(m.VectorInt([4, 5]))
+ assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5])
+
+ v_int2.extend([6, 7])
+ assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7])
+
+ # test error handling, and that the vector is unchanged
+ with pytest.raises(RuntimeError):
+ v_int2.extend([8, 'a'])
+
+ assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7])
+
+ # test extending from a generator
+ v_int2.extend(x for x in range(5))
+ assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4])
+
# related to the PyPy's buffer protocol.
@pytest.unsupported_on_pypy