complex number support
diff --git a/include/pybind/numpy.h b/include/pybind/numpy.h
index f4a4a74..0336794 100644
--- a/include/pybind/numpy.h
+++ b/include/pybind/numpy.h
@@ -17,8 +17,10 @@
NAMESPACE_BEGIN(pybind)
+template <typename type> struct npy_format_descriptor { };
+
class array : public buffer {
-protected:
+public:
struct API {
enum Entries {
API_PyArray_Type = 2,
@@ -26,10 +28,18 @@
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
- API_NPY_C_CONTIGUOUS = 0x0001,
- API_NPY_F_CONTIGUOUS = 0x0002,
- API_NPY_NPY_ARRAY_FORCECAST = 0x0010,
- API_NPY_ENSURE_ARRAY = 0x0040
+ NPY_C_CONTIGUOUS = 0x0001,
+ NPY_F_CONTIGUOUS = 0x0002,
+ NPY_NPY_ARRAY_FORCECAST = 0x0010,
+ NPY_ENSURE_ARRAY = 0x0040,
+ NPY_BOOL=0,
+ NPY_BYTE, NPY_UBYTE,
+ NPY_SHORT, NPY_USHORT,
+ NPY_INT, NPY_UINT,
+ NPY_LONG, NPY_ULONG,
+ NPY_LONGLONG, NPY_ULONGLONG,
+ NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
+ NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE
};
static API lookup() {
@@ -59,13 +69,12 @@
PyTypeObject *PyArray_Type;
PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
};
-public:
+
PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
template <typename Type> array(size_t size, const Type *ptr) {
API& api = lookup_api();
- PyObject *descr = api.PyArray_DescrFromType(
- (int) format_descriptor<Type>::value()[0]);
+ PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<Type>::value);
if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format!");
Py_intptr_t shape = (Py_intptr_t) size;
@@ -83,7 +92,12 @@
API& api = lookup_api();
if (info.format.size() != 1)
throw std::runtime_error("Unsupported buffer format!");
- PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
+ int fmt = (int) info.format[0];
+ if (info.format == "Zd")
+ fmt = API::NPY_CDOUBLE;
+ else if (info.format == "Zf")
+ fmt = API::NPY_CFLOAT;
+ PyObject *descr = api.PyArray_DescrFromType(fmt);
if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
PyObject *tmp = api.PyArray_NewFromDescr(
@@ -109,12 +123,12 @@
PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr));
array_dtype() : array() { }
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
- static PyObject *ensure(PyObject *ptr) {
+ PyObject *ensure(PyObject *ptr) {
API &api = lookup_api();
- PyObject *descr = api.PyArray_DescrFromType(format_descriptor<T>::value()[0]);
+ PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<T>::value);
return api.PyArray_FromAny(ptr, descr, 0, 0,
- API::API_NPY_C_CONTIGUOUS | API::API_NPY_ENSURE_ARRAY |
- API::API_NPY_NPY_ARRAY_FORCECAST, nullptr);
+ API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY |
+ API::NPY_NPY_ARRAY_FORCECAST, nullptr);
}
};
@@ -125,8 +139,19 @@
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>)
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>)
PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>)
+PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<float>>)
+PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<double>>)
+PYBIND_TYPE_CASTER_PYTYPE(array_dtype<bool>)
NAMESPACE_END(detail)
+#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
+DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
+DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
+DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
+DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
+DECL_FMT(std::complex<double>, NPY_CDOUBLE);
+#undef DECL_FMT
+
template <typename func_type, typename return_type, typename... args_type, size_t... Index>
std::function<object(array_dtype<args_type>...)>
vectorize(func_type &&f, return_type (*) (args_type ...),
@@ -171,7 +196,7 @@
return cast(result[0]);
/* Return the result */
- return array(buffer_info(result.data(), sizeof(return_type),
+ return array(buffer_info(result.data(), sizeof(return_type),
format_descriptor<return_type>::value(),
ndim, shape, strides));
};