Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_SPARSEPRODUCT_H |
| 11 | #define EIGEN_SPARSEPRODUCT_H |
| 12 | |
| 13 | namespace Eigen { |
| 14 | |
| 15 | template<typename Lhs, typename Rhs> |
| 16 | struct SparseSparseProductReturnType |
| 17 | { |
| 18 | typedef typename internal::traits<Lhs>::Scalar Scalar; |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 19 | typedef typename internal::traits<Lhs>::Index Index; |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 20 | enum { |
| 21 | LhsRowMajor = internal::traits<Lhs>::Flags & RowMajorBit, |
| 22 | RhsRowMajor = internal::traits<Rhs>::Flags & RowMajorBit, |
| 23 | TransposeRhs = (!LhsRowMajor) && RhsRowMajor, |
| 24 | TransposeLhs = LhsRowMajor && (!RhsRowMajor) |
| 25 | }; |
| 26 | |
| 27 | typedef typename internal::conditional<TransposeLhs, |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 28 | SparseMatrix<Scalar,0,Index>, |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 29 | typename internal::nested<Lhs,Rhs::RowsAtCompileTime>::type>::type LhsNested; |
| 30 | |
| 31 | typedef typename internal::conditional<TransposeRhs, |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 32 | SparseMatrix<Scalar,0,Index>, |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 33 | typename internal::nested<Rhs,Lhs::RowsAtCompileTime>::type>::type RhsNested; |
| 34 | |
| 35 | typedef SparseSparseProduct<LhsNested, RhsNested> Type; |
| 36 | }; |
| 37 | |
| 38 | namespace internal { |
| 39 | template<typename LhsNested, typename RhsNested> |
| 40 | struct traits<SparseSparseProduct<LhsNested, RhsNested> > |
| 41 | { |
| 42 | typedef MatrixXpr XprKind; |
| 43 | // clean the nested types: |
| 44 | typedef typename remove_all<LhsNested>::type _LhsNested; |
| 45 | typedef typename remove_all<RhsNested>::type _RhsNested; |
| 46 | typedef typename _LhsNested::Scalar Scalar; |
| 47 | typedef typename promote_index_type<typename traits<_LhsNested>::Index, |
| 48 | typename traits<_RhsNested>::Index>::type Index; |
| 49 | |
| 50 | enum { |
| 51 | LhsCoeffReadCost = _LhsNested::CoeffReadCost, |
| 52 | RhsCoeffReadCost = _RhsNested::CoeffReadCost, |
| 53 | LhsFlags = _LhsNested::Flags, |
| 54 | RhsFlags = _RhsNested::Flags, |
| 55 | |
| 56 | RowsAtCompileTime = _LhsNested::RowsAtCompileTime, |
| 57 | ColsAtCompileTime = _RhsNested::ColsAtCompileTime, |
| 58 | MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, |
| 59 | MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, |
| 60 | |
| 61 | InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(_LhsNested::ColsAtCompileTime, _RhsNested::RowsAtCompileTime), |
| 62 | |
| 63 | EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), |
| 64 | |
| 65 | RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), |
| 66 | |
| 67 | Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) |
| 68 | | EvalBeforeAssigningBit |
| 69 | | EvalBeforeNestingBit, |
| 70 | |
| 71 | CoeffReadCost = Dynamic |
| 72 | }; |
| 73 | |
| 74 | typedef Sparse StorageKind; |
| 75 | }; |
| 76 | |
| 77 | } // end namespace internal |
| 78 | |
| 79 | template<typename LhsNested, typename RhsNested> |
| 80 | class SparseSparseProduct : internal::no_assignment_operator, |
| 81 | public SparseMatrixBase<SparseSparseProduct<LhsNested, RhsNested> > |
| 82 | { |
| 83 | public: |
| 84 | |
| 85 | typedef SparseMatrixBase<SparseSparseProduct> Base; |
| 86 | EIGEN_DENSE_PUBLIC_INTERFACE(SparseSparseProduct) |
| 87 | |
| 88 | private: |
| 89 | |
| 90 | typedef typename internal::traits<SparseSparseProduct>::_LhsNested _LhsNested; |
| 91 | typedef typename internal::traits<SparseSparseProduct>::_RhsNested _RhsNested; |
| 92 | |
| 93 | public: |
| 94 | |
| 95 | template<typename Lhs, typename Rhs> |
| 96 | EIGEN_STRONG_INLINE SparseSparseProduct(const Lhs& lhs, const Rhs& rhs) |
| 97 | : m_lhs(lhs), m_rhs(rhs), m_tolerance(0), m_conservative(true) |
| 98 | { |
| 99 | init(); |
| 100 | } |
| 101 | |
| 102 | template<typename Lhs, typename Rhs> |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 103 | EIGEN_STRONG_INLINE SparseSparseProduct(const Lhs& lhs, const Rhs& rhs, const RealScalar& tolerance) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 104 | : m_lhs(lhs), m_rhs(rhs), m_tolerance(tolerance), m_conservative(false) |
| 105 | { |
| 106 | init(); |
| 107 | } |
| 108 | |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 109 | SparseSparseProduct pruned(const Scalar& reference = 0, const RealScalar& epsilon = NumTraits<RealScalar>::dummy_precision()) const |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 110 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 111 | using std::abs; |
| 112 | return SparseSparseProduct(m_lhs,m_rhs,abs(reference)*epsilon); |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 113 | } |
| 114 | |
| 115 | template<typename Dest> |
| 116 | void evalTo(Dest& result) const |
| 117 | { |
| 118 | if(m_conservative) |
| 119 | internal::conservative_sparse_sparse_product_selector<_LhsNested, _RhsNested, Dest>::run(lhs(),rhs(),result); |
| 120 | else |
| 121 | internal::sparse_sparse_product_with_pruning_selector<_LhsNested, _RhsNested, Dest>::run(lhs(),rhs(),result,m_tolerance); |
| 122 | } |
| 123 | |
| 124 | EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); } |
| 125 | EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); } |
| 126 | |
| 127 | EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } |
| 128 | EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } |
| 129 | |
| 130 | protected: |
| 131 | void init() |
| 132 | { |
| 133 | eigen_assert(m_lhs.cols() == m_rhs.rows()); |
| 134 | |
| 135 | enum { |
| 136 | ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic |
| 137 | || _RhsNested::RowsAtCompileTime==Dynamic |
| 138 | || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime), |
| 139 | AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime, |
| 140 | SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested) |
| 141 | }; |
| 142 | // note to the lost user: |
| 143 | // * for a dot product use: v1.dot(v2) |
| 144 | // * for a coeff-wise product use: v1.cwise()*v2 |
| 145 | EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), |
| 146 | INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) |
| 147 | EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), |
| 148 | INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) |
| 149 | EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) |
| 150 | } |
| 151 | |
| 152 | LhsNested m_lhs; |
| 153 | RhsNested m_rhs; |
| 154 | RealScalar m_tolerance; |
| 155 | bool m_conservative; |
| 156 | }; |
| 157 | |
| 158 | // sparse = sparse * sparse |
| 159 | template<typename Derived> |
| 160 | template<typename Lhs, typename Rhs> |
| 161 | inline Derived& SparseMatrixBase<Derived>::operator=(const SparseSparseProduct<Lhs,Rhs>& product) |
| 162 | { |
| 163 | product.evalTo(derived()); |
| 164 | return derived(); |
| 165 | } |
| 166 | |
| 167 | /** \returns an expression of the product of two sparse matrices. |
| 168 | * By default a conservative product preserving the symbolic non zeros is performed. |
| 169 | * The automatic pruning of the small values can be achieved by calling the pruned() function |
| 170 | * in which case a totally different product algorithm is employed: |
| 171 | * \code |
| 172 | * C = (A*B).pruned(); // supress numerical zeros (exact) |
| 173 | * C = (A*B).pruned(ref); |
| 174 | * C = (A*B).pruned(ref,epsilon); |
| 175 | * \endcode |
| 176 | * where \c ref is a meaningful non zero reference value. |
| 177 | * */ |
| 178 | template<typename Derived> |
| 179 | template<typename OtherDerived> |
| 180 | inline const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type |
| 181 | SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const |
| 182 | { |
| 183 | return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); |
| 184 | } |
| 185 | |
| 186 | } // end namespace Eigen |
| 187 | |
| 188 | #endif // EIGEN_SPARSEPRODUCT_H |