Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de> |
| 5 | // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de> |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 6 | // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net> |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 7 | // |
| 8 | // This Source Code Form is subject to the terms of the Mozilla |
| 9 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 10 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 11 | |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 12 | #ifndef KRONECKER_TENSOR_PRODUCT_H |
| 13 | #define KRONECKER_TENSOR_PRODUCT_H |
| 14 | |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 15 | namespace Eigen { |
| 16 | |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 17 | template<typename Scalar, int Options, typename Index> class SparseMatrix; |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 18 | |
| 19 | /*! |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 20 | * \brief Kronecker tensor product helper class for dense matrices |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 21 | * |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 22 | * This class is the return value of kroneckerProduct(MatrixBase, |
| 23 | * MatrixBase). Use the function rather than construct this class |
| 24 | * directly to avoid specifying template prarameters. |
| 25 | * |
| 26 | * \tparam Lhs Type of the left-hand side, a matrix expression. |
| 27 | * \tparam Rhs Type of the rignt-hand side, a matrix expression. |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 28 | */ |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 29 | template<typename Lhs, typename Rhs> |
| 30 | class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> > |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 31 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 32 | private: |
| 33 | typedef ReturnByValue<KroneckerProduct> Base; |
| 34 | typedef typename Base::Scalar Scalar; |
| 35 | typedef typename Base::Index Index; |
| 36 | |
| 37 | public: |
| 38 | /*! \brief Constructor. */ |
| 39 | KroneckerProduct(const Lhs& A, const Rhs& B) |
| 40 | : m_A(A), m_B(B) |
| 41 | {} |
| 42 | |
| 43 | /*! \brief Evaluate the Kronecker tensor product. */ |
| 44 | template<typename Dest> void evalTo(Dest& dst) const; |
| 45 | |
| 46 | inline Index rows() const { return m_A.rows() * m_B.rows(); } |
| 47 | inline Index cols() const { return m_A.cols() * m_B.cols(); } |
| 48 | |
| 49 | Scalar coeff(Index row, Index col) const |
| 50 | { |
| 51 | return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * |
| 52 | m_B.coeff(row % m_B.rows(), col % m_B.cols()); |
| 53 | } |
| 54 | |
| 55 | Scalar coeff(Index i) const |
| 56 | { |
| 57 | EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct); |
| 58 | return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size()); |
| 59 | } |
| 60 | |
| 61 | private: |
| 62 | typename Lhs::Nested m_A; |
| 63 | typename Rhs::Nested m_B; |
| 64 | }; |
| 65 | |
| 66 | /*! |
| 67 | * \brief Kronecker tensor product helper class for sparse matrices |
| 68 | * |
| 69 | * If at least one of the operands is a sparse matrix expression, |
| 70 | * then this class is returned and evaluates into a sparse matrix. |
| 71 | * |
| 72 | * This class is the return value of kroneckerProduct(EigenBase, |
| 73 | * EigenBase). Use the function rather than construct this class |
| 74 | * directly to avoid specifying template prarameters. |
| 75 | * |
| 76 | * \tparam Lhs Type of the left-hand side, a matrix expression. |
| 77 | * \tparam Rhs Type of the rignt-hand side, a matrix expression. |
| 78 | */ |
| 79 | template<typename Lhs, typename Rhs> |
| 80 | class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> > |
| 81 | { |
| 82 | private: |
| 83 | typedef typename internal::traits<KroneckerProductSparse>::Index Index; |
| 84 | |
| 85 | public: |
| 86 | /*! \brief Constructor. */ |
| 87 | KroneckerProductSparse(const Lhs& A, const Rhs& B) |
| 88 | : m_A(A), m_B(B) |
| 89 | {} |
| 90 | |
| 91 | /*! \brief Evaluate the Kronecker tensor product. */ |
| 92 | template<typename Dest> void evalTo(Dest& dst) const; |
| 93 | |
| 94 | inline Index rows() const { return m_A.rows() * m_B.rows(); } |
| 95 | inline Index cols() const { return m_A.cols() * m_B.cols(); } |
| 96 | |
| 97 | template<typename Scalar, int Options, typename Index> |
| 98 | operator SparseMatrix<Scalar, Options, Index>() |
| 99 | { |
| 100 | SparseMatrix<Scalar, Options, Index> result; |
| 101 | evalTo(result.derived()); |
| 102 | return result; |
| 103 | } |
| 104 | |
| 105 | private: |
| 106 | typename Lhs::Nested m_A; |
| 107 | typename Rhs::Nested m_B; |
| 108 | }; |
| 109 | |
| 110 | template<typename Lhs, typename Rhs> |
| 111 | template<typename Dest> |
| 112 | void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const |
| 113 | { |
| 114 | const int BlockRows = Rhs::RowsAtCompileTime, |
| 115 | BlockCols = Rhs::ColsAtCompileTime; |
| 116 | const Index Br = m_B.rows(), |
| 117 | Bc = m_B.cols(); |
| 118 | for (Index i=0; i < m_A.rows(); ++i) |
| 119 | for (Index j=0; j < m_A.cols(); ++j) |
| 120 | Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B; |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 121 | } |
| 122 | |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 123 | template<typename Lhs, typename Rhs> |
| 124 | template<typename Dest> |
| 125 | void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 126 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 127 | const Index Br = m_B.rows(), |
| 128 | Bc = m_B.cols(); |
| 129 | dst.resize(rows(),cols()); |
| 130 | dst.resizeNonZeros(0); |
| 131 | dst.reserve(m_A.nonZeros() * m_B.nonZeros()); |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 132 | |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 133 | for (Index kA=0; kA < m_A.outerSize(); ++kA) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 134 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 135 | for (Index kB=0; kB < m_B.outerSize(); ++kB) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 136 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 137 | for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 138 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 139 | for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 140 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 141 | const Index i = itA.row() * Br + itB.row(), |
| 142 | j = itA.col() * Bc + itB.col(); |
| 143 | dst.insert(i,j) = itA.value() * itB.value(); |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 144 | } |
| 145 | } |
| 146 | } |
| 147 | } |
| 148 | } |
| 149 | |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 150 | namespace internal { |
| 151 | |
| 152 | template<typename _Lhs, typename _Rhs> |
| 153 | struct traits<KroneckerProduct<_Lhs,_Rhs> > |
| 154 | { |
| 155 | typedef typename remove_all<_Lhs>::type Lhs; |
| 156 | typedef typename remove_all<_Rhs>::type Rhs; |
| 157 | typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; |
| 158 | |
| 159 | enum { |
| 160 | Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, |
| 161 | Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, |
| 162 | MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, |
| 163 | MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, |
| 164 | CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost |
| 165 | }; |
| 166 | |
| 167 | typedef Matrix<Scalar,Rows,Cols> ReturnType; |
| 168 | }; |
| 169 | |
| 170 | template<typename _Lhs, typename _Rhs> |
| 171 | struct traits<KroneckerProductSparse<_Lhs,_Rhs> > |
| 172 | { |
| 173 | typedef MatrixXpr XprKind; |
| 174 | typedef typename remove_all<_Lhs>::type Lhs; |
| 175 | typedef typename remove_all<_Rhs>::type Rhs; |
| 176 | typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; |
| 177 | typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind; |
| 178 | typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index; |
| 179 | |
| 180 | enum { |
| 181 | LhsFlags = Lhs::Flags, |
| 182 | RhsFlags = Rhs::Flags, |
| 183 | |
| 184 | RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, |
| 185 | ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, |
| 186 | MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, |
| 187 | MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, |
| 188 | |
| 189 | EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit), |
| 190 | RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), |
| 191 | |
| 192 | Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) |
| 193 | | EvalBeforeNestingBit | EvalBeforeAssigningBit, |
| 194 | CoeffReadCost = Dynamic |
| 195 | }; |
| 196 | }; |
| 197 | |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 198 | } // end namespace internal |
| 199 | |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 200 | /*! |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 201 | * \ingroup KroneckerProduct_Module |
| 202 | * |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 203 | * Computes Kronecker tensor product of two dense matrices |
| 204 | * |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 205 | * \warning If you want to replace a matrix by its Kronecker product |
| 206 | * with some matrix, do \b NOT do this: |
| 207 | * \code |
| 208 | * A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect |
| 209 | * \endcode |
| 210 | * instead, use eval() to work around this: |
| 211 | * \code |
| 212 | * A = kroneckerProduct(A,B).eval(); |
| 213 | * \endcode |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 214 | * |
| 215 | * \param a Dense matrix a |
| 216 | * \param b Dense matrix b |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 217 | * \return Kronecker tensor product of a and b |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 218 | */ |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 219 | template<typename A, typename B> |
| 220 | KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 221 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 222 | return KroneckerProduct<A, B>(a.derived(), b.derived()); |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 223 | } |
| 224 | |
| 225 | /*! |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 226 | * \ingroup KroneckerProduct_Module |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 227 | * |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 228 | * Computes Kronecker tensor product of two matrices, at least one of |
| 229 | * which is sparse |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 230 | * |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 231 | * \param a Dense/sparse matrix a |
| 232 | * \param b Dense/sparse matrix b |
| 233 | * \return Kronecker tensor product of a and b, stored in a sparse |
| 234 | * matrix |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 235 | */ |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 236 | template<typename A, typename B> |
| 237 | KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b) |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 238 | { |
Carlos Hernandez | 7faaa9f | 2014-08-05 17:53:32 -0700 | [diff] [blame] | 239 | return KroneckerProductSparse<A,B>(a.derived(), b.derived()); |
Narayan Kamath | c981c48 | 2012-11-02 10:59:05 +0000 | [diff] [blame] | 240 | } |
| 241 | |
| 242 | } // end namespace Eigen |
| 243 | |
| 244 | #endif // KRONECKER_TENSOR_PRODUCT_H |