MOAB
4.9.3pre
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2015 Gael Guennebaud <[email protected]> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSEDENSEPRODUCT_H 00011 #define EIGEN_SPARSEDENSEPRODUCT_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; 00018 template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; 00019 00020 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, 00021 typename AlphaType, 00022 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, 00023 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> 00024 struct sparse_time_dense_product_impl; 00025 00026 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00027 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true> 00028 { 00029 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00030 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00031 typedef typename internal::remove_all<DenseResType>::type Res; 00032 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00033 typedef evaluator<Lhs> LhsEval; 00034 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00035 { 00036 LhsEval lhsEval(lhs); 00037 00038 Index n = lhs.outerSize(); 00039 #ifdef EIGEN_HAS_OPENMP 00040 Eigen::initParallel(); 00041 Index threads = Eigen::nbThreads(); 00042 #endif 00043 00044 for(Index c=0; c<rhs.cols(); ++c) 00045 { 00046 #ifdef EIGEN_HAS_OPENMP 00047 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. 00048 // It basically represents the minimal amount of work to be done to be worth it. 00049 if(threads>1 && lhsEval.nonZerosEstimate() > 20000) 00050 { 00051 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) 00052 for(Index i=0; i<n; ++i) 00053 processRow(lhsEval,rhs,res,alpha,i,c); 00054 } 00055 else 00056 #endif 00057 { 00058 for(Index i=0; i<n; ++i) 00059 processRow(lhsEval,rhs,res,alpha,i,c); 00060 } 00061 } 00062 } 00063 00064 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col) 00065 { 00066 typename Res::Scalar tmp(0); 00067 for(LhsInnerIterator it(lhsEval,i); it ;++it) 00068 tmp += it.value() * rhs.coeff(it.index(),col); 00069 res.coeffRef(i,col) += alpha * tmp; 00070 } 00071 00072 }; 00073 00074 // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format? 00075 template<typename T1, typename T2/*, int _Options, typename _StrideType*/> 00076 struct scalar_product_traits<T1, Ref<T2/*, _Options, _StrideType*/> > 00077 { 00078 enum { 00079 Defined = 1 00080 }; 00081 typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType; 00082 }; 00083 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType> 00084 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true> 00085 { 00086 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00087 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00088 typedef typename internal::remove_all<DenseResType>::type Res; 00089 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00090 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 00091 { 00092 evaluator<Lhs> lhsEval(lhs); 00093 for(Index c=0; c<rhs.cols(); ++c) 00094 { 00095 for(Index j=0; j<lhs.outerSize(); ++j) 00096 { 00097 // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); 00098 typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); 00099 for(LhsInnerIterator it(lhsEval,j); it ;++it) 00100 res.coeffRef(it.index(),c) += it.value() * rhs_j; 00101 } 00102 } 00103 } 00104 }; 00105 00106 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00107 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false> 00108 { 00109 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00110 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00111 typedef typename internal::remove_all<DenseResType>::type Res; 00112 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00113 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00114 { 00115 evaluator<Lhs> lhsEval(lhs); 00116 for(Index j=0; j<lhs.outerSize(); ++j) 00117 { 00118 typename Res::RowXpr res_j(res.row(j)); 00119 for(LhsInnerIterator it(lhsEval,j); it ;++it) 00120 res_j += (alpha*it.value()) * rhs.row(it.index()); 00121 } 00122 } 00123 }; 00124 00125 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00126 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false> 00127 { 00128 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00129 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00130 typedef typename internal::remove_all<DenseResType>::type Res; 00131 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00132 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00133 { 00134 evaluator<Lhs> lhsEval(lhs); 00135 for(Index j=0; j<lhs.outerSize(); ++j) 00136 { 00137 typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); 00138 for(LhsInnerIterator it(lhsEval,j); it ;++it) 00139 res.row(it.index()) += (alpha*it.value()) * rhs_j; 00140 } 00141 } 00142 }; 00143 00144 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> 00145 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 00146 { 00147 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha); 00148 } 00149 00150 } // end namespace internal 00151 00152 namespace internal { 00153 00154 template<typename Lhs, typename Rhs, int ProductType> 00155 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 00156 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> > 00157 { 00158 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 00159 00160 template<typename Dest> 00161 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 00162 { 00163 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested; 00164 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested; 00165 LhsNested lhsNested(lhs); 00166 RhsNested rhsNested(rhs); 00167 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha); 00168 } 00169 }; 00170 00171 template<typename Lhs, typename Rhs, int ProductType> 00172 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> 00173 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 00174 {}; 00175 00176 template<typename Lhs, typename Rhs, int ProductType> 00177 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 00178 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> > 00179 { 00180 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 00181 00182 template<typename Dst> 00183 static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 00184 { 00185 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested; 00186 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested; 00187 LhsNested lhsNested(lhs); 00188 RhsNested rhsNested(rhs); 00189 00190 // transpose everything 00191 Transpose<Dst> dstT(dst); 00192 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha); 00193 } 00194 }; 00195 00196 template<typename Lhs, typename Rhs, int ProductType> 00197 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> 00198 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 00199 {}; 00200 00201 template<typename LhsT, typename RhsT, bool NeedToTranspose> 00202 struct sparse_dense_outer_product_evaluator 00203 { 00204 protected: 00205 typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; 00206 typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; 00207 typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; 00208 00209 // if the actual left-hand side is a dense vector, 00210 // then build a sparse-view so that we can seamlessly iterate over it. 00211 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 00212 Lhs1, SparseView<Lhs1> >::type ActualLhs; 00213 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 00214 Lhs1 const&, SparseView<Lhs1> >::type LhsArg; 00215 00216 typedef evaluator<ActualLhs> LhsEval; 00217 typedef evaluator<ActualRhs> RhsEval; 00218 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; 00219 typedef typename ProdXprType::Scalar Scalar; 00220 00221 public: 00222 enum { 00223 Flags = NeedToTranspose ? RowMajorBit : 0, 00224 CoeffReadCost = HugeCost 00225 }; 00226 00227 class InnerIterator : public LhsIterator 00228 { 00229 public: 00230 InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) 00231 : LhsIterator(xprEval.m_lhsXprImpl, 0), 00232 m_outer(outer), 00233 m_empty(false), 00234 m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) 00235 {} 00236 00237 EIGEN_STRONG_INLINE Index outer() const { return m_outer; } 00238 EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } 00239 EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } 00240 00241 EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } 00242 EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } 00243 00244 protected: 00245 Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const 00246 { 00247 return rhs.coeff(outer); 00248 } 00249 00250 Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) 00251 { 00252 typename RhsEval::InnerIterator it(rhs, outer); 00253 if (it && it.index()==0 && it.value()!=Scalar(0)) 00254 return it.value(); 00255 m_empty = true; 00256 return Scalar(0); 00257 } 00258 00259 Index m_outer; 00260 bool m_empty; 00261 Scalar m_factor; 00262 }; 00263 00264 sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) 00265 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 00266 { 00267 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 00268 } 00269 00270 // transpose case 00271 sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) 00272 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 00273 { 00274 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 00275 } 00276 00277 protected: 00278 const LhsArg m_lhs; 00279 evaluator<ActualLhs> m_lhsXprImpl; 00280 evaluator<ActualRhs> m_rhsXprImpl; 00281 }; 00282 00283 // sparse * dense outer product 00284 template<typename Lhs, typename Rhs> 00285 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape> 00286 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> 00287 { 00288 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; 00289 00290 typedef Product<Lhs, Rhs> XprType; 00291 typedef typename XprType::PlainObject PlainObject; 00292 00293 explicit product_evaluator(const XprType& xpr) 00294 : Base(xpr.lhs(), xpr.rhs()) 00295 {} 00296 00297 }; 00298 00299 template<typename Lhs, typename Rhs> 00300 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape> 00301 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> 00302 { 00303 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; 00304 00305 typedef Product<Lhs, Rhs> XprType; 00306 typedef typename XprType::PlainObject PlainObject; 00307 00308 explicit product_evaluator(const XprType& xpr) 00309 : Base(xpr.lhs(), xpr.rhs()) 00310 {} 00311 00312 }; 00313 00314 } // end namespace internal 00315 00316 } // end namespace Eigen 00317 00318 #endif // EIGEN_SPARSEDENSEPRODUCT_H