eigen.h: return compile time vectors as 1D NumPy arrays
diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h
index 746f0a2..00f0347 100644
--- a/include/pybind11/eigen.h
+++ b/include/pybind11/eigen.h
@@ -41,6 +41,7 @@
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
typedef typename Type::Scalar Scalar;
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
+ static constexpr bool isVector = Type::IsVectorAtCompileTime;
bool load(handle src, bool) {
array_t<Scalar> buffer(src, true);
@@ -50,7 +51,7 @@
buffer_info info = buffer.request();
if (info.ndim == 1) {
typedef Eigen::Stride<Eigen::Dynamic, 0> Strides;
- if (!Type::IsVectorAtCompileTime &&
+ if (!isVector &&
!(Type::RowsAtCompileTime == Eigen::Dynamic &&
Type::ColsAtCompileTime == Eigen::Dynamic))
return false;
@@ -87,23 +88,39 @@
}
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
- array result(buffer_info(
- /* Pointer to buffer */
- const_cast<Scalar *>(src.data()),
- /* Size of one scalar */
- sizeof(Scalar),
- /* Python struct-style format descriptor */
- format_descriptor<Scalar>::value,
- /* Number of dimensions */
- 2,
- /* Buffer dimensions */
- { (size_t) src.rows(),
- (size_t) src.cols() },
- /* Strides (in bytes) for each index */
- { sizeof(Scalar) * (rowMajor ? src.cols() : 1),
- sizeof(Scalar) * (rowMajor ? 1 : src.rows()) }
- ));
- return result.release();
+ if (isVector) {
+ return array(buffer_info(
+ /* Pointer to buffer */
+ const_cast<Scalar *>(src.data()),
+ /* Size of one scalar */
+ sizeof(Scalar),
+ /* Python struct-style format descriptor */
+ format_descriptor<Scalar>::value,
+ /* Number of dimensions */
+ 1,
+ /* Buffer dimensions */
+ { (size_t) src.size() },
+ /* Strides (in bytes) for each index */
+ { sizeof(Scalar) }
+ )).release();
+ } else {
+ return array(buffer_info(
+ /* Pointer to buffer */
+ const_cast<Scalar *>(src.data()),
+ /* Size of one scalar */
+ sizeof(Scalar),
+ /* Python struct-style format descriptor */
+ format_descriptor<Scalar>::value,
+ /* Number of dimensions */
+ isVector ? 1 : 2,
+ /* Buffer dimensions */
+ { (size_t) src.rows(),
+ (size_t) src.cols() },
+ /* Strides (in bytes) for each index */
+ { sizeof(Scalar) * (rowMajor ? src.cols() : 1),
+ sizeof(Scalar) * (rowMajor ? 1 : src.rows()) }
+ )).release();
+ }
}
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;