blob: f4a4a74e783ef9311d4387218ea3cc9b51078c19 [file] [log] [blame]
Wenzel Jakobd4258ba2015-07-26 16:33:49 +02001/*
2 pybind/numpy.h: Basic NumPy support, auto-vectorization support
3
4 Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
5
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include <pybind/pybind.h>
13#if defined(_MSC_VER)
14#pragma warning(push)
15#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
16#endif
17
18NAMESPACE_BEGIN(pybind)
19
20class array : public buffer {
21protected:
22 struct API {
23 enum Entries {
24 API_PyArray_Type = 2,
25 API_PyArray_DescrFromType = 45,
26 API_PyArray_FromAny = 69,
27 API_PyArray_NewCopy = 85,
28 API_PyArray_NewFromDescr = 94,
29 API_NPY_C_CONTIGUOUS = 0x0001,
30 API_NPY_F_CONTIGUOUS = 0x0002,
31 API_NPY_NPY_ARRAY_FORCECAST = 0x0010,
32 API_NPY_ENSURE_ARRAY = 0x0040
33 };
34
35 static API lookup() {
36 PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
37 PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
38 void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
39 Py_XDECREF(capsule);
40 Py_XDECREF(numpy);
41 if (api_ptr == nullptr)
42 throw std::runtime_error("Could not acquire pointer to NumPy API!");
43 API api;
44 api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
45 api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
46 api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny];
47 api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
48 api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
49 return api;
50 }
51
52 bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); }
53
54 PyObject *(*PyArray_DescrFromType)(int);
55 PyObject *(*PyArray_NewFromDescr)
56 (PyTypeObject *, PyObject *, int, Py_intptr_t *,
57 Py_intptr_t *, void *, int, PyObject *);
58 PyObject *(*PyArray_NewCopy)(PyObject *, int);
59 PyTypeObject *PyArray_Type;
60 PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *);
61 };
62public:
63 PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
64
65 template <typename Type> array(size_t size, const Type *ptr) {
66 API& api = lookup_api();
67 PyObject *descr = api.PyArray_DescrFromType(
68 (int) format_descriptor<Type>::value()[0]);
69 if (descr == nullptr)
70 throw std::runtime_error("NumPy: unsupported buffer format!");
71 Py_intptr_t shape = (Py_intptr_t) size;
72 PyObject *tmp = api.PyArray_NewFromDescr(
73 api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
74 if (tmp == nullptr)
75 throw std::runtime_error("NumPy: unable to create array!");
76 m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
77 Py_DECREF(tmp);
78 if (m_ptr == nullptr)
79 throw std::runtime_error("NumPy: unable to copy array!");
80 }
81
82 array(const buffer_info &info) {
83 API& api = lookup_api();
84 if (info.format.size() != 1)
85 throw std::runtime_error("Unsupported buffer format!");
86 PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
87 if (descr == nullptr)
88 throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
89 PyObject *tmp = api.PyArray_NewFromDescr(
90 api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
91 (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
92 if (tmp == nullptr)
93 throw std::runtime_error("NumPy: unable to create array!");
94 m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
95 Py_DECREF(tmp);
96 if (m_ptr == nullptr)
97 throw std::runtime_error("NumPy: unable to copy array!");
98 }
99
100protected:
101 static API &lookup_api() {
102 static API api = API::lookup();
103 return api;
104 }
105};
106
107template <typename T> class array_dtype : public array {
108public:
109 PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr));
110 array_dtype() : array() { }
111 static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
112 static PyObject *ensure(PyObject *ptr) {
113 API &api = lookup_api();
114 PyObject *descr = api.PyArray_DescrFromType(format_descriptor<T>::value()[0]);
115 return api.PyArray_FromAny(ptr, descr, 0, 0,
116 API::API_NPY_C_CONTIGUOUS | API::API_NPY_ENSURE_ARRAY |
117 API::API_NPY_NPY_ARRAY_FORCECAST, nullptr);
118 }
119};
120
121NAMESPACE_BEGIN(detail)
122PYBIND_TYPE_CASTER_PYTYPE(array)
123PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int8_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint8_t>)
124PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint16_t>)
125PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>)
126PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>)
127PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>)
128NAMESPACE_END(detail)
129
130template <typename func_type, typename return_type, typename... args_type, size_t... Index>
131 std::function<object(array_dtype<args_type>...)>
132 vectorize(func_type &&f, return_type (*) (args_type ...),
133 detail::index_sequence<Index...>) {
134
135 return [f](array_dtype<args_type>... args) -> array {
136 /* Request buffers from all parameters */
137 const size_t N = sizeof...(args_type);
138 std::array<buffer_info, N> buffers {{ args.request()... }};
139
140 /* Determine dimensions parameters of output array */
141 int ndim = 0; size_t count = 0;
142 std::vector<size_t> shape;
143 for (size_t i=0; i<N; ++i) {
144 if (buffers[i].count > count) {
145 ndim = buffers[i].ndim;
146 shape = buffers[i].shape;
147 count = buffers[i].count;
148 }
149 }
150 std::vector<size_t> strides(ndim);
151 if (ndim > 0) {
152 strides[ndim-1] = sizeof(return_type);
153 for (int i=ndim-1; i>0; --i)
154 strides[i-1] = strides[i] * shape[i];
155 }
156
157 /* Check if the parameters are actually compatible */
158 for (size_t i=0; i<N; ++i) {
159 if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
160 throw std::runtime_error("pybind::vectorize: incompatible size/dimension of inputs!");
161 }
162
163 /* Call the function */
164 std::vector<return_type> result(count);
165 for (size_t i=0; i<count; ++i)
166 result[i] = f((buffers[Index].count == 1
167 ? *((args_type *) buffers[Index].ptr)
168 : ((args_type *) buffers[Index].ptr)[i])...);
169
170 if (count == 1)
171 return cast(result[0]);
172
173 /* Return the result */
174 return array(buffer_info(result.data(), sizeof(return_type),
175 format_descriptor<return_type>::value(),
176 ndim, shape, strides));
177 };
178}
179
180template <typename func_type, typename return_type, typename... args_type>
181 std::function<object(array_dtype<args_type>...)>
182 vectorize(func_type &&f, return_type (*f_) (args_type ...) = nullptr) {
183 return vectorize(f, f_, typename detail::make_index_sequence<sizeof...(args_type)>::type());
184}
185
186template <typename return_type, typename... args_type>
187std::function<object(array_dtype<args_type>...)> vectorize(return_type (*f) (args_type ...)) {
188 return vectorize(f, f);
189}
190
191template <typename func> auto vectorize(func &&f) -> decltype(
192 vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) {
193 return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(
194 &std::remove_reference<func>::type::operator())>::type *) nullptr);
195}
196
197NAMESPACE_END(pybind)
198
199#if defined(_MSC_VER)
200#pragma warning(pop)
201#endif