numpy.h replace macros with functions (#514)
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 77006c8..72dd4b3 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -199,16 +199,28 @@
return api;
}
};
-NAMESPACE_END(detail)
-#define PyArray_GET_(ptr, attr) \
- (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
-#define PyArrayDescr_GET_(ptr, attr) \
- (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
-#define PyArray_FLAGS_(ptr) \
- PyArray_GET_(ptr, flags)
-#define PyArray_CHKFLAGS_(ptr, flag) \
- (flag == (PyArray_FLAGS_(ptr) & flag))
+inline PyArray_Proxy* array_proxy(void* ptr) {
+ return reinterpret_cast<PyArray_Proxy*>(ptr);
+}
+
+inline const PyArray_Proxy* array_proxy(const void* ptr) {
+ return reinterpret_cast<const PyArray_Proxy*>(ptr);
+}
+
+inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
+ return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
+}
+
+inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
+ return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
+}
+
+inline bool check_flags(const void* ptr, int flag) {
+ return (flag == (array_proxy(ptr)->flags & flag));
+}
+
+NAMESPACE_END(detail)
class dtype : public object {
public:
@@ -249,17 +261,17 @@
/// Size of the data type in bytes.
size_t itemsize() const {
- return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
+ return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
}
/// Returns true for structured data types.
bool has_fields() const {
- return PyArrayDescr_GET_(m_ptr, names) != nullptr;
+ return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
}
/// Single-character type code.
char kind() const {
- return PyArrayDescr_GET_(m_ptr, kind);
+ return detail::array_descriptor_proxy(m_ptr)->kind;
}
private:
@@ -341,7 +353,7 @@
pybind11_fail("NumPy: unable to create array!");
if (ptr) {
if (base) {
- PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
+ detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
} else {
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
}
@@ -376,7 +388,7 @@
/// Array descriptor (dtype)
pybind11::dtype dtype() const {
- return reinterpret_borrow<pybind11::dtype>(PyArray_GET_(m_ptr, descr));
+ return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
}
/// Total number of elements
@@ -386,7 +398,7 @@
/// Byte size of a single element
size_t itemsize() const {
- return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
+ return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
}
/// Total number of bytes
@@ -396,17 +408,17 @@
/// Number of dimensions
size_t ndim() const {
- return (size_t) PyArray_GET_(m_ptr, nd);
+ return (size_t) detail::array_proxy(m_ptr)->nd;
}
/// Base object
object base() const {
- return reinterpret_borrow<object>(PyArray_GET_(m_ptr, base));
+ return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
}
/// Dimensions of the array
const size_t* shape() const {
- return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
+ return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
}
/// Dimension along a given axis
@@ -418,7 +430,7 @@
/// Strides of the array
const size_t* strides() const {
- return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
+ return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
}
/// Stride along a given axis
@@ -430,23 +442,23 @@
/// Return the NumPy array flags
int flags() const {
- return PyArray_FLAGS_(m_ptr);
+ return detail::array_proxy(m_ptr)->flags;
}
/// If set, the array is writeable (otherwise the buffer is read-only)
bool writeable() const {
- return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
+ return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
}
/// If set, the array owns the data (will be freed when the array is deleted)
bool owndata() const {
- return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
+ return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
}
/// Pointer to the contained data. If index is not provided, points to the
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
template<typename... Ix> const void* data(Ix... index) const {
- return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
+ return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
}
/// Mutable pointer to the contained data. If index is not provided, points to the
@@ -454,7 +466,7 @@
/// May throw if the array is not writeable.
template<typename... Ix> void* mutable_data(Ix... index) {
check_writeable();
- return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
+ return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
}
/// Byte offset from beginning of the array to a given index (full or partial).
@@ -620,7 +632,7 @@
static bool _check(handle h) {
const auto &api = detail::npy_api::get();
return api.PyArray_Check_(h.ptr())
- && api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
+ && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
}
protected: