blob: 69194a2c426c7bee92033edf1931be0f5d96f09b [file] [log] [blame]
Wenzel Jakob9e0a0562016-05-05 20:33:54 +02001/*
2 pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
3
4 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
5
6 All rights reserved. Use of this source code is governed by a
7 BSD-style license that can be found in the LICENSE file.
8*/
9
10#pragma once
11
12#include "numpy.h"
Wenzel Jakob0a078052016-05-29 13:40:40 +020013
Wenzel Jakob8706fb92016-09-07 23:37:40 +090014#if defined(__INTEL_COMPILER)
15# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
16#elif defined(__GNUG__) || defined(__clang__)
Wenzel Jakob0a078052016-05-29 13:40:40 +020017# pragma GCC diagnostic push
18# pragma GCC diagnostic ignored "-Wconversion"
Wenzel Jakobb5692722016-05-30 11:37:03 +020019# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
Jason Rhinelandercb637702016-12-13 19:09:08 -050020# if __GNUC__ >= 7
21# pragma GCC diagnostic ignored "-Wint-in-bool-context"
22# endif
Wenzel Jakob0a078052016-05-29 13:40:40 +020023#endif
24
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020025#include <Eigen/Core>
26#include <Eigen/SparseCore>
27
28#if defined(_MSC_VER)
Jason Rhinelandercb637702016-12-13 19:09:08 -050029# pragma warning(push)
30# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020031#endif
32
Jason Rhinelander17d02832017-01-16 20:35:14 -050033// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
34// move constructors that break things. We could detect this an explicitly copy, but an extra copy
35// of matrices seems highly undesirable.
36static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
37
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020038NAMESPACE_BEGIN(pybind11)
Jason Rhinelander17d02832017-01-16 20:35:14 -050039
40// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
41using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
42template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
43template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
44
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020045NAMESPACE_BEGIN(detail)
46
Jason Rhinelander17d02832017-01-16 20:35:14 -050047#if EIGEN_VERSION_AT_LEAST(3,3,0)
48using EigenIndex = Eigen::Index;
49#else
50using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
51#endif
Wenzel Jakob9e0a0562016-05-05 20:33:54 +020052
Jason Rhinelander17d02832017-01-16 20:35:14 -050053// Matches Eigen::Map, Eigen::Ref, blocks, etc:
54template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
55template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
56template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
57template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -040058// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
59// basically covers anything that can be assigned to a dense matrix but that don't have a typical
60// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
61// SelfAdjointView fall into this category.
Jason Rhinelander17d02832017-01-16 20:35:14 -050062template <typename T> using is_eigen_other = all_of<
Jason Rhinelanderfa5d05e2016-12-12 18:11:49 -050063 is_template_base_of<Eigen::EigenBase, T>,
Jason Rhinelander17d02832017-01-16 20:35:14 -050064 negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
Dean Moldovan71af3b02016-09-24 23:54:02 +020065>;
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -040066
Jason Rhinelander17d02832017-01-16 20:35:14 -050067// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
68template <bool EigenRowMajor> struct EigenConformable {
69 bool conformable = false;
70 EigenIndex rows = 0, cols = 0;
71 EigenDStride stride{0, 0};
72
73 EigenConformable(bool fits = false) : conformable{fits} {}
74 // Matrix type:
75 EigenConformable(EigenIndex r, EigenIndex c,
76 EigenIndex rstride, EigenIndex cstride) :
77 conformable{true}, rows{r}, cols{c},
78 stride(EigenRowMajor ? rstride : cstride /* outer stride */,
79 EigenRowMajor ? cstride : rstride /* inner stride */)
80 {}
81 // Vector type:
Jason Rhinelanderefa87262017-03-17 14:51:52 -030082 EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
83 : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
84
Jason Rhinelander17d02832017-01-16 20:35:14 -050085 template <typename props> bool stride_compatible() const {
Jason Rhinelanderefa87262017-03-17 14:51:52 -030086 // To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
87 // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
Jason Rhinelander17d02832017-01-16 20:35:14 -050088 return
Jason Rhinelanderefa87262017-03-17 14:51:52 -030089 (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
90 (EigenRowMajor ? cols : rows) == 1) &&
91 (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
92 (EigenRowMajor ? rows : cols) == 1);
Jason Rhinelander17d02832017-01-16 20:35:14 -050093 }
94 operator bool() const { return conformable; }
95};
96
97template <typename Type> struct eigen_extract_stride { using type = Type; };
98template <typename PlainObjectType, int MapOptions, typename StrideType>
99struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
100template <typename PlainObjectType, int Options, typename StrideType>
101struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
102
103// Helper struct for extracting information from an Eigen type
104template <typename Type_> struct EigenProps {
105 using Type = Type_;
106 using Scalar = typename Type::Scalar;
107 using StrideType = typename eigen_extract_stride<Type>::type;
108 static constexpr EigenIndex
109 rows = Type::RowsAtCompileTime,
110 cols = Type::ColsAtCompileTime,
111 size = Type::SizeAtCompileTime;
112 static constexpr bool
113 row_major = Type::IsRowMajor,
114 vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
115 fixed_rows = rows != Eigen::Dynamic,
116 fixed_cols = cols != Eigen::Dynamic,
117 fixed = size != Eigen::Dynamic, // Fully-fixed size
118 dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
119
120 template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
121 static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
122 outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
123 vector ? size : row_major ? cols : rows>::value;
124 static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
125 static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
126 static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
127
128 // Takes an input array and determines whether we can make it fit into the Eigen type. If
129 // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
130 // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
131 static EigenConformable<row_major> conformable(const array &a) {
132 const auto dims = a.ndim();
133 if (dims < 1 || dims > 2)
134 return false;
135
136 if (dims == 2) { // Matrix type: require exact match (or dynamic)
137
138 EigenIndex
139 np_rows = a.shape(0),
140 np_cols = a.shape(1),
141 np_rstride = a.strides(0) / sizeof(Scalar),
142 np_cstride = a.strides(1) / sizeof(Scalar);
143 if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
144 return false;
145
146 return {np_rows, np_cols, np_rstride, np_cstride};
147 }
148
149 // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
150 // is used, we want the (single) numpy stride value.
151 const EigenIndex n = a.shape(0),
152 stride = a.strides(0) / sizeof(Scalar);
153
154 if (vector) { // Eigen type is a compile-time vector
155 if (fixed && size != n)
156 return false; // Vector size mismatch
157 return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
158 }
159 else if (fixed) {
160 // The type has a fixed size, but is not a vector: abort
161 return false;
162 }
163 else if (fixed_cols) {
164 // Since this isn't a vector, cols must be != 1. We allow this only if it exactly
165 // equals the number of elements (rows is Dynamic, and so 1 row is allowed).
166 if (cols != n) return false;
167 return {1, n, stride};
168 }
169 else {
170 // Otherwise it's either fully dynamic, or column dynamic; both become a column vector
171 if (fixed_rows && rows != n) return false;
172 return {n, 1, stride};
173 }
174 }
175
176 static PYBIND11_DESCR descriptor() {
177 constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
178 constexpr bool show_order = is_eigen_dense_map<Type>::value;
179 constexpr bool show_c_contiguous = show_order && requires_row_major;
180 constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
181
Jason Rhinelandere9e17742017-04-08 19:26:42 -0400182 return type_descr(_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
183 _("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
184 _(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
185 _("]") +
186 // For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
187 // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
188 // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
189 // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
190 // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
191 // *gave* a numpy.ndarray of the right type and dimensions.
192 _<show_writeable>(", flags.writeable", "") +
193 _<show_c_contiguous>(", flags.c_contiguous", "") +
194 _<show_f_contiguous>(", flags.f_contiguous", "") +
195 _("]")
196 );
Jason Rhinelander17d02832017-01-16 20:35:14 -0500197 }
198};
199
200// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
201// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
202template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
203 constexpr size_t elem_size = sizeof(typename props::Scalar);
204 std::vector<size_t> shape, strides;
205 if (props::vector) {
206 shape.push_back(src.size());
207 strides.push_back(elem_size * src.innerStride());
208 }
209 else {
210 shape.push_back(src.rows());
211 shape.push_back(src.cols());
212 strides.push_back(elem_size * src.rowStride());
213 strides.push_back(elem_size * src.colStride());
214 }
215 array a(std::move(shape), std::move(strides), src.data(), base);
216 if (!writeable)
217 array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
218
219 return a.release();
220}
221
222// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
223// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
224// the base will be set to None, and lifetime management is up to the caller). The numpy array is
225// non-writeable if the given type is const.
226template <typename props, typename Type>
227handle eigen_ref_array(Type &src, handle parent = none()) {
228 // none here is to get past array's should-we-copy detection, which currently always
229 // copies when there is no base. Setting the base to None should be harmless.
230 return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
231}
232
233// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
234// array that references the encapsulated data with a python-side reference to the capsule to tie
235// its destruction to that of any dependent python objects. Const-ness is determined by whether or
236// not the Type of the pointer given is const.
237template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
238handle eigen_encapsulate(Type *src) {
Wenzel Jakobb16421e2017-03-22 22:04:00 +0100239 capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
Jason Rhinelander17d02832017-01-16 20:35:14 -0500240 return eigen_ref_array<props>(*src, base);
241}
242
243// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
244// types.
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200245template<typename Type>
Jason Rhinelander17d02832017-01-16 20:35:14 -0500246struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
247 using Scalar = typename Type::Scalar;
248 using props = EigenProps<Type>;
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200249
250 bool load(handle src, bool) {
Dean Moldovan4de27102016-11-16 01:35:22 +0100251 auto buf = array_t<Scalar>::ensure(src);
Dean Moldovanb4498ef2016-10-23 14:50:08 +0200252 if (!buf)
Ivan Smirnov91b3d682016-08-29 02:41:05 +0100253 return false;
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200254
Jason Rhinelander17d02832017-01-16 20:35:14 -0500255 auto dims = buf.ndim();
256 if (dims < 1 || dims > 2)
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200257 return false;
Jason Rhinelander17d02832017-01-16 20:35:14 -0500258
259 auto fits = props::conformable(buf);
260 if (!fits)
261 return false; // Non-comformable vector/matrix types
262
263 value = Eigen::Map<const Type, 0, EigenDStride>(buf.data(), fits.rows, fits.cols, fits.stride);
264
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200265 return true;
266 }
267
Jason Rhinelander17d02832017-01-16 20:35:14 -0500268private:
269
270 // Cast implementation
271 template <typename CType>
272 static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
273 switch (policy) {
274 case return_value_policy::take_ownership:
275 case return_value_policy::automatic:
276 return eigen_encapsulate<props>(src);
277 case return_value_policy::move:
278 return eigen_encapsulate<props>(new CType(std::move(*src)));
279 case return_value_policy::copy:
280 return eigen_array_cast<props>(*src);
281 case return_value_policy::reference:
282 case return_value_policy::automatic_reference:
283 return eigen_ref_array<props>(*src);
284 case return_value_policy::reference_internal:
285 return eigen_ref_array<props>(*src, parent);
286 default:
287 throw cast_error("unhandled return_value_policy: should not happen!");
288 };
289 }
290
291public:
292
293 // Normal returned non-reference, non-const value:
294 static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
295 return cast_impl(&src, return_value_policy::move, parent);
296 }
297 // If you return a non-reference const, we mark the numpy array readonly:
298 static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
299 return cast_impl(&src, return_value_policy::move, parent);
300 }
301 // lvalue reference return; default (automatic) becomes copy
302 static handle cast(Type &src, return_value_policy policy, handle parent) {
303 if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
304 policy = return_value_policy::copy;
305 return cast_impl(&src, policy, parent);
306 }
307 // const lvalue reference return; default (automatic) becomes copy
308 static handle cast(const Type &src, return_value_policy policy, handle parent) {
309 if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
310 policy = return_value_policy::copy;
311 return cast(&src, policy, parent);
312 }
313 // non-const pointer return
314 static handle cast(Type *src, return_value_policy policy, handle parent) {
315 return cast_impl(src, policy, parent);
316 }
317 // const pointer return
318 static handle cast(const Type *src, return_value_policy policy, handle parent) {
319 return cast_impl(src, policy, parent);
320 }
321
Jason Rhinelandere9e17742017-04-08 19:26:42 -0400322 static PYBIND11_DESCR name() { return props::descriptor(); }
Jason Rhinelander17d02832017-01-16 20:35:14 -0500323
324 operator Type*() { return &value; }
325 operator Type&() { return value; }
326 template <typename T> using cast_op_type = cast_op_type<T>;
327
328private:
329 Type value;
330};
331
332// Eigen Ref/Map classes have slightly different policy requirements, meaning we don't want to force
333// `move` when a Ref/Map rvalue is returned; we treat Ref<> sort of like a pointer (we care about
334// the underlying data, not the outer shell).
335template <typename Return>
336struct return_value_policy_override<Return, enable_if_t<is_eigen_dense_map<Return>::value>> {
337 static return_value_policy policy(return_value_policy p) { return p; }
338};
339
340// Base class for casting reference/map/block/etc. objects back to python.
341template <typename MapType> struct eigen_map_caster {
342private:
343 using props = EigenProps<MapType>;
344
345public:
346
347 // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
348 // to stay around), but we'll allow it under the assumption that you know what you're doing (and
349 // have an appropriate keep_alive in place). We return a numpy array pointing directly at the
350 // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
351 // that this means you need to ensure you don't destroy the object in some other way (e.g. with
352 // an appropriate keep_alive, or with a reference to a statically allocated matrix).
353 static handle cast(const MapType &src, return_value_policy policy, handle parent) {
354 switch (policy) {
355 case return_value_policy::copy:
356 return eigen_array_cast<props>(src);
357 case return_value_policy::reference_internal:
358 return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
359 case return_value_policy::reference:
360 case return_value_policy::automatic:
361 case return_value_policy::automatic_reference:
362 return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
363 default:
364 // move, take_ownership don't make any sense for a ref/map:
365 pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
Wenzel Jakoba970a572016-05-20 12:00:56 +0200366 }
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200367 }
368
Jason Rhinelander17d02832017-01-16 20:35:14 -0500369 static PYBIND11_DESCR name() { return props::descriptor(); }
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200370
Jason Rhinelander17d02832017-01-16 20:35:14 -0500371 // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
372 // types but not bound arguments). We still provide them (with an explicitly delete) so that
373 // you end up here if you try anyway.
374 bool load(handle, bool) = delete;
375 operator MapType() = delete;
376 template <typename> using cast_op_type = MapType;
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200377};
378
Jason Rhinelander17d02832017-01-16 20:35:14 -0500379// We can return any map-like object (but can only load Refs, specialized next):
380template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
381 : eigen_map_caster<Type> {};
382
383// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
384// copying (it requires some extra effort in many cases).
385template <typename PlainObjectType, typename StrideType>
386struct type_caster<
387 Eigen::Ref<PlainObjectType, 0, StrideType>,
388 enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
389> : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
390private:
391 using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
392 using props = EigenProps<Type>;
393 using Scalar = typename props::Scalar;
394 using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
395 using Array = array_t<Scalar, array::forcecast |
396 ((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
397 (props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
398 static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
399 // Delay construction (these have no default constructor)
400 std::unique_ptr<MapType> map;
401 std::unique_ptr<Type> ref;
402 // Our array. When possible, this is just a numpy array pointing to the source data, but
403 // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
404 // layout, or is an array of a type that needs to be converted). Using a numpy temporary
405 // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
406 // storage order conversion. (Note that we refuse to use this temporary copy when loading an
407 // argument for a Ref<M> with M non-const, i.e. a read-write reference).
408 Array copy_or_ref;
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400409public:
Jason Rhinelander17d02832017-01-16 20:35:14 -0500410 bool load(handle src, bool convert) {
411 // First check whether what we have is already an array of the right type. If not, we can't
412 // avoid a copy (because the copy is also going to do type conversion).
413 bool need_copy = !isinstance<Array>(src);
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400414
Jason Rhinelander17d02832017-01-16 20:35:14 -0500415 EigenConformable<props::row_major> fits;
416 if (!need_copy) {
417 // We don't need a converting copy, but we also need to check whether the strides are
418 // compatible with the Ref's stride requirements
419 Array aref = reinterpret_borrow<Array>(src);
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400420
Jason Rhinelander17d02832017-01-16 20:35:14 -0500421 if (aref && (!need_writeable || aref.writeable())) {
422 fits = props::conformable(aref);
423 if (!fits) return false; // Incompatible dimensions
424 if (!fits.template stride_compatible<props>())
425 need_copy = true;
426 else
427 copy_or_ref = std::move(aref);
428 }
429 else {
430 need_copy = true;
431 }
432 }
433
434 if (need_copy) {
435 // We need to copy: If we need a mutable reference, or we're not supposed to convert
436 // (either because we're in the no-convert overload pass, or because we're explicitly
437 // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
438 if (!convert || need_writeable) return false;
439
440 Array copy = Array::ensure(src);
441 if (!copy) return false;
442 fits = props::conformable(copy);
443 if (!fits || !fits.template stride_compatible<props>())
444 return false;
Jason Rhinelanderdc5ce592017-03-13 12:49:10 -0300445 copy_or_ref = std::move(copy);
Jason Rhinelander17d02832017-01-16 20:35:14 -0500446 }
447
448 ref.reset();
449 map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
450 ref.reset(new Type(*map));
451
452 return true;
453 }
454
455 operator Type*() { return ref.get(); }
456 operator Type&() { return *ref; }
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400457 template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
Jason Rhinelander17d02832017-01-16 20:35:14 -0500458
459private:
460 template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
461 Scalar *data(Array &a) { return a.mutable_data(); }
462
463 template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
464 const Scalar *data(Array &a) { return a.data(); }
465
466 // Attempt to figure out a constructor of `Stride` that will work.
467 // If both strides are fixed, use a default constructor:
468 template <typename S> using stride_ctor_default = bool_constant<
469 S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
470 std::is_default_constructible<S>::value>;
471 // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
472 // Eigen::Stride, and use it:
473 template <typename S> using stride_ctor_dual = bool_constant<
474 !stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
475 // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
476 // it (passing whichever stride is dynamic).
477 template <typename S> using stride_ctor_outer = bool_constant<
478 !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
479 S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
480 std::is_constructible<S, EigenIndex>::value>;
481 template <typename S> using stride_ctor_inner = bool_constant<
482 !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
483 S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
484 std::is_constructible<S, EigenIndex>::value>;
485
486 template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
487 static S make_stride(EigenIndex, EigenIndex) { return S(); }
488 template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
489 static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
490 template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
491 static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
492 template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
493 static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
494
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400495};
496
Jason Rhinelander17d02832017-01-16 20:35:14 -0500497// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
498// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
499// load() is not supported, but we can cast them into the python domain by first copying to a
500// regular Eigen::Matrix, then casting that.
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400501template <typename Type>
Jason Rhinelander17d02832017-01-16 20:35:14 -0500502struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400503protected:
Jason Rhinelander17d02832017-01-16 20:35:14 -0500504 using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
505 using props = EigenProps<Matrix>;
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400506public:
Jason Rhinelander17d02832017-01-16 20:35:14 -0500507 static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
508 handle h = eigen_encapsulate<props>(new Matrix(src));
509 return h;
510 }
511 static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400512
Jason Rhinelander17d02832017-01-16 20:35:14 -0500513 static PYBIND11_DESCR name() { return props::descriptor(); }
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400514
Jason Rhinelander17d02832017-01-16 20:35:14 -0500515 // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
516 // types but not bound arguments). We still provide them (with an explicitly delete) so that
517 // you end up here if you try anyway.
518 bool load(handle, bool) = delete;
519 operator Type() = delete;
520 template <typename> using cast_op_type = Type;
Jason Rhinelander9ffb3dd2016-08-04 15:24:41 -0400521};
522
Jason Rhinelander5fd50742016-08-03 16:50:22 -0400523template<typename Type>
Wenzel Jakobc1fc27e2016-09-13 00:36:43 +0900524struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200525 typedef typename Type::Scalar Scalar;
526 typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
527 typedef typename Type::Index Index;
Jason Rhinelanderd9d224f2017-01-12 19:50:33 -0500528 static constexpr bool rowMajor = Type::IsRowMajor;
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200529
530 bool load(handle src, bool) {
Wenzel Jakob178c8a82016-05-10 15:59:01 +0100531 if (!src)
532 return false;
533
Dean Moldovanc7ac16b2016-10-28 03:08:15 +0200534 auto obj = reinterpret_borrow<object>(src);
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200535 object sparse_module = module::import("scipy.sparse");
536 object matrix_type = sparse_module.attr(
537 rowMajor ? "csr_matrix" : "csc_matrix");
538
539 if (obj.get_type() != matrix_type.ptr()) {
540 try {
Wenzel Jakob6c03beb2016-05-08 14:34:09 +0200541 obj = matrix_type(obj);
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200542 } catch (const error_already_set &) {
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200543 return false;
544 }
545 }
546
Ivan Smirnov91b3d682016-08-29 02:41:05 +0100547 auto values = array_t<Scalar>((object) obj.attr("data"));
548 auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
549 auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200550 auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
551 auto nnz = obj.attr("nnz").cast<Index>();
552
Dean Moldovanb4498ef2016-10-23 14:50:08 +0200553 if (!values || !innerIndices || !outerIndices)
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200554 return false;
555
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200556 value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
Ivan Smirnov91b3d682016-08-29 02:41:05 +0100557 shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
558 outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200559
560 return true;
561 }
562
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200563 static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
564 const_cast<Type&>(src).makeCompressed();
565
566 object matrix_type = module::import("scipy.sparse").attr(
567 rowMajor ? "csr_matrix" : "csc_matrix");
568
Ivan Smirnov6956b652016-08-15 01:24:59 +0100569 array data((size_t) src.nonZeros(), src.valuePtr());
570 array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
571 array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200572
Wenzel Jakob6c03beb2016-05-08 14:34:09 +0200573 return matrix_type(
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200574 std::make_tuple(data, innerIndices, outerIndices),
575 std::make_pair(src.rows(), src.cols())
576 ).release();
577 }
578
Jason Rhinelanderd9d224f2017-01-12 19:50:33 -0500579 PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
Jason Rhinelander8469f752016-07-06 00:40:54 -0400580 + npy_format_descriptor<Scalar>::name() + _("]"));
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200581};
582
583NAMESPACE_END(detail)
584NAMESPACE_END(pybind11)
585
Jason Rhinelandercb637702016-12-13 19:09:08 -0500586#if defined(__GNUG__) || defined(__clang__)
587# pragma GCC diagnostic pop
588#elif defined(_MSC_VER)
589# pragma warning(pop)
Wenzel Jakob9e0a0562016-05-05 20:33:54 +0200590#endif