blob: 033679470c161d14c1e582204a67d8eb4f6fffc8 [file] [log] [blame]
Wenzel Jakobd4258ba2015-07-26 16:33:49 +02001/*
2 pybind/numpy.h: Basic NumPy support, auto-vectorization support
3
4 Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
5
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include <pybind/pybind.h>
13#if defined(_MSC_VER)
14#pragma warning(push)
15#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
16#endif
17
18NAMESPACE_BEGIN(pybind)
19
Wenzel Jakob43398a82015-07-28 16:12:20 +020020template <typename type> struct npy_format_descriptor { };
21
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020022class array : public buffer {
Wenzel Jakob43398a82015-07-28 16:12:20 +020023public:
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020024 struct API {
25 enum Entries {
26 API_PyArray_Type = 2,
27 API_PyArray_DescrFromType = 45,
28 API_PyArray_FromAny = 69,
29 API_PyArray_NewCopy = 85,
30 API_PyArray_NewFromDescr = 94,
Wenzel Jakob43398a82015-07-28 16:12:20 +020031 NPY_C_CONTIGUOUS = 0x0001,
32 NPY_F_CONTIGUOUS = 0x0002,
33 NPY_NPY_ARRAY_FORCECAST = 0x0010,
34 NPY_ENSURE_ARRAY = 0x0040,
35 NPY_BOOL=0,
36 NPY_BYTE, NPY_UBYTE,
37 NPY_SHORT, NPY_USHORT,
38 NPY_INT, NPY_UINT,
39 NPY_LONG, NPY_ULONG,
40 NPY_LONGLONG, NPY_ULONGLONG,
41 NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
42 NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020043 };
44
45 static API lookup() {
46 PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
47 PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
48 void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
49 Py_XDECREF(capsule);
50 Py_XDECREF(numpy);
51 if (api_ptr == nullptr)
52 throw std::runtime_error("Could not acquire pointer to NumPy API!");
53 API api;
54 api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
55 api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
56 api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny];
57 api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
58 api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
59 return api;
60 }
61
62 bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); }
63
64 PyObject *(*PyArray_DescrFromType)(int);
65 PyObject *(*PyArray_NewFromDescr)
66 (PyTypeObject *, PyObject *, int, Py_intptr_t *,
67 Py_intptr_t *, void *, int, PyObject *);
68 PyObject *(*PyArray_NewCopy)(PyObject *, int);
69 PyTypeObject *PyArray_Type;
70 PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
71 };
Wenzel Jakob43398a82015-07-28 16:12:20 +020072
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020073 PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
74
75 template <typename Type> array(size_t size, const Type *ptr) {
76 API& api = lookup_api();
Wenzel Jakob43398a82015-07-28 16:12:20 +020077 PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<Type>::value);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020078 if (descr == nullptr)
79 throw std::runtime_error("NumPy: unsupported buffer format!");
80 Py_intptr_t shape = (Py_intptr_t) size;
81 PyObject *tmp = api.PyArray_NewFromDescr(
82 api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
83 if (tmp == nullptr)
84 throw std::runtime_error("NumPy: unable to create array!");
85 m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
86 Py_DECREF(tmp);
87 if (m_ptr == nullptr)
88 throw std::runtime_error("NumPy: unable to copy array!");
89 }
90
91 array(const buffer_info &info) {
92 API& api = lookup_api();
93 if (info.format.size() != 1)
94 throw std::runtime_error("Unsupported buffer format!");
Wenzel Jakob43398a82015-07-28 16:12:20 +020095 int fmt = (int) info.format[0];
96 if (info.format == "Zd")
97 fmt = API::NPY_CDOUBLE;
98 else if (info.format == "Zf")
99 fmt = API::NPY_CFLOAT;
100 PyObject *descr = api.PyArray_DescrFromType(fmt);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200101 if (descr == nullptr)
102 throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
103 PyObject *tmp = api.PyArray_NewFromDescr(
104 api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
105 (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
106 if (tmp == nullptr)
107 throw std::runtime_error("NumPy: unable to create array!");
108 m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
109 Py_DECREF(tmp);
110 if (m_ptr == nullptr)
111 throw std::runtime_error("NumPy: unable to copy array!");
112 }
113
114protected:
115 static API &lookup_api() {
116 static API api = API::lookup();
117 return api;
118 }
119};
120
121template <typename T> class array_dtype : public array {
122public:
123 PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr));
124 array_dtype() : array() { }
125 static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
Wenzel Jakob43398a82015-07-28 16:12:20 +0200126 PyObject *ensure(PyObject *ptr) {
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200127 API &api = lookup_api();
Wenzel Jakob43398a82015-07-28 16:12:20 +0200128 PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<T>::value);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200129 return api.PyArray_FromAny(ptr, descr, 0, 0,
Wenzel Jakob43398a82015-07-28 16:12:20 +0200130 API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY |
131 API::NPY_NPY_ARRAY_FORCECAST, nullptr);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200132 }
133};
134
135NAMESPACE_BEGIN(detail)
136PYBIND_TYPE_CASTER_PYTYPE(array)
137PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int8_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint8_t>)
138PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint16_t>)
139PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>)
140PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>)
141PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>)
Wenzel Jakob43398a82015-07-28 16:12:20 +0200142PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<float>>)
143PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<double>>)
144PYBIND_TYPE_CASTER_PYTYPE(array_dtype<bool>)
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200145NAMESPACE_END(detail)
146
Wenzel Jakob43398a82015-07-28 16:12:20 +0200147#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
148DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
149DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
150DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
151DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
152DECL_FMT(std::complex<double>, NPY_CDOUBLE);
153#undef DECL_FMT
154
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200155template <typename func_type, typename return_type, typename... args_type, size_t... Index>
156 std::function<object(array_dtype<args_type>...)>
157 vectorize(func_type &&f, return_type (*) (args_type ...),
158 detail::index_sequence<Index...>) {
159
160 return [f](array_dtype<args_type>... args) -> array {
161 /* Request buffers from all parameters */
162 const size_t N = sizeof...(args_type);
163 std::array<buffer_info, N> buffers {{ args.request()... }};
164
165 /* Determine dimensions parameters of output array */
166 int ndim = 0; size_t count = 0;
167 std::vector<size_t> shape;
168 for (size_t i=0; i<N; ++i) {
169 if (buffers[i].count > count) {
170 ndim = buffers[i].ndim;
171 shape = buffers[i].shape;
172 count = buffers[i].count;
173 }
174 }
175 std::vector<size_t> strides(ndim);
176 if (ndim > 0) {
177 strides[ndim-1] = sizeof(return_type);
178 for (int i=ndim-1; i>0; --i)
179 strides[i-1] = strides[i] * shape[i];
180 }
181
182 /* Check if the parameters are actually compatible */
183 for (size_t i=0; i<N; ++i) {
184 if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
185 throw std::runtime_error("pybind::vectorize: incompatible size/dimension of inputs!");
186 }
187
188 /* Call the function */
189 std::vector<return_type> result(count);
190 for (size_t i=0; i<count; ++i)
191 result[i] = f((buffers[Index].count == 1
192 ? *((args_type *) buffers[Index].ptr)
193 : ((args_type *) buffers[Index].ptr)[i])...);
194
195 if (count == 1)
196 return cast(result[0]);
197
198 /* Return the result */
Wenzel Jakob43398a82015-07-28 16:12:20 +0200199 return array(buffer_info(result.data(), sizeof(return_type),
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200200 format_descriptor<return_type>::value(),
201 ndim, shape, strides));
202 };
203}
204
205template <typename func_type, typename return_type, typename... args_type>
206 std::function<object(array_dtype<args_type>...)>
207 vectorize(func_type &&f, return_type (*f_) (args_type ...) = nullptr) {
208 return vectorize(f, f_, typename detail::make_index_sequence<sizeof...(args_type)>::type());
209}
210
211template <typename return_type, typename... args_type>
212std::function<object(array_dtype<args_type>...)> vectorize(return_type (*f) (args_type ...)) {
213 return vectorize(f, f);
214}
215
216template <typename func> auto vectorize(func &&f) -> decltype(
217 vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) {
218 return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(
219 &std::remove_reference<func>::type::operator())>::type *) nullptr);
220}
221
222NAMESPACE_END(pybind)
223
224#if defined(_MSC_VER)
225#pragma warning(pop)
226#endif