Add support for Eigen::Ref<...> function arguments
Eigen::Ref is a common way to pass eigen dense types without needing a
template, e.g. the single definition `void
func(Eigen::Ref<Eigen::MatrixXd> x)` can be called with any double
matrix-like object.
The current pybind11 eigen support fails with internal errors if
attempting to bind a function with an Eigen::Ref<...> argument because
Eigen::Ref<...> satisfies the "is_eigen_dense" requirement, but can't
compile if actually used: Eigen::Ref<...> itself is not default
constructible, and so the argument std::tuple containing an
Eigen::Ref<...> isn't constructible, which results in compilation
failure.
This commit adds support for Eigen::Ref<...> by giving it its own
type_caster implementation which consists of an internal type_caster of
the referenced type, load/cast methods that dispatch to the internal
type_caster, and a unique_ptr to an Eigen::Ref<> instance that gets
set during load().
There is, of course, no performance advantage for pybind11-using code of
using Eigen::Ref<...>--we are allocating a matrix of the derived type
when loading it--but this has the advantage of allowing pybind11 to bind
transparently to C++ methods taking Eigen::Refs.
diff --git a/example/eigen.cpp b/example/eigen.cpp
index f99ae3a..728b575 100644
--- a/example/eigen.cpp
+++ b/example/eigen.cpp
@@ -9,6 +9,7 @@
#include "example.h"
#include <pybind11/eigen.h>
+#include <Eigen/Cholesky>
Eigen::VectorXf double_col(const Eigen::VectorXf& x)
{ return 2.0f * x; }
@@ -19,6 +20,14 @@
Eigen::MatrixXf double_mat_cm(const Eigen::MatrixXf& x)
{ return 2.0f * x; }
+// Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended
+Eigen::MatrixXd cholesky1(Eigen::Ref<Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
+Eigen::MatrixXd cholesky2(const Eigen::Ref<const Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
+Eigen::MatrixXd cholesky3(const Eigen::Ref<Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
+Eigen::MatrixXd cholesky4(Eigen::Ref<const Eigen::MatrixXd> &x) { return x.llt().matrixL(); }
+Eigen::MatrixXd cholesky5(Eigen::Ref<Eigen::MatrixXd> x) { return x.llt().matrixL(); }
+Eigen::MatrixXd cholesky6(Eigen::Ref<const Eigen::MatrixXd> x) { return x.llt().matrixL(); }
+
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXfRowMajor;
MatrixXfRowMajor double_mat_rm(const MatrixXfRowMajor& x)
{ return 2.0f * x; }
@@ -40,6 +49,12 @@
m.def("double_row", &double_row);
m.def("double_mat_cm", &double_mat_cm);
m.def("double_mat_rm", &double_mat_rm);
+ m.def("cholesky1", &cholesky1);
+ m.def("cholesky2", &cholesky2);
+ m.def("cholesky3", &cholesky3);
+ m.def("cholesky4", &cholesky4);
+ m.def("cholesky5", &cholesky5);
+ m.def("cholesky6", &cholesky6);
m.def("fixed_r", [mat]() -> FixedMatrixR {
return FixedMatrixR(mat);
diff --git a/example/eigen.py b/example/eigen.py
index e69605d..6cdc394 100644
--- a/example/eigen.py
+++ b/example/eigen.py
@@ -11,6 +11,7 @@
from example import sparse_passthrough_r, sparse_passthrough_c
from example import double_row, double_col
from example import double_mat_cm, double_mat_rm
+from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
try:
import numpy as np
import scipy
@@ -70,3 +71,10 @@
for slice_idx, ref_mat in enumerate(slices):
print("double_mat_cm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_cm(ref_mat), 2.0 * ref_mat)))
print("double_mat_rm(%d) = %s" % (slice_idx, check_got_vs_ref(double_mat_rm(ref_mat), 2.0 * ref_mat)))
+
+i = 1
+for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
+ mymat = chol(np.array([[1,2,4], [2,13,23], [4,23,77]]))
+ print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
+ i += 1
+
diff --git a/example/eigen.ref b/example/eigen.ref
index 03091cc..93e88ad 100644
--- a/example/eigen.ref
+++ b/example/eigen.ref
@@ -27,3 +27,9 @@
double_mat_rm(1) = OK
double_mat_cm(2) = OK
double_mat_rm(2) = OK
+cholesky1 OK
+cholesky2 OK
+cholesky3 OK
+cholesky4 OK
+cholesky5 OK
+cholesky6 OK
diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h
index 987b547..b9f22b0 100644
--- a/include/pybind11/eigen.h
+++ b/include/pybind11/eigen.h
@@ -40,6 +40,19 @@
static constexpr bool value = decltype(test(std::declval<T>()))::value;
};
+// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
+// it (since there is no reference!), but we can cast from it.
+template <typename T> class is_eigen_ref {
+private:
+ template<typename Derived> static typename std::enable_if<
+ std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value,
+ Derived>::type test(const Eigen::Ref<Derived> &);
+ static void test(...);
+public:
+ typedef decltype(test(std::declval<T>())) Derived;
+ static constexpr bool value = !std::is_void<Derived>::value;
+};
+
template <typename T> class is_eigen_sparse {
private:
template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
@@ -49,7 +62,7 @@
};
template<typename Type>
-struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
+struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
typedef typename Type::Scalar Scalar;
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
static constexpr bool isVector = Type::IsVectorAtCompileTime;
@@ -150,6 +163,26 @@
};
template<typename Type>
+struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
+private:
+ using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
+ using DerivedCaster = type_caster<Derived>;
+ DerivedCaster derived_caster;
+protected:
+ std::unique_ptr<Type> value;
+public:
+ bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
+ static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
+ static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
+
+ static PYBIND11_DESCR name() { return DerivedCaster::name(); }
+
+ operator Type*() { return value.get(); }
+ operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
+ template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
+};
+
+template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
typedef typename Type::Scalar Scalar;
typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;