Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 1 | /* |
| 2 | pybind/numpy.h: Basic NumPy support, auto-vectorization support |
| 3 | |
| 4 | Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch> |
| 5 | |
| 6 | All rights reserved. Use of this source code is governed by a |
| 7 | BSD-style license that can be found in the LICENSE file. |
| 8 | */ |
| 9 | |
| 10 | #pragma once |
| 11 | |
| 12 | #include <pybind/pybind.h> |
| 13 | #if defined(_MSC_VER) |
| 14 | #pragma warning(push) |
| 15 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant |
| 16 | #endif |
| 17 | |
| 18 | NAMESPACE_BEGIN(pybind) |
| 19 | |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 20 | template <typename type> struct npy_format_descriptor { }; |
| 21 | |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 22 | class array : public buffer { |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 23 | public: |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 24 | struct API { |
| 25 | enum Entries { |
| 26 | API_PyArray_Type = 2, |
| 27 | API_PyArray_DescrFromType = 45, |
| 28 | API_PyArray_FromAny = 69, |
| 29 | API_PyArray_NewCopy = 85, |
| 30 | API_PyArray_NewFromDescr = 94, |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 31 | NPY_C_CONTIGUOUS = 0x0001, |
| 32 | NPY_F_CONTIGUOUS = 0x0002, |
| 33 | NPY_NPY_ARRAY_FORCECAST = 0x0010, |
| 34 | NPY_ENSURE_ARRAY = 0x0040, |
| 35 | NPY_BOOL=0, |
| 36 | NPY_BYTE, NPY_UBYTE, |
| 37 | NPY_SHORT, NPY_USHORT, |
| 38 | NPY_INT, NPY_UINT, |
| 39 | NPY_LONG, NPY_ULONG, |
| 40 | NPY_LONGLONG, NPY_ULONGLONG, |
| 41 | NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, |
| 42 | NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 43 | }; |
| 44 | |
| 45 | static API lookup() { |
| 46 | PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray"); |
| 47 | PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr; |
| 48 | void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr); |
| 49 | Py_XDECREF(capsule); |
| 50 | Py_XDECREF(numpy); |
| 51 | if (api_ptr == nullptr) |
| 52 | throw std::runtime_error("Could not acquire pointer to NumPy API!"); |
| 53 | API api; |
| 54 | api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type]; |
| 55 | api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType]; |
| 56 | api.PyArray_FromAny = (decltype(api.PyArray_FromAny)) api_ptr[API_PyArray_FromAny]; |
| 57 | api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy]; |
| 58 | api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr]; |
| 59 | return api; |
| 60 | } |
| 61 | |
| 62 | bool PyArray_Check(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type); } |
| 63 | |
| 64 | PyObject *(*PyArray_DescrFromType)(int); |
| 65 | PyObject *(*PyArray_NewFromDescr) |
| 66 | (PyTypeObject *, PyObject *, int, Py_intptr_t *, |
| 67 | Py_intptr_t *, void *, int, PyObject *); |
| 68 | PyObject *(*PyArray_NewCopy)(PyObject *, int); |
| 69 | PyTypeObject *PyArray_Type; |
| 70 | PyObject *(*PyArray_FromAny) (PyObject *, PyObject *, int, int, int, PyObject *); |
| 71 | }; |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 72 | |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 73 | PYBIND_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check) |
| 74 | |
| 75 | template <typename Type> array(size_t size, const Type *ptr) { |
| 76 | API& api = lookup_api(); |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 77 | PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<Type>::value); |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 78 | if (descr == nullptr) |
| 79 | throw std::runtime_error("NumPy: unsupported buffer format!"); |
| 80 | Py_intptr_t shape = (Py_intptr_t) size; |
| 81 | PyObject *tmp = api.PyArray_NewFromDescr( |
| 82 | api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr); |
| 83 | if (tmp == nullptr) |
| 84 | throw std::runtime_error("NumPy: unable to create array!"); |
| 85 | m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */); |
| 86 | Py_DECREF(tmp); |
| 87 | if (m_ptr == nullptr) |
| 88 | throw std::runtime_error("NumPy: unable to copy array!"); |
| 89 | } |
| 90 | |
| 91 | array(const buffer_info &info) { |
| 92 | API& api = lookup_api(); |
| 93 | if (info.format.size() != 1) |
| 94 | throw std::runtime_error("Unsupported buffer format!"); |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 95 | int fmt = (int) info.format[0]; |
| 96 | if (info.format == "Zd") |
| 97 | fmt = API::NPY_CDOUBLE; |
| 98 | else if (info.format == "Zf") |
| 99 | fmt = API::NPY_CFLOAT; |
| 100 | PyObject *descr = api.PyArray_DescrFromType(fmt); |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 101 | if (descr == nullptr) |
| 102 | throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!"); |
| 103 | PyObject *tmp = api.PyArray_NewFromDescr( |
| 104 | api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0], |
| 105 | (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr); |
| 106 | if (tmp == nullptr) |
| 107 | throw std::runtime_error("NumPy: unable to create array!"); |
| 108 | m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */); |
| 109 | Py_DECREF(tmp); |
| 110 | if (m_ptr == nullptr) |
| 111 | throw std::runtime_error("NumPy: unable to copy array!"); |
| 112 | } |
| 113 | |
| 114 | protected: |
| 115 | static API &lookup_api() { |
| 116 | static API api = API::lookup(); |
| 117 | return api; |
| 118 | } |
| 119 | }; |
| 120 | |
| 121 | template <typename T> class array_dtype : public array { |
| 122 | public: |
| 123 | PYBIND_OBJECT_CVT(array_dtype, array, is_non_null, m_ptr = ensure(m_ptr)); |
| 124 | array_dtype() : array() { } |
| 125 | static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 126 | PyObject *ensure(PyObject *ptr) { |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 127 | API &api = lookup_api(); |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 128 | PyObject *descr = api.PyArray_DescrFromType(npy_format_descriptor<T>::value); |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 129 | return api.PyArray_FromAny(ptr, descr, 0, 0, |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 130 | API::NPY_C_CONTIGUOUS | API::NPY_ENSURE_ARRAY | |
| 131 | API::NPY_NPY_ARRAY_FORCECAST, nullptr); |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 132 | } |
| 133 | }; |
| 134 | |
| 135 | NAMESPACE_BEGIN(detail) |
| 136 | PYBIND_TYPE_CASTER_PYTYPE(array) |
| 137 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int8_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint8_t>) |
| 138 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int16_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint16_t>) |
| 139 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int32_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint32_t>) |
| 140 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<int64_t>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<uint64_t>) |
| 141 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<float>) PYBIND_TYPE_CASTER_PYTYPE(array_dtype<double>) |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 142 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<float>>) |
| 143 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<std::complex<double>>) |
| 144 | PYBIND_TYPE_CASTER_PYTYPE(array_dtype<bool>) |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 145 | NAMESPACE_END(detail) |
| 146 | |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 147 | #define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; } |
| 148 | DECL_FMT(int8_t, NPY_BYTE); DECL_FMT(uint8_t, NPY_UBYTE); DECL_FMT(int16_t, NPY_SHORT); |
| 149 | DECL_FMT(uint16_t, NPY_USHORT); DECL_FMT(int32_t, NPY_INT); DECL_FMT(uint32_t, NPY_UINT); |
| 150 | DECL_FMT(int64_t, NPY_LONGLONG); DECL_FMT(uint64_t, NPY_ULONGLONG); DECL_FMT(float, NPY_FLOAT); |
| 151 | DECL_FMT(double, NPY_DOUBLE); DECL_FMT(bool, NPY_BOOL); DECL_FMT(std::complex<float>, NPY_CFLOAT); |
| 152 | DECL_FMT(std::complex<double>, NPY_CDOUBLE); |
| 153 | #undef DECL_FMT |
| 154 | |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 155 | template <typename func_type, typename return_type, typename... args_type, size_t... Index> |
| 156 | std::function<object(array_dtype<args_type>...)> |
| 157 | vectorize(func_type &&f, return_type (*) (args_type ...), |
| 158 | detail::index_sequence<Index...>) { |
| 159 | |
| 160 | return [f](array_dtype<args_type>... args) -> array { |
| 161 | /* Request buffers from all parameters */ |
| 162 | const size_t N = sizeof...(args_type); |
| 163 | std::array<buffer_info, N> buffers {{ args.request()... }}; |
| 164 | |
| 165 | /* Determine dimensions parameters of output array */ |
| 166 | int ndim = 0; size_t count = 0; |
| 167 | std::vector<size_t> shape; |
| 168 | for (size_t i=0; i<N; ++i) { |
| 169 | if (buffers[i].count > count) { |
| 170 | ndim = buffers[i].ndim; |
| 171 | shape = buffers[i].shape; |
| 172 | count = buffers[i].count; |
| 173 | } |
| 174 | } |
| 175 | std::vector<size_t> strides(ndim); |
| 176 | if (ndim > 0) { |
| 177 | strides[ndim-1] = sizeof(return_type); |
| 178 | for (int i=ndim-1; i>0; --i) |
| 179 | strides[i-1] = strides[i] * shape[i]; |
| 180 | } |
| 181 | |
| 182 | /* Check if the parameters are actually compatible */ |
| 183 | for (size_t i=0; i<N; ++i) { |
| 184 | if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape)) |
| 185 | throw std::runtime_error("pybind::vectorize: incompatible size/dimension of inputs!"); |
| 186 | } |
| 187 | |
| 188 | /* Call the function */ |
| 189 | std::vector<return_type> result(count); |
| 190 | for (size_t i=0; i<count; ++i) |
| 191 | result[i] = f((buffers[Index].count == 1 |
| 192 | ? *((args_type *) buffers[Index].ptr) |
| 193 | : ((args_type *) buffers[Index].ptr)[i])...); |
| 194 | |
| 195 | if (count == 1) |
| 196 | return cast(result[0]); |
| 197 | |
| 198 | /* Return the result */ |
Wenzel Jakob | 43398a8 | 2015-07-28 16:12:20 +0200 | [diff] [blame^] | 199 | return array(buffer_info(result.data(), sizeof(return_type), |
Wenzel Jakob | d4258ba | 2015-07-26 16:33:49 +0200 | [diff] [blame] | 200 | format_descriptor<return_type>::value(), |
| 201 | ndim, shape, strides)); |
| 202 | }; |
| 203 | } |
| 204 | |
| 205 | template <typename func_type, typename return_type, typename... args_type> |
| 206 | std::function<object(array_dtype<args_type>...)> |
| 207 | vectorize(func_type &&f, return_type (*f_) (args_type ...) = nullptr) { |
| 208 | return vectorize(f, f_, typename detail::make_index_sequence<sizeof...(args_type)>::type()); |
| 209 | } |
| 210 | |
| 211 | template <typename return_type, typename... args_type> |
| 212 | std::function<object(array_dtype<args_type>...)> vectorize(return_type (*f) (args_type ...)) { |
| 213 | return vectorize(f, f); |
| 214 | } |
| 215 | |
| 216 | template <typename func> auto vectorize(func &&f) -> decltype( |
| 217 | vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) { |
| 218 | return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype( |
| 219 | &std::remove_reference<func>::type::operator())>::type *) nullptr); |
| 220 | } |
| 221 | |
| 222 | NAMESPACE_END(pybind) |
| 223 | |
| 224 | #if defined(_MSC_VER) |
| 225 | #pragma warning(pop) |
| 226 | #endif |