Improve consistency of array and array_t with regard to other pytypes
* `array_t(const object &)` now throws on error
* `array_t::ensure()` is intended for casters —- old constructor is
deprecated
* `array` and `array_t` get default constructors (empty array)
* `array` gets a converting constructor
* `py::isinstance<array_T<T>>()` checks the type (but not flags)
There is only one special thing which must remain: `array_t` gets
its own `type_caster` specialization which uses `ensure` instead
of a simple check.
diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h
index 729d0f6..0a1208e 100644
--- a/include/pybind11/eigen.h
+++ b/include/pybind11/eigen.h
@@ -54,7 +54,7 @@
static constexpr bool isVector = Type::IsVectorAtCompileTime;
bool load(handle src, bool) {
- array_t<Scalar> buf(src, true);
+ auto buf = array_t<Scalar>::ensure(src);
if (!buf)
return false;
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 3cbea01..77006c8 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -305,7 +305,7 @@
class array : public buffer {
public:
- PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
+ PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
enum {
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
@@ -313,6 +313,8 @@
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};
+ array() : array(0, static_cast<const double *>(nullptr)) {}
+
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
const std::vector<size_t> &strides, const void *ptr = nullptr,
handle base = handle()) {
@@ -478,10 +480,12 @@
}
/// Ensure that the argument is a NumPy array
- static array ensure(object input, int ExtraFlags = 0) {
- auto& api = detail::npy_api::get();
- return reinterpret_steal<array>(api.PyArray_FromAny_(
- input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr));
+ /// In case of an error, nullptr is returned and the Python error is cleared.
+ static array ensure(handle h, int ExtraFlags = 0) {
+ auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
+ if (!result)
+ PyErr_Clear();
+ return result;
}
protected:
@@ -520,8 +524,6 @@
return strides;
}
-protected:
-
template<typename... Ix> void check_dimensions(Ix... index) const {
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
}
@@ -536,15 +538,31 @@
}
check_dimensions_impl(axis + 1, shape + 1, index...);
}
+
+ /// Create array from any object -- always returns a new reference
+ static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
+ if (ptr == nullptr)
+ return nullptr;
+ return detail::npy_api::get().PyArray_FromAny_(
+ ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
+ }
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
public:
- array_t() : array() { }
+ array_t() : array(0, static_cast<const T *>(nullptr)) {}
+ array_t(handle h, borrowed_t) : array(h, borrowed) { }
+ array_t(handle h, stolen_t) : array(h, stolen) { }
- array_t(handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_(m_ptr); }
+ PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
+ array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
+ if (!m_ptr) PyErr_Clear();
+ if (!is_borrowed) Py_XDECREF(h.ptr());
+ }
- array_t(const object &o) : array(o) { m_ptr = ensure_(m_ptr); }
+ array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
+ if (!m_ptr) throw error_already_set();
+ }
explicit array_t(const buffer_info& info) : array(info) { }
@@ -590,17 +608,30 @@
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
}
- static PyObject *ensure_(PyObject *ptr) {
- if (ptr == nullptr)
- return nullptr;
- auto& api = detail::npy_api::get();
- PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
- detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
+ /// Ensure that the argument is a NumPy array of the correct dtype.
+ /// In case of an error, nullptr is returned and the Python error is cleared.
+ static array_t ensure(handle h) {
+ auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
if (!result)
PyErr_Clear();
- Py_DECREF(ptr);
return result;
}
+
+ static bool _check(handle h) {
+ const auto &api = detail::npy_api::get();
+ return api.PyArray_Check_(h.ptr())
+ && api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
+ }
+
+protected:
+ /// Create array from any object -- always returns a new reference
+ static PyObject *raw_array_t(PyObject *ptr) {
+ if (ptr == nullptr)
+ return nullptr;
+ return detail::npy_api::get().PyArray_FromAny_(
+ ptr, dtype::of<T>().release().ptr(), 0, 0,
+ detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
+ }
};
template <typename T>
@@ -631,7 +662,7 @@
using type = array_t<T, ExtraFlags>;
bool load(handle src, bool /* convert */) {
- value = type(src, true);
+ value = type::ensure(src);
return static_cast<bool>(value);
}
diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp
index df6377e..14c4c29 100644
--- a/tests/test_numpy_array.cpp
+++ b/tests/test_numpy_array.cpp
@@ -126,4 +126,28 @@
);
sm.def("function_taking_uint64", [](uint64_t) { });
+
+ sm.def("isinstance_untyped", [](py::object yes, py::object no) {
+ return py::isinstance<py::array>(yes) && !py::isinstance<py::array>(no);
+ });
+
+ sm.def("isinstance_typed", [](py::object o) {
+ return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
+ });
+
+ sm.def("default_constructors", []() {
+ return py::dict(
+ "array"_a=py::array(),
+ "array_t<int32>"_a=py::array_t<std::int32_t>(),
+ "array_t<double>"_a=py::array_t<double>()
+ );
+ });
+
+ sm.def("converting_constructors", [](py::object o) {
+ return py::dict(
+ "array"_a=py::array(o),
+ "array_t<int32>"_a=py::array_t<std::int32_t>(o),
+ "array_t<double>"_a=py::array_t<double>(o)
+ );
+ });
});
diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py
index 40682ef..cec0054 100644
--- a/tests/test_numpy_array.py
+++ b/tests/test_numpy_array.py
@@ -245,3 +245,30 @@
from pybind11_tests.array import function_taking_uint64
function_taking_uint64(123)
function_taking_uint64(np.uint64(123))
+
+
+@pytest.requires_numpy
+def test_isinstance():
+ from pybind11_tests.array import isinstance_untyped, isinstance_typed
+
+ assert isinstance_untyped(np.array([1, 2, 3]), "not an array")
+ assert isinstance_typed(np.array([1.0, 2.0, 3.0]))
+
+
+@pytest.requires_numpy
+def test_constructors():
+ from pybind11_tests.array import default_constructors, converting_constructors
+
+ defaults = default_constructors()
+ for a in defaults.values():
+ assert a.size == 0
+ assert defaults["array"].dtype == np.array([]).dtype
+ assert defaults["array_t<int32>"].dtype == np.int32
+ assert defaults["array_t<double>"].dtype == np.float64
+
+ results = converting_constructors([1, 2, 3])
+ for a in results.values():
+ np.testing.assert_array_equal(a, [1, 2, 3])
+ assert results["array"].dtype == np.int_
+ assert results["array_t<int32>"].dtype == np.int32
+ assert results["array_t<double>"].dtype == np.float64