Move register_dtype() outside of the template
(avoid code bloat if possible)
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index b180cb2..af465a1 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -81,14 +81,18 @@
struct numpy_internals {
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
- template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
- auto it = registered_dtypes.find(std::type_index(typeid(T)));
+ numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
+ auto it = registered_dtypes.find(std::type_index(tinfo));
if (it != registered_dtypes.end())
return &(it->second);
if (throw_if_missing)
- pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name());
+ pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
return nullptr;
}
+
+ template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
+ return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
+ }
};
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
@@ -686,6 +690,62 @@
dtype descr;
};
+inline PYBIND11_NOINLINE void register_structured_dtype(
+ const std::initializer_list<field_descriptor>& fields,
+ const std::type_info& tinfo, size_t itemsize,
+ bool (*direct_converter)(PyObject *, void *&))
+{
+ auto& numpy_internals = get_numpy_internals();
+ if (numpy_internals.get_type_info(tinfo, false))
+ pybind11_fail("NumPy: dtype is already registered");
+
+ list names, formats, offsets;
+ for (auto field : fields) {
+ if (!field.descr)
+ pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
+ field.name + "` @ " + tinfo.name());
+ names.append(PYBIND11_STR_TYPE(field.name));
+ formats.append(field.descr);
+ offsets.append(pybind11::int_(field.offset));
+ }
+ auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
+
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
+ // not encoded explicitly into the format string. This will supposedly
+ // get fixed in v1.12; for further details, see these:
+ // - https://github.com/numpy/numpy/issues/7797
+ // - https://github.com/numpy/numpy/pull/7798
+ // Because of this, we won't use numpy's logic to generate buffer format
+ // strings and will just do it ourselves.
+ std::vector<field_descriptor> ordered_fields(fields);
+ std::sort(ordered_fields.begin(), ordered_fields.end(),
+ [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
+ size_t offset = 0;
+ std::ostringstream oss;
+ oss << "T{";
+ for (auto& field : ordered_fields) {
+ if (field.offset > offset)
+ oss << (field.offset - offset) << 'x';
+ // note that '=' is required to cover the case of unaligned fields
+ oss << '=' << field.format << ':' << field.name << ':';
+ offset = field.offset + field.size;
+ }
+ if (itemsize > offset)
+ oss << (itemsize - offset) << 'x';
+ oss << '}';
+ auto format_str = oss.str();
+
+ // Sanity check: verify that NumPy properly parses our buffer format string
+ auto& api = npy_api::get();
+ auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
+ if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
+ pybind11_fail("NumPy: invalid buffer descriptor!");
+
+ auto tindex = std::type_index(tinfo);
+ numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
+ get_internals().direct_conversions[tindex].push_back(direct_converter);
+}
+
template <typename T>
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
static PYBIND11_DESCR name() { return _("struct"); }
@@ -699,56 +759,9 @@
return format_str;
}
- static void register_dtype(std::initializer_list<field_descriptor> fields) {
- auto& numpy_internals = get_numpy_internals();
- if (numpy_internals.get_type_info<T>(false))
- pybind11_fail("NumPy: dtype is already registered");
-
- list names, formats, offsets;
- for (auto field : fields) {
- if (!field.descr)
- pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
- field.name + "` @ " + typeid(T).name());
- names.append(PYBIND11_STR_TYPE(field.name));
- formats.append(field.descr);
- offsets.append(pybind11::int_(field.offset));
- }
- auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
-
- // There is an existing bug in NumPy (as of v1.11): trailing bytes are
- // not encoded explicitly into the format string. This will supposedly
- // get fixed in v1.12; for further details, see these:
- // - https://github.com/numpy/numpy/issues/7797
- // - https://github.com/numpy/numpy/pull/7798
- // Because of this, we won't use numpy's logic to generate buffer format
- // strings and will just do it ourselves.
- std::vector<field_descriptor> ordered_fields(fields);
- std::sort(ordered_fields.begin(), ordered_fields.end(),
- [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
- size_t offset = 0;
- std::ostringstream oss;
- oss << "T{";
- for (auto& field : ordered_fields) {
- if (field.offset > offset)
- oss << (field.offset - offset) << 'x';
- // note that '=' is required to cover the case of unaligned fields
- oss << '=' << field.format << ':' << field.name << ':';
- offset = field.offset + field.size;
- }
- if (sizeof(T) > offset)
- oss << (sizeof(T) - offset) << 'x';
- oss << '}';
- auto format_str = oss.str();
-
- // Sanity check: verify that NumPy properly parses our buffer format string
- auto& api = npy_api::get();
- auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1));
- if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
- pybind11_fail("NumPy: invalid buffer descriptor!");
-
- auto tindex = std::type_index(typeid(T));
- numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
- get_internals().direct_conversions[tindex].push_back(direct_converter);
+ static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
+ register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
+ sizeof(T), &direct_converter);
}
private: