blob: 3c06854f0871cc0572d58ef7d774fe4cfae196b2 [file] [log] [blame]
Wenzel Jakobd4258ba2015-07-26 16:33:49 +02001/*
2 pybind/numpy.h: Basic NumPy support, auto-vectorization support
3
4 Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
5
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include <pybind/pybind.h>
Wenzel Jakoba576e6a2015-07-29 17:51:54 +020013#include <functional>
14
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020015#if defined(_MSC_VER)
16#pragma warning(push)
17#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
18#endif
19
20NAMESPACE_BEGIN(pybind)
21
Wenzel Jakob43398a82015-07-28 16:12:20 +020022template <typename type> struct npy_format_descriptor { };
23
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020024class array : public buffer {
Wenzel Jakob43398a82015-07-28 16:12:20 +020025public:
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020026 struct API {
27 enum Entries {
28 API_PyArray_Type = 2,
29 API_PyArray_DescrFromType = 45,
30 API_PyArray_FromAny = 69,
31 API_PyArray_NewCopy = 85,
32 API_PyArray_NewFromDescr = 94,
Wenzel Jakob43398a82015-07-28 16:12:20 +020033 NPY_C_CONTIGUOUS = 0x0001,
34 NPY_F_CONTIGUOUS = 0x0002,
35 NPY_NPY_ARRAY_FORCECAST = 0x0010,
36 NPY_ENSURE_ARRAY = 0x0040,
37 NPY_BOOL=0,
38 NPY_BYTE, NPY_UBYTE,
39 NPY_SHORT, NPY_USHORT,
40 NPY_INT, NPY_UINT,
41 NPY_LONG, NPY_ULONG,
42 NPY_LONGLONG, NPY_ULONGLONG,
43 NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
44 NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020045 };
46
47 static API lookup() {
48 PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
49 PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
50 void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
51 Py_XDECREF(capsule);
52 Py_XDECREF(numpy);
53 if (api_ptr == nullptr)
54 throw std::runtime_error("Could not acquire pointer to NumPy API!");
55 API api;
56 api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
57 api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
58 api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny];
59 api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
60 api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
61 return api;
62 }
63
64 bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); }
65
66 PyObject *(*PyArray_DescrFromType)(int);
67 PyObject *(*PyArray_NewFromDescr)
68 (PyTypeObject *, PyObject *, int, Py_intptr_t *,
69 Py_intptr_t *, void *, int, PyObject *);
70 PyObject *(*PyArray_NewCopy)(PyObject *, int);
71 PyTypeObject *PyArray_Type;
72 PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
73 };
Wenzel Jakob43398a82015-07-28 16:12:20 +020074
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020075 PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
76
77 template <typename Type> array(size_t size, const Type *ptr) {
78 API& api = lookup_api();
Wenzel Jakob43398a82015-07-28 16:12:20 +020079 PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<Type>::value);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +020080 if (descr == nullptr)
81 throw std::runtime_error("NumPy: unsupported buffer format!");
82 Py_intptr_t shape = (Py_intptr_t) size;
83 PyObject *tmp = api.PyArray_NewFromDescr(
84 api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
85 if (tmp == nullptr)
86 throw std::runtime_error("NumPy: unable to create array!");
87 m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
88 Py_DECREF(tmp);
89 if (m_ptr == nullptr)
90 throw std::runtime_error("NumPy: unable to copy array!");
91 }
92
93 array(const buffer_info &info) {
94 API& api = lookup_api();
95 if (info.format.size() != 1)
96 throw std::runtime_error("Unsupported buffer format!");
Wenzel Jakob43398a82015-07-28 16:12:20 +020097 int fmt = (int) info.format[0];
98 if (info.format == "Zd")
99 fmt = API::NPY_CDOUBLE;
100 else if (info.format == "Zf")
101 fmt = API::NPY_CFLOAT;
102 PyObject *descr = api.PyArray_DescrFromType(fmt);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200103 if (descr == nullptr)
104 throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
105 PyObject *tmp = api.PyArray_NewFromDescr(
106 api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
107 (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
108 if (tmp == nullptr)
109 throw std::runtime_error("NumPy: unable to create array!");
110 m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
111 Py_DECREF(tmp);
112 if (m_ptr == nullptr)
113 throw std::runtime_error("NumPy: unable to copy array!");
114 }
115
116protected:
117 static API &lookup_api() {
118 static API api = API::lookup();
119 return api;
120 }
121};
122
123template <typename T> class array_dtype : public array {
124public:
125 PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr));
126 array_dtype() : array() { }
127 static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
Wenzel Jakob43398a82015-07-28 16:12:20 +0200128 PyObject *ensure(PyObject *ptr) {
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200129 API &api = lookup_api();
Wenzel Jakob43398a82015-07-28 16:12:20 +0200130 PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<T>::value);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200131 return api.PyArray_FromAny(ptr, descr, 0, 0,
Wenzel Jakob43398a82015-07-28 16:12:20 +0200132 API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY |
133 API::NPY_NPY_ARRAY_FORCECAST, nullptr);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200134 }
135};
136
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200137#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
138DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT);
139DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT);
140DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT);
141DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT);
142DECL_FMT(std::complex<double>, NPY_CDOUBLE);
143#undef DECL_FMT
144
145
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200146NAMESPACE_BEGIN(detail)
147PYBIND_TYPE_CASTER_PYTYPE(array)
148PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int8_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint8_t>)
149PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint16_t>)
150PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>)
151PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>)
152PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>)
Wenzel Jakob43398a82015-07-28 16:12:20 +0200153PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<float>>)
154PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<double>>)
155PYBIND_TYPE_CASTER_PYTYPE(array_dtype<bool>)
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200156
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200157template <typename Func, typename Return, typename... Args>
158struct vectorize_helper {
159 typename std::remove_reference<Func>::type f;
Wenzel Jakob43398a82015-07-28 16:12:20 +0200160
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200161 vectorize_helper(const Func &f) : f(f) { }
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200162
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200163 object operator()(array_dtype<Args>... args) {
164 return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
165 }
166
167 template <size_t ... Index> object run(array_dtype<Args>&... args, index_sequence<Index...>) {
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200168 /* Request buffers from all parameters */
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200169 const size_t N = sizeof...(Args);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200170 std::array<buffer_info, N> buffers {{ args.request()... }};
171
172 /* Determine dimensions parameters of output array */
173 int ndim = 0; size_t count = 0;
174 std::vector<size_t> shape;
175 for (size_t i=0; i<N; ++i) {
176 if (buffers[i].count > count) {
177 ndim = buffers[i].ndim;
178 shape = buffers[i].shape;
179 count = buffers[i].count;
180 }
181 }
182 std::vector<size_t> strides(ndim);
183 if (ndim > 0) {
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200184 strides[ndim-1] = sizeof(Return);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200185 for (int i=ndim-1; i>0; --i)
186 strides[i-1] = strides[i] * shape[i];
187 }
188
189 /* Check if the parameters are actually compatible */
190 for (size_t i=0; i<N; ++i) {
191 if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
192 throw std::runtime_error("pybind::vectorize: incompatible size/dimension of inputs!");
193 }
194
195 /* Call the function */
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200196 std::vector<Return> result(count);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200197 for (size_t i=0; i<count; ++i)
198 result[i] = f((buffers[Index].count == 1
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200199 ? *((Args *) buffers[Index].ptr)
200 : ((Args *) buffers[Index].ptr)[i])...);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200201
202 if (count == 1)
203 return cast(result[0]);
204
205 /* Return the result */
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200206 return array(buffer_info(result.data(), sizeof(Return),
207 format_descriptor<Return>::value(),
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200208 ndim, shape, strides));
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200209 }
210};
211
212NAMESPACE_END(detail)
213
214template <typename Func, typename Return, typename... Args>
215detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
216 return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200217}
218
Wenzel Jakoba576e6a2015-07-29 17:51:54 +0200219template <typename Return, typename... Args>
220detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
221 return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakobd4258ba2015-07-26 16:33:49 +0200222}
223
224template <typename func> auto vectorize(func &&f) -> decltype(
225 vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) {
226 return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(
227 &std::remove_reference<func>::type::operator())>::type *) nullptr);
228}
229
230NAMESPACE_END(pybind)
231
232#if defined(_MSC_VER)
233#pragma warning(pop)
234#endif