Merge pull request #437 from dean0x7d/dynamic-attrs
Add dynamic attribute support
diff --git a/include/pybind11/common.h b/include/pybind11/common.h
index 84035eb..2ee4691 100644
--- a/include/pybind11/common.h
+++ b/include/pybind11/common.h
@@ -455,6 +455,7 @@
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
+PYBIND11_RUNTIME_EXCEPTION(import_error, PyExc_ImportError)
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 1125fd7..d437c92 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -22,8 +22,8 @@
#include <functional>
#if defined(_MSC_VER)
-#pragma warning(push)
-#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
+# pragma warning(push)
+# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
/* This will be true on all flat address space platforms and allows us to reduce the
@@ -156,8 +156,10 @@
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
+#define PyArray_FLAGS_(ptr) \
+ PyArray_GET_(ptr, flags)
#define PyArray_CHKFLAGS_(ptr, flag) \
- (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
+ (flag == (PyArray_FLAGS_(ptr) & flag))
class dtype : public object {
public:
@@ -258,38 +260,62 @@
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};
- array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
- const std::vector<size_t>& strides, const void *ptr = nullptr) {
+ array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
+ const std::vector<size_t> &strides, const void *ptr = nullptr,
+ handle base = handle()) {
auto& api = detail::npy_api::get();
auto ndim = shape.size();
if (shape.size() != strides.size())
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
auto descr = dt;
+
+ int flags = 0;
+ if (base && ptr) {
+ array base_array(base, true);
+ if (base_array.check())
+ /* Copy flags from base (except baseship bit) */
+ flags = base_array.flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
+ else
+ /* Writable by default, easy to downgrade later on if needed */
+ flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
+ }
+
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
- (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), 0, nullptr), false);
+ (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
- if (ptr)
- tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
+ if (ptr) {
+ if (base) {
+ PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
+ } else {
+ tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
+ }
+ }
m_ptr = tmp.release().ptr();
}
- array(const pybind11::dtype& dt, const std::vector<size_t>& shape, const void *ptr = nullptr)
- : array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { }
+ array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
+ const void *ptr = nullptr, handle base = handle())
+ : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
- array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr)
- : array(dt, std::vector<size_t> { count }, ptr) { }
+ array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
+ handle base = handle())
+ : array(dt, std::vector<size_t>{ count }, ptr, base) { }
template<typename T> array(const std::vector<size_t>& shape,
- const std::vector<size_t>& strides, const T* ptr)
- : array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { }
+ const std::vector<size_t>& strides,
+ const T* ptr, handle base = handle())
+ : array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
- template<typename T> array(const std::vector<size_t>& shape, const T* ptr)
- : array(shape, default_strides(shape, sizeof(T)), ptr) { }
+ template <typename T>
+ array(const std::vector<size_t> &shape, const T *ptr,
+ handle base = handle())
+ : array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
- template<typename T> array(size_t count, const T* ptr)
- : array(std::vector<size_t> { count }, ptr) { }
+ template <typename T>
+ array(size_t count, const T *ptr, handle base = handle())
+ : array(std::vector<size_t>{ count }, ptr, base) { }
array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -319,6 +345,11 @@
return (size_t) PyArray_GET_(m_ptr, nd);
}
+ /// Base object
+ object base() const {
+ return object(PyArray_GET_(m_ptr, base), true);
+ }
+
/// Dimensions of the array
const size_t* shape() const {
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
@@ -343,6 +374,11 @@
return strides()[dim];
}
+ /// Return the NumPy array flags
+ int flags() const {
+ return PyArray_FLAGS_(m_ptr);
+ }
+
/// If set, the array is writeable (otherwise the buffer is read-only)
bool writeable() const {
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
@@ -389,6 +425,13 @@
return array(api.PyArray_Squeeze_(m_ptr), false);
}
+ /// Ensure that the argument is a NumPy array
+ static array ensure(object input, int ExtraFlags = 0) {
+ auto& api = detail::npy_api::get();
+ return array(api.PyArray_FromAny_(
+ input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr), false);
+ }
+
protected:
template<typename, typename> friend struct detail::npy_format_descriptor;
@@ -430,20 +473,23 @@
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
public:
- PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
+ PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure_(m_ptr));
array_t() : array() { }
array_t(const buffer_info& info) : array(info) { }
- array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, const T* ptr = nullptr)
- : array(shape, strides, ptr) { }
+ array_t(const std::vector<size_t> &shape,
+ const std::vector<size_t> &strides, const T *ptr = nullptr,
+ handle base = handle())
+ : array(shape, strides, ptr, base) { }
- array_t(const std::vector<size_t>& shape, const T* ptr = nullptr)
- : array(shape, ptr) { }
+ array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
+ handle base = handle())
+ : array(shape, ptr, base) { }
- array_t(size_t count, const T* ptr = nullptr)
- : array(count, ptr) { }
+ array_t(size_t count, const T *ptr = nullptr, handle base = handle())
+ : array(count, ptr, base) { }
constexpr size_t itemsize() const {
return sizeof(T);
@@ -479,7 +525,7 @@
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
- static PyObject *ensure(PyObject *ptr) {
+ static PyObject *ensure_(PyObject *ptr) {
if (ptr == nullptr)
return nullptr;
auto& api = detail::npy_api::get();
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 9cc57c1..f19435d 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -567,7 +567,7 @@
static module import(const char *name) {
PyObject *obj = PyImport_ImportModule(name);
if (!obj)
- pybind11_fail("Module \"" + std::string(name) + "\" not found!");
+ throw import_error("Module \"" + std::string(name) + "\" not found!");
return module(obj, false);
}
};
@@ -1398,15 +1398,27 @@
auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" ");
auto line = sep.attr("join")(strings);
- auto file = kwargs.contains("file") ? kwargs["file"].cast<object>()
- : module::import("sys").attr("stdout");
+ object file;
+ if (kwargs.contains("file")) {
+ file = kwargs["file"].cast<object>();
+ } else {
+ try {
+ file = module::import("sys").attr("stdout");
+ } catch (const import_error &) {
+ /* If print() is called from code that is executed as
+ part of garbage collection during interpreter shutdown,
+ importing 'sys' can fail. Give up rather than crashing the
+ interpreter in this case. */
+ return;
+ }
+ }
+
auto write = file.attr("write");
write(line);
write(kwargs.contains("end") ? kwargs["end"] : cast("\n"));
- if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) {
+ if (kwargs.contains("flush") && kwargs["flush"].cast<bool>())
file.attr("flush")();
- }
}
NAMESPACE_END(detail)
diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp
index 0614f57..ec4ddac 100644
--- a/tests/test_numpy_array.cpp
+++ b/tests/test_numpy_array.cpp
@@ -91,4 +91,37 @@
def_index_fn(mutate_data_t, arr_t&);
def_index_fn(at_t, const arr_t&);
def_index_fn(mutate_at_t, arr_t&);
+
+ sm.def("make_f_array", [] {
+ return py::array_t<float>({ 2, 2 }, { 4, 8 });
+ });
+
+ sm.def("make_c_array", [] {
+ return py::array_t<float>({ 2, 2 }, { 8, 4 });
+ });
+
+ sm.def("wrap", [](py::array a) {
+ return py::array(
+ a.dtype(),
+ std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
+ std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
+ a.data(),
+ a
+ );
+ });
+
+ struct ArrayClass {
+ int data[2] = { 1, 2 };
+ ArrayClass() { py::print("ArrayClass()"); }
+ ~ArrayClass() { py::print("~ArrayClass()"); }
+ };
+
+ py::class_<ArrayClass>(sm, "ArrayClass")
+ .def(py::init<>())
+ .def("numpy_view", [](py::object &obj) {
+ py::print("ArrayClass::numpy_view()");
+ ArrayClass &a = obj.cast<ArrayClass&>();
+ return py::array_t<int>({2}, {4}, a.data, obj);
+ }
+ );
});
diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py
index 4a6af5e..ae1954a 100644
--- a/tests/test_numpy_array.py
+++ b/tests/test_numpy_array.py
@@ -1,4 +1,5 @@
import pytest
+import gc
with pytest.suppress(ImportError):
import numpy as np
@@ -148,3 +149,92 @@
with pytest.raises(IndexError) as excinfo:
index_at(arr, 0, 4)
assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3'
+
+
+@pytest.requires_numpy
+def test_make_c_f_array():
+ from pybind11_tests.array import (
+ make_c_array, make_f_array
+ )
+ assert make_c_array().flags.c_contiguous
+ assert not make_c_array().flags.f_contiguous
+ assert make_f_array().flags.f_contiguous
+ assert not make_f_array().flags.c_contiguous
+
+
+@pytest.requires_numpy
+def test_wrap():
+ from pybind11_tests.array import wrap
+
+ def assert_references(A, B):
+ assert A is not B
+ assert A.__array_interface__['data'][0] == \
+ B.__array_interface__['data'][0]
+ assert A.shape == B.shape
+ assert A.strides == B.strides
+ assert A.flags.c_contiguous == B.flags.c_contiguous
+ assert A.flags.f_contiguous == B.flags.f_contiguous
+ assert A.flags.writeable == B.flags.writeable
+ assert A.flags.aligned == B.flags.aligned
+ assert A.flags.updateifcopy == B.flags.updateifcopy
+ assert np.all(A == B)
+ assert not B.flags.owndata
+ assert B.base is A
+ if A.flags.writeable and A.ndim == 2:
+ A[0, 0] = 1234
+ assert B[0, 0] == 1234
+
+ A1 = np.array([1, 2], dtype=np.int16)
+ assert A1.flags.owndata and A1.base is None
+ A2 = wrap(A1)
+ assert_references(A1, A2)
+
+ A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F')
+ assert A1.flags.owndata and A1.base is None
+ A2 = wrap(A1)
+ assert_references(A1, A2)
+
+ A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C')
+ A1.flags.writeable = False
+ A2 = wrap(A1)
+ assert_references(A1, A2)
+
+ A1 = np.random.random((4, 4, 4))
+ A2 = wrap(A1)
+ assert_references(A1, A2)
+
+ A1 = A1.transpose()
+ A2 = wrap(A1)
+ assert_references(A1, A2)
+
+ A1 = A1.diagonal()
+ A2 = wrap(A1)
+ assert_references(A1, A2)
+
+
+@pytest.requires_numpy
+def test_numpy_view(capture):
+ from pybind11_tests.array import ArrayClass
+ with capture:
+ ac = ArrayClass()
+ ac_view_1 = ac.numpy_view()
+ ac_view_2 = ac.numpy_view()
+ assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32))
+ del ac
+ gc.collect()
+ assert capture == """
+ ArrayClass()
+ ArrayClass::numpy_view()
+ ArrayClass::numpy_view()
+ """
+ ac_view_1[0] = 4
+ ac_view_1[1] = 3
+ assert ac_view_2[0] == 4
+ assert ac_view_2[1] == 3
+ with capture:
+ del ac_view_1
+ del ac_view_2
+ gc.collect()
+ assert capture == """
+ ~ArrayClass()
+ """