Add numpy wrappers for char[] and std::array<char>
diff --git a/example/example20.cpp b/example/example20.cpp
index 8b18c05..07849a0 100644
--- a/example/example20.cpp
+++ b/example/example20.cpp
@@ -65,6 +65,19 @@
struct UnboundStruct { };
+struct StringStruct {
+ char a[3];
+ std::array<char, 3> b;
+};
+
+std::ostream& operator<<(std::ostream& os, const StringStruct& v) {
+ os << "a='";
+ for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i];
+ os << "',b='";
+ for (size_t i = 0; i < 3 && v.b[i]; i++) os << v.b[i];
+ return os << "'";
+}
+
template <typename T>
py::array mkarray_via_buffer(size_t n) {
return py::array(py::buffer_info(nullptr, sizeof(T),
@@ -108,6 +121,25 @@
return arr;
}
+py::array_t<StringStruct, 0> create_string_array(bool non_empty) {
+ auto arr = mkarray_via_buffer<StringStruct>(non_empty ? 4 : 0);
+ if (non_empty) {
+ auto req = arr.request();
+ auto ptr = static_cast<StringStruct*>(req.ptr);
+ for (size_t i = 0; i < req.size * req.itemsize; i++)
+ static_cast<char*>(req.ptr)[i] = 0;
+ ptr[1].a[0] = 'a'; ptr[1].b[0] = 'a';
+ ptr[2].a[0] = 'a'; ptr[2].b[0] = 'a';
+ ptr[3].a[0] = 'a'; ptr[3].b[0] = 'a';
+
+ ptr[2].a[1] = 'b'; ptr[2].b[1] = 'b';
+ ptr[3].a[1] = 'b'; ptr[3].b[1] = 'b';
+
+ ptr[3].a[2] = 'c'; ptr[3].b[2] = 'c';
+ }
+ return arr;
+}
+
template <typename S>
void print_recarray(py::array_t<S, 0> arr) {
auto req = arr.request();
@@ -122,6 +154,7 @@
std::cout << py::format_descriptor<NestedStruct>::format() << std::endl;
std::cout << py::format_descriptor<PartialStruct>::format() << std::endl;
std::cout << py::format_descriptor<PartialNestedStruct>::format() << std::endl;
+ std::cout << py::format_descriptor<StringStruct>::format() << std::endl;
}
void print_dtypes() {
@@ -133,6 +166,7 @@
std::cout << to_str(py::dtype_of<NestedStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialStruct>()) << std::endl;
std::cout << to_str(py::dtype_of<PartialNestedStruct>()) << std::endl;
+ std::cout << to_str(py::dtype_of<StringStruct>()) << std::endl;
}
void init_ex20(py::module &m) {
@@ -141,6 +175,7 @@
PYBIND11_NUMPY_DTYPE(NestedStruct, a, b);
PYBIND11_NUMPY_DTYPE(PartialStruct, x, y, z);
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
+ PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
m.def("create_rec_simple", &create_recarray<SimpleStruct>);
m.def("create_rec_packed", &create_recarray<PackedStruct>);
@@ -153,6 +188,8 @@
m.def("print_rec_nested", &print_recarray<NestedStruct>);
m.def("print_dtypes", &print_dtypes);
m.def("get_format_unbound", &get_format_unbound);
+ m.def("create_string_array", &create_string_array);
+ m.def("print_string_array", &print_recarray<StringStruct>);
}
#undef PYBIND11_PACKED
diff --git a/example/example20.py b/example/example20.py
index bb57590..34dfd83 100644
--- a/example/example20.py
+++ b/example/example20.py
@@ -6,7 +6,7 @@
from example import (
create_rec_simple, create_rec_packed, create_rec_nested, print_format_descriptors,
print_rec_simple, print_rec_packed, print_rec_nested, print_dtypes, get_format_unbound,
- create_rec_partial, create_rec_partial_nested
+ create_rec_partial, create_rec_partial_nested, create_string_array, print_string_array
)
@@ -72,3 +72,12 @@
print_rec_nested(arr)
assert create_rec_nested.__doc__.strip().endswith('numpy.ndarray[dtype=NestedStruct]')
+
+arr = create_string_array(True)
+print(arr.dtype)
+print_string_array(arr)
+dtype = arr.dtype
+assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc']
+assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc']
+arr = create_string_array(False)
+assert dtype == arr.dtype
diff --git a/example/example20.ref b/example/example20.ref
index 72a6c18..4f07ce4 100644
--- a/example/example20.ref
+++ b/example/example20.ref
@@ -3,11 +3,13 @@
T{=T{=?:x:3x=I:y:=f:z:}:a:=T{=?:x:=I:y:=f:z:}:b:}
T{=?:x:3x=I:y:=f:z:12x}
T{8x=T{=?:x:3x=I:y:=f:z:12x}:a:8x}
+T{=3s:a:=3s:b:}
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}
[('x', '?'), ('y', '<u4'), ('z', '<f4')]
[('a', {'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':12}), ('b', [('x', '?'), ('y', '<u4'), ('z', '<f4')])]
{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
+[('a', 'S3'), ('b', 'S3')]
s:0,0,0
s:1,1,1.5
s:0,2,3
@@ -18,4 +20,9 @@
{'names':['a'], 'formats':[{'names':['x','y','z'], 'formats':['?','<u4','<f4'], 'offsets':[0,4,8], 'itemsize':24}], 'offsets':[8], 'itemsize':40}
n:a=s:0,0,0;b=p:1,1,1.5
n:a=s:1,1,1.5;b=p:0,2,3
-n:a=s:0,2,3;b=p:1,3,4.5
\ No newline at end of file
+n:a=s:0,2,3;b=p:1,3,4.5
+[('a', 'S3'), ('b', 'S3')]
+a='',b=''
+a='a',b='a'
+a='ab',b='ab'
+a='abc',b='abc'
diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h
index 1dc3de2..20fbfd8 100644
--- a/include/pybind11/numpy.h
+++ b/include/pybind11/numpy.h
@@ -13,6 +13,7 @@
#include "complex.h"
#include <numeric>
#include <algorithm>
+#include <array>
#include <cstdlib>
#include <cstring>
#include <sstream>
@@ -27,10 +28,14 @@
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
+template <typename T> struct is_std_array : std::false_type { };
+template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T>
struct is_pod_struct {
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
+ !std::is_array<T>::value &&
+ !is_std_array<T>::value &&
!std::is_integral<T>::value &&
!std::is_same<T, float>::value &&
!std::is_same<T, double>::value &&
@@ -221,9 +226,14 @@
template <typename T>
struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> {
- static const char *format() {
- return detail::npy_format_descriptor<T>::format();
- }
+ static const char *format() { return detail::npy_format_descriptor<T>::format(); }
+};
+
+template <size_t N> struct format_descriptor<char[N]> {
+ static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
+};
+template <size_t N> struct format_descriptor<std::array<char, N>> {
+ static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
};
template <typename T>
@@ -268,6 +278,22 @@
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
#undef DECL_FMT
+#define DECL_CHAR_FMT \
+ static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
+ static object dtype() { \
+ auto& api = array::lookup_api(); \
+ PyObject *descr = nullptr; \
+ PYBIND11_DESCR fmt = _("S") + _<N>(); \
+ pybind11::str py_fmt(fmt.text()); \
+ if (!api.PyArray_DescrConverter_(py_fmt.release().ptr(), &descr) || !descr) \
+ pybind11_fail("NumPy: failed to create string dtype"); \
+ return object(descr, false); \
+ } \
+ static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
+template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
+template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
+#undef DECL_CHAR_FMT
+
struct field_descriptor {
const char *name;
size_t offset;