Strip padding fields in dtypes, update the tests
diff --git a/example/example20.cpp b/example/example20.cpp
index 2c24e6d..32b50e3 100644
--- a/example/example20.cpp
+++ b/example/example20.cpp
@@ -44,6 +44,19 @@
return os << "n:a=" << v.a << ";b=" << v.b;
}
+struct PartialStruct {
+ bool x;
+ uint32_t y;
+ float z;
+ long dummy2;
+};
+
+struct PartialNestedStruct {
+ long dummy1;
+ PartialStruct a;
+ long dummy2;
+};
+
struct UnboundStruct { };
template <typename T>
@@ -54,7 +67,7 @@
}
template <typename S>
-py::array_t<S> create_recarray(size_t n) {
+py::array_t<S, 0> create_recarray(size_t n) {
auto arr = mkarray_via_buffer<S>(n);
auto ptr = static_cast<S*>(arr.request().ptr);
for (size_t i = 0; i < n; i++) {
@@ -67,7 +80,7 @@
return py::format_descriptor<UnboundStruct>::format();
}
-py::array_t<NestedStruct> create_nested(size_t n) {
+py::array_t<NestedStruct, 0> create_nested(size_t n) {
auto arr = mkarray_via_buffer<NestedStruct>(n);
auto ptr = static_cast<NestedStruct*>(arr.request().ptr);
for (size_t i = 0; i < n; i++) {
@@ -77,8 +90,17 @@
return arr;
}
+py::array_t<PartialNestedStruct, 0> create_partial_nested(size_t n) {
+ auto arr = mkarray_via_buffer<PartialNestedStruct>(n);
+ auto ptr = static_cast<PartialNestedStruct*>(arr.request().ptr);
+ for (size_t i = 0; i < n; i++) {
+ ptr[i].a.x = i % 2; ptr[i].a.y = (uint32_t) i; ptr[i].a.z = (float) i * 1.5f;
+ }
+ return arr;
+}
+
template <typename S>
-void print_recarray(py::array_t<S> arr) {
+void print_recarray(py::array_t<S, 0> arr) {
auto buf = arr.request();
auto ptr = static_cast<S*>(buf.ptr);
for (size_t i = 0; i < buf.size; i++)
@@ -89,6 +111,8 @@
std::cout << py::format_descriptor<SimpleStruct>::format() << std::endl;
std::cout << py::format_descriptor<PackedStruct>::format() << std::endl;
std::cout << py::format_descriptor<NestedStruct>::format() << std::endl;
+ std::cout << py::format_descriptor<PartialStruct>::format() << std::endl;
+ std::cout << py::format_descriptor<PartialNestedStruct>::format() << std::endl;
}
void print_dtypes() {
@@ -98,16 +122,22 @@
std::cout << to_str(py::dtype_of<SimpleStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PackedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
+ std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
+ std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
}
void init_ex20(py::module &m) {
PYBIND11_NUMPY_DTYPE(SimpleStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(PackedStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
+ PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
+ PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
m.def("create_rec_packed", &create_recarray<PackedStruct>);
m.def("create_rec_nested", &create_nested);
+ m.def("create_rec_partial", &create_recarray<PartialStruct>);
+ m.def("create_rec_partial_nested", &create_partial_nested);
m.def("print_format_descriptors", &print_format_descriptors);
m.def("print_rec_simple", &print_recarray<SimpleStruct>);
m.def("print_rec_packed", &print_recarray<PackedStruct>);
diff --git a/example/example20.py b/example/example20.py
index e0a0018..85ea9ae 100644
--- a/example/example20.py
+++ b/example/example20.py
@@ -5,7 +5,8 @@
import numpy as np
from example import (
create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors,
- print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound
+ print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound,
+ create_rec_partial, create_rec_partial_nested
)
@@ -23,6 +24,8 @@
'offsets': [0, 4, 8]})
packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
+elements = [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)]
+
for func, dtype in [(create_rec_simple, simple_dtype), (create_rec_packed, packed_dtype)]:
arr = func(0)
assert arr.dtype == dtype
@@ -31,14 +34,30 @@
arr = func(3)
assert arr.dtype == dtype
- check_eq(arr, [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)], simple_dtype)
- check_eq(arr, [(False, 0, 0.0), (True, 1, 1.5), (False, 2, 3.0)], packed_dtype)
+ check_eq(arr, elements, simple_dtype)
+ check_eq(arr, elements, packed_dtype)
if dtype == simple_dtype:
print_rec_simple(arr)
else:
print_rec_packed(arr)
+
+arr = create_rec_partial(3)
+print(arr.dtype)
+partial_dtype = arr.dtype
+assert '' not in arr.dtype.fields
+assert partial_dtype.itemsize > simple_dtype.itemsize
+check_eq(arr, elements, simple_dtype)
+check_eq(arr, elements, packed_dtype)
+
+arr = create_rec_partial_nested(3)
+print(arr.dtype)
+assert '' not in arr.dtype.fields
+assert '' not in arr.dtype.fields['a'][0].fields
+assert arr.dtype.itemsize > partial_dtype.itemsize
+np.testing.assert_equal(arr['a'], create_rec_partial(3))
+
nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)])
arr = create_rec_nested(0)
diff --git a/example/example20.ref b/example/example20.ref
index 32e2b4b..72a6c18 100644
--- a/example/example20.ref
+++ b/example/example20.ref
@@ -1,15 +1,21 @@
-T{?:x:xxxI:y:f:z:}
-T{?:x:=I:y:f:z:}
-T{T{?:x:xxxI:y:f:z:}:a:T{?:x:=I:y:f:z:}:b:}
+T{=?:x:3x=I:y:=f:z:}
+T{=?:x:=I:y:=f:z:}
+T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}
+T{=?:x:3x=I:y:=f:z:12x}
+T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}
[('x', '?'), ('y', '<u4'), ('z', '<f4')]
[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]
+{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
+{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
s:0,0,0
s:1,1,1.5
s:0,2,3
p:0,0,0
p:1,1,1.5
p:0,2,3
+{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
+{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
n:a=s:0,0,0;b=p:1,1,1.5
n:a=s:1,1,1.5;b=p:0,2,3
n:a=s:0,2,3;b=p:1,3,4.5
\ No newline at end of file
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 423403d..b8827c3 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -15,6 +15,7 @@
#include <algorithm>
#include <cstdlib>
#include <cstring>
+#include <sstream>
#include <initializer_list>
#if defined(_MSC_VER)
@@ -26,6 +27,8 @@
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
+object fix_dtype(object);
+
template <typename T>
struct is_pod_struct {
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
@@ -47,7 +50,9 @@
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
+ API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
+ API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001,
@@ -61,7 +66,9 @@
NPY_LONG_, NPY_ULONG_,
NPY_LONGLONG_, NPY_ULONGLONG_,
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
- NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_
+ NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
+ NPY_OBJECT_ = 17,
+ NPY_STRING_, NPY_UNICODE_, NPY_VOID_
};
static API lookup() {
@@ -79,7 +86,9 @@
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
+ DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
+ DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
@@ -91,10 +100,12 @@
PyObject *(*PyArray_NewFromDescr_)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
+ PyObject *(*PyArray_DescrNewFromType_)(int);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
+ bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
};
@@ -113,52 +124,83 @@
Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
- if (ptr && tmp)
- tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
+ if (ptr)
+ tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
}
array(const buffer_info &info) {
- PyObject *arr = nullptr, *descr = nullptr;
- int ndim = 0;
- Py_ssize_t dims[32];
- API& api = lookup_api();
+ auto& api = lookup_api();
- // Allocate non-zeroed memory if it hasn't been provided by the caller.
- // Normally, we could leave this null for NumPy to allocate memory for us, but
- // since we need a memoryview, the data pointer has to be non-null. NumPy uses
- // malloc if NPY_NEEDS_INIT is not set (in which case it uses calloc); however,
- // we don't have a desriptor yet (only a buffer format string), so we can't
- // access the flags. As long as we're not dealing with object dtypes/fields
- // though, the memory doesn't have to be zeroed so we use malloc.
- auto buf_info = info;
- if (!buf_info.ptr)
- // always allocate at least 1 element, same way as NumPy does it
- buf_info.ptr = std::malloc(std::max(info.size, (size_t) 1) * info.itemsize);
- if (!buf_info.ptr)
- pybind11_fail("NumPy: failed to allocate memory for buffer");
+ // _dtype_from_pep3118 returns dtypes with padding fields in, however the array
+ // constructor seems to then consume them, so we don't need to strip them ourselves
+ auto numpy_internal = module::import("numpy.core._internal");
+ auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
+ auto dtype = dtype_from_fmt(pybind11::str(info.format));
+ auto dtype2 = strip_padding_fields(dtype);
- // PyArray_GetArrayParamsFromObject seems to be the only low-level API function
- // that will accept arbitrary buffers (including structured types)
- auto view = memoryview(buf_info);
- auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
- &ndim, dims, &arr, nullptr);
- if (res < 0 || !arr || descr)
- // We expect arr to have a pointer to a newly created array, in which case all
- // other parameters like descr would be set to null, according to the API.
- pybind11_fail("NumPy: unable to convert buffer to an array");
- m_ptr = arr;
+ object tmp(api.PyArray_NewFromDescr_(
+ api.PyArray_Type_, dtype2.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
+ (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
+ if (!tmp)
+ pybind11_fail("NumPy: unable to create array!");
+ if (info.ptr)
+ tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
+ m_ptr = tmp.release().ptr();
+ auto d = (object) this->attr("dtype");
}
-protected:
+// protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
+
+ static object strip_padding_fields(object dtype) {
+ // Recursively strip all void fields with empty names that are generated for
+ // padding fields (as of NumPy v1.11).
+ auto fields = dtype.attr("fields").cast<object>();
+ if (fields.ptr() == Py_None)
+ return dtype;
+
+ struct field_descr { pybind11::str name; object format; int_ offset; };
+ std::vector<field_descr> field_descriptors;
+
+ auto items = fields.attr("items").cast<object>();
+ for (auto field : items()) {
+ auto spec = object(field, true).cast<tuple>();
+ auto name = spec[0].cast<pybind11::str>();
+ auto format = spec[1].cast<tuple>()[0].cast<object>();
+ auto offset = spec[1].cast<tuple>()[1].cast<int_>();
+ if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
+ continue;
+ field_descriptors.push_back({name, strip_padding_fields(format), offset});
+ }
+
+ std::sort(field_descriptors.begin(), field_descriptors.end(),
+ [](const field_descr& a, const field_descr& b) {
+ return (int) a.offset < (int) b.offset;
+ });
+
+ list names, formats, offsets;
+ for (auto& descr : field_descriptors) {
+ names.append(descr.name);
+ formats.append(descr.format);
+ offsets.append(descr.offset);
+ }
+ auto args = dict();
+ args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
+ args["itemsize"] = dtype.attr("itemsize").cast<int_>();
+
+ PyObject *descr = nullptr;
+ if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
+ pybind11_fail("NumPy: failed to create structured dtype");
+ return object(descr, false);
+ }
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
@@ -233,9 +275,12 @@
struct field_descriptor {
const char *name;
size_t offset;
+ size_t size;
+ const char *format;
object descr;
};
+
template <typename T>
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
static PYBIND11_DESCR name() { return _("user-defined"); }
@@ -253,7 +298,7 @@
}
static void register_dtype(std::initializer_list<field_descriptor> fields) {
- array::API& api = array::lookup_api();
+ auto& api = array::lookup_api();
auto args = dict();
list names { }, offsets { }, formats { };
for (auto field : fields) {
@@ -263,26 +308,47 @@
offsets.append(int_(field.offset));
formats.append(field.descr);
}
- args["names"] = names;
- args["offsets"] = offsets;
- args["formats"] = formats;
+ args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
args["itemsize"] = int_(sizeof(T));
// This is essentially the same as calling np.dtype() constructor in Python and passing
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
pybind11_fail("NumPy: failed to create structured dtype");
- // Let NumPy figure the buffer format string for us: memoryview(np.empty(0, dtype)).format
- auto np = module::import("numpy");
- auto empty = (object) np.attr("empty");
- if (auto arr = (object) empty(int_(0), dtype())) {
- if (auto view = PyMemoryView_FromObject(arr.ptr())) {
- if (auto info = PyMemoryView_GET_BUFFER(view)) {
- std::strncpy(format_(), info->format, 4096);
- return;
- }
- }
+
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
+ // not encoded explicitly into the format string. This will supposedly
+ // get fixed in v1.12; for further details, see these:
+ // - https://github.com/numpy/numpy/issues/7797
+ // - https://github.com/numpy/numpy/pull/7798
+ // Because of this, we won't use numpy's logic to generate buffer format
+ // strings and will just do it ourselves.
+ std::vector<field_descriptor> ordered_fields(fields);
+ std::sort(ordered_fields.begin(), ordered_fields.end(),
+ [](const field_descriptor& a, const field_descriptor &b) {
+ return a.offset < b.offset;
+ });
+ size_t offset = 0;
+ std::ostringstream oss;
+ oss << "T{";
+ for (auto& field : ordered_fields) {
+ if (field.offset > offset)
+ oss << (field.offset - offset) << 'x';
+ // note that '=' is required to cover the case of unaligned fields
+ oss << '=' << field.format << ':' << field.name << ':';
+ offset = field.offset + field.size;
}
- pybind11_fail("NumPy: failed to extract buffer format");
+ if (sizeof(T) > offset)
+ oss << (sizeof(T) - offset) << 'x';
+ oss << '}';
+ std::strncpy(format_(), oss.str().c_str(), 4096);
+
+ // Sanity check: verify that NumPy properly parses our buffer format string
+ auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
+ auto dtype = (object) arr.attr("dtype");
+ auto fixed_dtype = dtype;
+ // auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
+ // if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
+ // pybind11_fail("NumPy: invalid buffer descriptor!");
}
private:
@@ -293,7 +359,8 @@
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \
- #Field, offsetof(Type, Field), \
+ #Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
+ ::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
}