blob: 0035cc6e3a7ad27133d44d13aeda0755fca83b5b [file] [log] [blame]
Wenzel Jakob9e0a0562016-05-05 20:33:54 +02001/*
2 pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
3
4 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
5
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include "numpy.h"
Wenzel Jakob0a078052016-05-29 13:40:40 +020013
14#if defined(__GNUG__) || defined(__clang__)
15# pragma GCC diagnostic push
16# pragma GCC diagnostic ignored "-Wconversion"
Wenzel Jakobb5692722016-05-30 11:37:03 +020017# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
Wenzel Jakob0a078052016-05-29 13:40:40 +020018#endif
19
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020020#include <Eigen/Core>
21#include <Eigen/SparseCore>
22
Wenzel Jakob0a078052016-05-29 13:40:40 +020023#if defined(__GNUG__) || defined(__clang__)
24# pragma GCC diagnostic pop
25#endif
26
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020027#if defined(_MSC_VER)
28#pragma warning(push)
29#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
30#endif
31
32NAMESPACE_BEGIN(pybind11)
33NAMESPACE_BEGIN(detail)
34
35template <typename T> class is_eigen_dense {
36private:
37 template<typename Derived> static std::true_type test(const Eigen::DenseBase<Derived> &);
38 static std::false_type test(...);
39public:
40 static constexpr bool value = decltype(test(std::declval<T>()))::value;
41};
42
Jason Rhinelander8657f302016-08-04 13:21:39 -040043// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, so it needs a special
44// type_caster to handle argument copying/forwarding.
Jason Rhinelander5fd50742016-08-03 16:50:22 -040045template <typename T> class is_eigen_ref {
46private:
47 template<typename Derived> static typename std::enable_if<
48 std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value,
49 Derived>::type test(const Eigen::Ref<Derived> &);
50 static void test(...);
51public:
52 typedef decltype(test(std::declval<T>())) Derived;
53 static constexpr bool value = !std::is_void<Derived>::value;
54};
55
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020056template <typename T> class is_eigen_sparse {
57private:
58 template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
59 static std::false_type test(...);
60public:
61 static constexpr bool value = decltype(test(std::declval<T>()))::value;
62};
63
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -040064// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
65// basically covers anything that can be assigned to a dense matrix but that don't have a typical
66// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
67// SelfAdjointView fall into this category.
68template <typename T> class is_eigen_base {
69private:
70 template<typename Derived> static std::true_type test(const Eigen::EigenBase<Derived> &);
71 static std::false_type test(...);
72public:
73 static constexpr bool value = !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value &&
74 decltype(test(std::declval<T>()))::value;
75};
76
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020077template<typename Type>
Jason Rhinelander5fd50742016-08-03 16:50:22 -040078struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020079 typedef typename Type::Scalar Scalar;
80 static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
Wenzel Jakoba970a572016-05-20 12:00:56 +020081 static constexpr bool isVector = Type::IsVectorAtCompileTime;
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020082
83 bool load(handle src, bool) {
84 array_t<Scalar> buffer(src, true);
85 if (!buffer.check())
86 return false;
87
Ivan Smirnov6956b652016-08-15 01:24:59 +010088 auto info = buffer.request();
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020089 if (info.ndim == 1) {
Ben North93594a32016-07-05 20:05:10 +010090 typedef Eigen::InnerStride<> Strides;
Wenzel Jakoba970a572016-05-20 12:00:56 +020091 if (!isVector &&
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020092 !(Type::RowsAtCompileTime == Eigen::Dynamic &&
93 Type::ColsAtCompileTime == Eigen::Dynamic))
94 return false;
95
96 if (Type::SizeAtCompileTime != Eigen::Dynamic &&
97 info.shape[0] != (size_t) Type::SizeAtCompileTime)
98 return false;
99
Ben North93594a32016-07-05 20:05:10 +0100100 auto strides = Strides(info.strides[0] / sizeof(Scalar));
101
Wenzel Jakob5ba89c32016-07-09 15:44:54 +0200102 Strides::Index n_elts = (Strides::Index) info.shape[0];
Ben North93594a32016-07-05 20:05:10 +0100103 Strides::Index unity = 1;
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200104
105 value = Eigen::Map<Type, 0, Strides>(
Ben North93594a32016-07-05 20:05:10 +0100106 (Scalar *) info.ptr, rowMajor ? unity : n_elts, rowMajor ? n_elts : unity, strides);
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200107 } else if (info.ndim == 2) {
108 typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
109
110 if ((Type::RowsAtCompileTime != Eigen::Dynamic && info.shape[0] != (size_t) Type::RowsAtCompileTime) ||
111 (Type::ColsAtCompileTime != Eigen::Dynamic && info.shape[1] != (size_t) Type::ColsAtCompileTime))
112 return false;
113
114 auto strides = Strides(
115 info.strides[rowMajor ? 0 : 1] / sizeof(Scalar),
116 info.strides[rowMajor ? 1 : 0] / sizeof(Scalar));
117
118 value = Eigen::Map<Type, 0, Strides>(
Wenzel Jakob0a078052016-05-29 13:40:40 +0200119 (Scalar *) info.ptr,
120 typename Strides::Index(info.shape[0]),
121 typename Strides::Index(info.shape[1]), strides);
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200122 } else {
123 return false;
124 }
125 return true;
126 }
127
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200128 static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
Wenzel Jakoba970a572016-05-20 12:00:56 +0200129 if (isVector) {
Ivan Smirnov6956b652016-08-15 01:24:59 +0100130 return array(
131 { (size_t) src.size() }, // shape
132 { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) }, // strides
133 src.data() // data
134 ).release();
Wenzel Jakoba970a572016-05-20 12:00:56 +0200135 } else {
Ivan Smirnov6956b652016-08-15 01:24:59 +0100136 return array(
137 { (size_t) src.rows(), // shape
Wenzel Jakoba970a572016-05-20 12:00:56 +0200138 (size_t) src.cols() },
Ivan Smirnov6956b652016-08-15 01:24:59 +0100139 { sizeof(Scalar) * static_cast<size_t>(src.rowStride()), // strides
140 sizeof(Scalar) * static_cast<size_t>(src.colStride()) },
141 src.data() // data
142 ).release();
Wenzel Jakoba970a572016-05-20 12:00:56 +0200143 }
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200144 }
145
Dean Moldovaned23dda2016-08-04 01:40:40 +0200146 PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
147 _("[") + rows() + _(", ") + cols() + _("]]"));
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200148
Wenzel Jakobb4378672016-05-24 21:39:41 +0200149protected:
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200150 template <typename T = Type, typename std::enable_if<T::RowsAtCompileTime == Eigen::Dynamic, int>::type = 0>
151 static PYBIND11_DESCR rows() { return _("m"); }
152 template <typename T = Type, typename std::enable_if<T::RowsAtCompileTime != Eigen::Dynamic, int>::type = 0>
153 static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); }
154 template <typename T = Type, typename std::enable_if<T::ColsAtCompileTime == Eigen::Dynamic, int>::type = 0>
155 static PYBIND11_DESCR cols() { return _("n"); }
156 template <typename T = Type, typename std::enable_if<T::ColsAtCompileTime != Eigen::Dynamic, int>::type = 0>
157 static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200158};
159
160template<typename Type>
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400161struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400162protected:
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400163 using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
164 using DerivedCaster = type_caster<Derived>;
165 DerivedCaster derived_caster;
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400166 std::unique_ptr<Type> value;
167public:
168 bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
169 static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
170 static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
171
172 static PYBIND11_DESCR name() { return DerivedCaster::name(); }
173
174 operator Type*() { return value.get(); }
175 operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
176 template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
177};
178
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400179// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can
180// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that.
181template <typename Type>
182struct type_caster<Type, typename std::enable_if<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>::type> {
183protected:
184 using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>;
185 using MatrixCaster = type_caster<Matrix>;
186public:
187 [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); }
188 static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); }
189 static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); }
190
191 static PYBIND11_DESCR name() { return MatrixCaster::name(); }
192
193 [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
194 [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
195 template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
196};
197
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400198template<typename Type>
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200199struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
200 typedef typename Type::Scalar Scalar;
201 typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
202 typedef typename Type::Index Index;
203 static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
204
205 bool load(handle src, bool) {
Wenzel Jakob178c8a82016-05-10 15:59:01 +0100206 if (!src)
207 return false;
208
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200209 object obj(src, true);
210 object sparse_module = module::import("scipy.sparse");
211 object matrix_type = sparse_module.attr(
212 rowMajor ? "csr_matrix" : "csc_matrix");
213
214 if (obj.get_type() != matrix_type.ptr()) {
215 try {
Wenzel Jakob6c03beb2016-05-08 14:34:09 +0200216 obj = matrix_type(obj);
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200217 } catch (const error_already_set &) {
Ivan Smirnov42ad3282016-06-19 14:39:41 +0100218 PyErr_Clear();
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200219 return false;
220 }
221 }
222
223 auto valuesArray = array_t<Scalar>((object) obj.attr("data"));
224 auto innerIndicesArray = array_t<StorageIndex>((object) obj.attr("indices"));
225 auto outerIndicesArray = array_t<StorageIndex>((object) obj.attr("indptr"));
226 auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
227 auto nnz = obj.attr("nnz").cast<Index>();
228
229 if (!valuesArray.check() || !innerIndicesArray.check() ||
230 !outerIndicesArray.check())
231 return false;
232
Ivan Smirnov6956b652016-08-15 01:24:59 +0100233 auto outerIndices = outerIndicesArray.request();
234 auto innerIndices = innerIndicesArray.request();
235 auto values = valuesArray.request();
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200236
237 value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
238 shape[0].cast<Index>(),
239 shape[1].cast<Index>(),
240 nnz,
241 static_cast<StorageIndex *>(outerIndices.ptr),
242 static_cast<StorageIndex *>(innerIndices.ptr),
243 static_cast<Scalar *>(values.ptr)
244 );
245
246 return true;
247 }
248
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200249 static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
250 const_cast<Type&>(src).makeCompressed();
251
252 object matrix_type = module::import("scipy.sparse").attr(
253 rowMajor ? "csr_matrix" : "csc_matrix");
254
Ivan Smirnov6956b652016-08-15 01:24:59 +0100255 array data((size_t) src.nonZeros(), src.valuePtr());
256 array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
257 array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200258
Wenzel Jakob6c03beb2016-05-08 14:34:09 +0200259 return matrix_type(
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200260 std::make_tuple(data, innerIndices, outerIndices),
261 std::make_pair(src.rows(), src.cols())
262 ).release();
263 }
264
Dean Moldovaned23dda2016-08-04 01:40:40 +0200265 PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
Jason Rhinelander8469f752016-07-06 00:40:54 -0400266 + npy_format_descriptor<Scalar>::name() + _("]"));
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200267};
268
269NAMESPACE_END(detail)
270NAMESPACE_END(pybind11)
271
272#if defined(_MSC_VER)
273#pragma warning(pop)
274#endif