Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 1 | /* |
| 2 | pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices |
| 3 | |
| 4 | Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> |
| 5 | |
| 6 | All rights reserved. Use of this source code is governed by a |
| 7 | BSD-style license that can be found in the LICENSE file. |
| 8 | */ |
| 9 | |
| 10 | #pragma once |
| 11 | |
| 12 | #include "numpy.h" |
Wenzel Jakob | 0a07805 | 2016-05-29 13:40:40 +0200 | [diff] [blame] | 13 | |
| 14 | #if defined(__GNUG__) || defined(__clang__) |
| 15 | # pragma GCC diagnostic push |
| 16 | # pragma GCC diagnostic ignored "-Wconversion" |
Wenzel Jakob | b569272 | 2016-05-30 11:37:03 +0200 | [diff] [blame] | 17 | # pragma GCC diagnostic ignored "-Wdeprecated-declarations" |
Wenzel Jakob | 0a07805 | 2016-05-29 13:40:40 +0200 | [diff] [blame] | 18 | #endif |
| 19 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 20 | #include <Eigen/Core> |
| 21 | #include <Eigen/SparseCore> |
| 22 | |
Wenzel Jakob | 0a07805 | 2016-05-29 13:40:40 +0200 | [diff] [blame] | 23 | #if defined(__GNUG__) || defined(__clang__) |
| 24 | # pragma GCC diagnostic pop |
| 25 | #endif |
| 26 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 27 | #if defined(_MSC_VER) |
| 28 | #pragma warning(push) |
| 29 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant |
| 30 | #endif |
| 31 | |
| 32 | NAMESPACE_BEGIN(pybind11) |
| 33 | NAMESPACE_BEGIN(detail) |
| 34 | |
| 35 | template <typename T> class is_eigen_dense { |
| 36 | private: |
| 37 | template<typename Derived> static std::true_type test(const Eigen::DenseBase<Derived> &); |
| 38 | static std::false_type test(...); |
| 39 | public: |
| 40 | static constexpr bool value = decltype(test(std::declval<T>()))::value; |
| 41 | }; |
| 42 | |
Jason Rhinelander | 8657f30 | 2016-08-04 13:21:39 -0400 | [diff] [blame] | 43 | // Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, so it needs a special |
| 44 | // type_caster to handle argument copying/forwarding. |
Jason Rhinelander | 5fd5074 | 2016-08-03 16:50:22 -0400 | [diff] [blame] | 45 | template <typename T> class is_eigen_ref { |
| 46 | private: |
| 47 | template<typename Derived> static typename std::enable_if< |
| 48 | std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value, |
| 49 | Derived>::type test(const Eigen::Ref<Derived> &); |
| 50 | static void test(...); |
| 51 | public: |
| 52 | typedef decltype(test(std::declval<T>())) Derived; |
| 53 | static constexpr bool value = !std::is_void<Derived>::value; |
| 54 | }; |
| 55 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 56 | template <typename T> class is_eigen_sparse { |
| 57 | private: |
| 58 | template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &); |
| 59 | static std::false_type test(...); |
| 60 | public: |
| 61 | static constexpr bool value = decltype(test(std::declval<T>()))::value; |
| 62 | }; |
| 63 | |
Jason Rhinelander | 9ffb3dd | 2016-08-04 15:24:41 -0400 | [diff] [blame] | 64 | // Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This |
| 65 | // basically covers anything that can be assigned to a dense matrix but that don't have a typical |
| 66 | // matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and |
| 67 | // SelfAdjointView fall into this category. |
| 68 | template <typename T> class is_eigen_base { |
| 69 | private: |
| 70 | template<typename Derived> static std::true_type test(const Eigen::EigenBase<Derived> &); |
| 71 | static std::false_type test(...); |
| 72 | public: |
| 73 | static constexpr bool value = !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value && |
| 74 | decltype(test(std::declval<T>()))::value; |
| 75 | }; |
| 76 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 77 | template<typename Type> |
Jason Rhinelander | 5fd5074 | 2016-08-03 16:50:22 -0400 | [diff] [blame] | 78 | struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> { |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 79 | typedef typename Type::Scalar Scalar; |
| 80 | static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; |
Wenzel Jakob | a970a57 | 2016-05-20 12:00:56 +0200 | [diff] [blame] | 81 | static constexpr bool isVector = Type::IsVectorAtCompileTime; |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 82 | |
| 83 | bool load(handle src, bool) { |
| 84 | array_t<Scalar> buffer(src, true); |
| 85 | if (!buffer.check()) |
| 86 | return false; |
| 87 | |
Ivan Smirnov | 6956b65 | 2016-08-15 01:24:59 +0100 | [diff] [blame^] | 88 | auto info = buffer.request(); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 89 | if (info.ndim == 1) { |
Ben North | 93594a3 | 2016-07-05 20:05:10 +0100 | [diff] [blame] | 90 | typedef Eigen::InnerStride<> Strides; |
Wenzel Jakob | a970a57 | 2016-05-20 12:00:56 +0200 | [diff] [blame] | 91 | if (!isVector && |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 92 | !(Type::RowsAtCompileTime == Eigen::Dynamic && |
| 93 | Type::ColsAtCompileTime == Eigen::Dynamic)) |
| 94 | return false; |
| 95 | |
| 96 | if (Type::SizeAtCompileTime != Eigen::Dynamic && |
| 97 | info.shape[0] != (size_t) Type::SizeAtCompileTime) |
| 98 | return false; |
| 99 | |
Ben North | 93594a3 | 2016-07-05 20:05:10 +0100 | [diff] [blame] | 100 | auto strides = Strides(info.strides[0] / sizeof(Scalar)); |
| 101 | |
Wenzel Jakob | 5ba89c3 | 2016-07-09 15:44:54 +0200 | [diff] [blame] | 102 | Strides::Index n_elts = (Strides::Index) info.shape[0]; |
Ben North | 93594a3 | 2016-07-05 20:05:10 +0100 | [diff] [blame] | 103 | Strides::Index unity = 1; |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 104 | |
| 105 | value = Eigen::Map<Type, 0, Strides>( |
Ben North | 93594a3 | 2016-07-05 20:05:10 +0100 | [diff] [blame] | 106 | (Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 107 | } else if (info.ndim == 2) { |
| 108 | typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides; |
| 109 | |
| 110 | if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) || |
| 111 | (Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime)) |
| 112 | return false; |
| 113 | |
| 114 | auto strides = Strides( |
| 115 | info.strides[rowMajor ? 0 : 1] / sizeof(Scalar), |
| 116 | info.strides[rowMajor ? 1 : 0] / sizeof(Scalar)); |
| 117 | |
| 118 | value = Eigen::Map<Type, 0, Strides>( |
Wenzel Jakob | 0a07805 | 2016-05-29 13:40:40 +0200 | [diff] [blame] | 119 | (Scalar *) info.ptr, |
| 120 | typename Strides::Index(info.shape[0]), |
| 121 | typename Strides::Index(info.shape[1]), strides); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 122 | } else { |
| 123 | return false; |
| 124 | } |
| 125 | return true; |
| 126 | } |
| 127 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 128 | static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { |
Wenzel Jakob | a970a57 | 2016-05-20 12:00:56 +0200 | [diff] [blame] | 129 | if (isVector) { |
Ivan Smirnov | 6956b65 | 2016-08-15 01:24:59 +0100 | [diff] [blame^] | 130 | return array( |
| 131 | { (size_t) src.size() }, // shape |
| 132 | { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) }, // strides |
| 133 | src.data() // data |
| 134 | ).release(); |
Wenzel Jakob | a970a57 | 2016-05-20 12:00:56 +0200 | [diff] [blame] | 135 | } else { |
Ivan Smirnov | 6956b65 | 2016-08-15 01:24:59 +0100 | [diff] [blame^] | 136 | return array( |
| 137 | { (size_t) src.rows(), // shape |
Wenzel Jakob | a970a57 | 2016-05-20 12:00:56 +0200 | [diff] [blame] | 138 | (size_t) src.cols() }, |
Ivan Smirnov | 6956b65 | 2016-08-15 01:24:59 +0100 | [diff] [blame^] | 139 | { sizeof(Scalar) * static_cast<size_t>(src.rowStride()), // strides |
| 140 | sizeof(Scalar) * static_cast<size_t>(src.colStride()) }, |
| 141 | src.data() // data |
| 142 | ).release(); |
Wenzel Jakob | a970a57 | 2016-05-20 12:00:56 +0200 | [diff] [blame] | 143 | } |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 144 | } |
| 145 | |
Dean Moldovan | ed23dda | 2016-08-04 01:40:40 +0200 | [diff] [blame] | 146 | PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() + |
| 147 | _("[") + rows() + _(", ") + cols() + _("]]")); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 148 | |
Wenzel Jakob | b437867 | 2016-05-24 21:39:41 +0200 | [diff] [blame] | 149 | protected: |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 150 | template <typename T = Type, typename std::enable_if<T::RowsAtCompileTime == Eigen::Dynamic, int>::type = 0> |
| 151 | static PYBIND11_DESCR rows() { return _("m"); } |
| 152 | template <typename T = Type, typename std::enable_if<T::RowsAtCompileTime != Eigen::Dynamic, int>::type = 0> |
| 153 | static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); } |
| 154 | template <typename T = Type, typename std::enable_if<T::ColsAtCompileTime == Eigen::Dynamic, int>::type = 0> |
| 155 | static PYBIND11_DESCR cols() { return _("n"); } |
| 156 | template <typename T = Type, typename std::enable_if<T::ColsAtCompileTime != Eigen::Dynamic, int>::type = 0> |
| 157 | static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); } |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 158 | }; |
| 159 | |
| 160 | template<typename Type> |
Jason Rhinelander | 5fd5074 | 2016-08-03 16:50:22 -0400 | [diff] [blame] | 161 | struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> { |
Jason Rhinelander | 9ffb3dd | 2016-08-04 15:24:41 -0400 | [diff] [blame] | 162 | protected: |
Jason Rhinelander | 5fd5074 | 2016-08-03 16:50:22 -0400 | [diff] [blame] | 163 | using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type; |
| 164 | using DerivedCaster = type_caster<Derived>; |
| 165 | DerivedCaster derived_caster; |
Jason Rhinelander | 5fd5074 | 2016-08-03 16:50:22 -0400 | [diff] [blame] | 166 | std::unique_ptr<Type> value; |
| 167 | public: |
| 168 | bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; } |
| 169 | static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); } |
| 170 | static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); } |
| 171 | |
| 172 | static PYBIND11_DESCR name() { return DerivedCaster::name(); } |
| 173 | |
| 174 | operator Type*() { return value.get(); } |
| 175 | operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; } |
| 176 | template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; |
| 177 | }; |
| 178 | |
Jason Rhinelander | 9ffb3dd | 2016-08-04 15:24:41 -0400 | [diff] [blame] | 179 | // type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can |
| 180 | // cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that. |
| 181 | template <typename Type> |
| 182 | struct type_caster<Type, typename std::enable_if<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>::type> { |
| 183 | protected: |
| 184 | using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>; |
| 185 | using MatrixCaster = type_caster<Matrix>; |
| 186 | public: |
| 187 | [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); } |
| 188 | static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); } |
| 189 | static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); } |
| 190 | |
| 191 | static PYBIND11_DESCR name() { return MatrixCaster::name(); } |
| 192 | |
| 193 | [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); } |
| 194 | [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); } |
| 195 | template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>; |
| 196 | }; |
| 197 | |
Jason Rhinelander | 5fd5074 | 2016-08-03 16:50:22 -0400 | [diff] [blame] | 198 | template<typename Type> |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 199 | struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> { |
| 200 | typedef typename Type::Scalar Scalar; |
| 201 | typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex; |
| 202 | typedef typename Type::Index Index; |
| 203 | static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; |
| 204 | |
| 205 | bool load(handle src, bool) { |
Wenzel Jakob | 178c8a8 | 2016-05-10 15:59:01 +0100 | [diff] [blame] | 206 | if (!src) |
| 207 | return false; |
| 208 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 209 | object obj(src, true); |
| 210 | object sparse_module = module::import("scipy.sparse"); |
| 211 | object matrix_type = sparse_module.attr( |
| 212 | rowMajor ? "csr_matrix" : "csc_matrix"); |
| 213 | |
| 214 | if (obj.get_type() != matrix_type.ptr()) { |
| 215 | try { |
Wenzel Jakob | 6c03beb | 2016-05-08 14:34:09 +0200 | [diff] [blame] | 216 | obj = matrix_type(obj); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 217 | } catch (const error_already_set &) { |
Ivan Smirnov | 42ad328 | 2016-06-19 14:39:41 +0100 | [diff] [blame] | 218 | PyErr_Clear(); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 219 | return false; |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | auto valuesArray = array_t<Scalar>((object) obj.attr("data")); |
| 224 | auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices")); |
| 225 | auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr")); |
| 226 | auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); |
| 227 | auto nnz = obj.attr("nnz").cast<Index>(); |
| 228 | |
| 229 | if (!valuesArray.check() || !innerIndicesArray.check() || |
| 230 | !outerIndicesArray.check()) |
| 231 | return false; |
| 232 | |
Ivan Smirnov | 6956b65 | 2016-08-15 01:24:59 +0100 | [diff] [blame^] | 233 | auto outerIndices = outerIndicesArray.request(); |
| 234 | auto innerIndices = innerIndicesArray.request(); |
| 235 | auto values = valuesArray.request(); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 236 | |
| 237 | value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>( |
| 238 | shape[0].cast<Index>(), |
| 239 | shape[1].cast<Index>(), |
| 240 | nnz, |
| 241 | static_cast<StorageIndex *>(outerIndices.ptr), |
| 242 | static_cast<StorageIndex *>(innerIndices.ptr), |
| 243 | static_cast<Scalar *>(values.ptr) |
| 244 | ); |
| 245 | |
| 246 | return true; |
| 247 | } |
| 248 | |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 249 | static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { |
| 250 | const_cast<Type&>(src).makeCompressed(); |
| 251 | |
| 252 | object matrix_type = module::import("scipy.sparse").attr( |
| 253 | rowMajor ? "csr_matrix" : "csc_matrix"); |
| 254 | |
Ivan Smirnov | 6956b65 | 2016-08-15 01:24:59 +0100 | [diff] [blame^] | 255 | array data((size_t) src.nonZeros(), src.valuePtr()); |
| 256 | array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr()); |
| 257 | array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr()); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 258 | |
Wenzel Jakob | 6c03beb | 2016-05-08 14:34:09 +0200 | [diff] [blame] | 259 | return matrix_type( |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 260 | std::make_tuple(data, innerIndices, outerIndices), |
| 261 | std::make_pair(src.rows(), src.cols()) |
| 262 | ).release(); |
| 263 | } |
| 264 | |
Dean Moldovan | ed23dda | 2016-08-04 01:40:40 +0200 | [diff] [blame] | 265 | PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") |
Jason Rhinelander | 8469f75 | 2016-07-06 00:40:54 -0400 | [diff] [blame] | 266 | + npy_format_descriptor<Scalar>::name() + _("]")); |
Wenzel Jakob | 9e0a056 | 2016-05-05 20:33:54 +0200 | [diff] [blame] | 267 | }; |
| 268 | |
| 269 | NAMESPACE_END(detail) |
| 270 | NAMESPACE_END(pybind11) |
| 271 | |
| 272 | #if defined(_MSC_VER) |
| 273 | #pragma warning(pop) |
| 274 | #endif |