blob: fcc18f5c9c8c53687866407dbb3d0269030497b3 [file] [log] [blame]
Narayan Kamathc981c482012-11-02 10:59:05 +00001// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
11#define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
12
13namespace Eigen {
14
15namespace internal {
16
17
18// perform a pseudo in-place sparse * sparse product assuming all matrices are col major
19template<typename Lhs, typename Rhs, typename ResultType>
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -070020static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
Narayan Kamathc981c482012-11-02 10:59:05 +000021{
22 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
23
24 typedef typename remove_all<Lhs>::type::Scalar Scalar;
25 typedef typename remove_all<Lhs>::type::Index Index;
26
27 // make sure to call innerSize/outerSize since we fake the storage order.
28 Index rows = lhs.innerSize();
29 Index cols = rhs.outerSize();
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -070030 //Index size = lhs.outerSize();
Narayan Kamathc981c482012-11-02 10:59:05 +000031 eigen_assert(lhs.outerSize() == rhs.innerSize());
32
33 // allocate a temporary buffer
34 AmbiVector<Scalar,Index> tempVector(rows);
35
36 // estimate the number of non zero entries
37 // given a rhs column containing Y non zeros, we assume that the respective Y columns
38 // of the lhs differs in average of one non zeros, thus the number of non zeros for
39 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
40 // per column of the lhs.
41 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
42 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
43
44 // mimics a resizeByInnerOuter:
45 if(ResultType::IsRowMajor)
46 res.resize(cols, rows);
47 else
48 res.resize(rows, cols);
49
50 res.reserve(estimated_nnz_prod);
51 double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
52 for (Index j=0; j<cols; ++j)
53 {
54 // FIXME:
55 //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
56 // let's do a more accurate determination of the nnz ratio for the current column j of res
57 tempVector.init(ratioColRes);
58 tempVector.setZero();
59 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
60 {
61 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
62 tempVector.restart();
63 Scalar x = rhsIt.value();
64 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
65 {
66 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
67 }
68 }
69 res.startVec(j);
70 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
71 res.insertBackByOuterInner(j,it.index()) = it.value();
72 }
73 res.finalize();
74}
75
76template<typename Lhs, typename Rhs, typename ResultType,
77 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
78 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
79 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
80struct sparse_sparse_product_with_pruning_selector;
81
82template<typename Lhs, typename Rhs, typename ResultType>
83struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
84{
85 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
86 typedef typename ResultType::RealScalar RealScalar;
87
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -070088 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
Narayan Kamathc981c482012-11-02 10:59:05 +000089 {
90 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
91 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
92 res.swap(_res);
93 }
94};
95
96template<typename Lhs, typename Rhs, typename ResultType>
97struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
98{
99 typedef typename ResultType::RealScalar RealScalar;
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -0700100 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
Narayan Kamathc981c482012-11-02 10:59:05 +0000101 {
102 // we need a col-major matrix to hold the result
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -0700103 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType;
Narayan Kamathc981c482012-11-02 10:59:05 +0000104 SparseTemporaryType _res(res.rows(), res.cols());
105 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
106 res = _res;
107 }
108};
109
110template<typename Lhs, typename Rhs, typename ResultType>
111struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
112{
113 typedef typename ResultType::RealScalar RealScalar;
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -0700114 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
Narayan Kamathc981c482012-11-02 10:59:05 +0000115 {
116 // let's transpose the product to get a column x column product
117 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
118 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
119 res.swap(_res);
120 }
121};
122
123template<typename Lhs, typename Rhs, typename ResultType>
124struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
125{
126 typedef typename ResultType::RealScalar RealScalar;
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -0700127 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
Narayan Kamathc981c482012-11-02 10:59:05 +0000128 {
Carlos Hernandez7faaa9f2014-08-05 17:53:32 -0700129 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
130 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
131 ColMajorMatrixLhs colLhs(lhs);
132 ColMajorMatrixRhs colRhs(rhs);
133 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
Narayan Kamathc981c482012-11-02 10:59:05 +0000134
135 // let's transpose the product to get a column x column product
136// typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
137// SparseTemporaryType _res(res.cols(), res.rows());
138// sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
139// res = _res.transpose();
140 }
141};
142
143// NOTE the 2 others cases (col row *) must never occur since they are caught
144// by ProductReturnType which transforms it to (col col *) by evaluating rhs.
145
146} // end namespace internal
147
148} // end namespace Eigen
149
150#endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H