MOAB
4.9.3pre
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2009 Gael Guennebaud <[email protected]> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SOLVETRIANGULAR_H 00011 #define EIGEN_SOLVETRIANGULAR_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 // Forward declarations: 00018 // The following two routines are implemented in the products/TriangularSolver*.h files 00019 template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder> 00020 struct triangular_solve_vector; 00021 00022 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> 00023 struct triangular_solve_matrix; 00024 00025 // small helper struct extracting some traits on the underlying solver operation 00026 template<typename Lhs, typename Rhs, int Side> 00027 class trsolve_traits 00028 { 00029 private: 00030 enum { 00031 RhsIsVectorAtCompileTime = (Side==OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime)==1 00032 }; 00033 public: 00034 enum { 00035 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime != Dynamic && Rhs::SizeAtCompileTime <= 8) 00036 ? CompleteUnrolling : NoUnrolling, 00037 RhsVectors = RhsIsVectorAtCompileTime ? 1 : Dynamic 00038 }; 00039 }; 00040 00041 template<typename Lhs, typename Rhs, 00042 int Side, // can be OnTheLeft/OnTheRight 00043 int Mode, // can be Upper/Lower | UnitDiag 00044 int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling, 00045 int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors 00046 > 00047 struct triangular_solver_selector; 00048 00049 template<typename Lhs, typename Rhs, int Side, int Mode> 00050 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1> 00051 { 00052 typedef typename Lhs::Scalar LhsScalar; 00053 typedef typename Rhs::Scalar RhsScalar; 00054 typedef blas_traits<Lhs> LhsProductTraits; 00055 typedef typename LhsProductTraits::ExtractType ActualLhsType; 00056 typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs; 00057 static void run(const Lhs& lhs, Rhs& rhs) 00058 { 00059 ActualLhsType actualLhs = LhsProductTraits::extract(lhs); 00060 00061 // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1 00062 00063 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1; 00064 00065 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(), 00066 (useRhsDirectly ? rhs.data() : 0)); 00067 00068 if(!useRhsDirectly) 00069 MappedRhs(actualRhs,rhs.size()) = rhs; 00070 00071 triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate, 00072 (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor> 00073 ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); 00074 00075 if(!useRhsDirectly) 00076 rhs = MappedRhs(actualRhs, rhs.size()); 00077 } 00078 }; 00079 00080 // the rhs is a matrix 00081 template<typename Lhs, typename Rhs, int Side, int Mode> 00082 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> 00083 { 00084 typedef typename Rhs::Scalar Scalar; 00085 typedef blas_traits<Lhs> LhsProductTraits; 00086 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; 00087 00088 static void run(const Lhs& lhs, Rhs& rhs) 00089 { 00090 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs); 00091 00092 const Index size = lhs.rows(); 00093 const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows(); 00094 00095 typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, 00096 Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType; 00097 00098 BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false); 00099 00100 triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, 00101 (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> 00102 ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking); 00103 } 00104 }; 00105 00106 /*************************************************************************** 00107 * meta-unrolling implementation 00108 ***************************************************************************/ 00109 00110 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size, 00111 bool Stop = LoopIndex==Size> 00112 struct triangular_solver_unroller; 00113 00114 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size> 00115 struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,false> { 00116 enum { 00117 IsLower = ((Mode&Lower)==Lower), 00118 DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1, 00119 StartIndex = IsLower ? 0 : DiagIndex+1 00120 }; 00121 static void run(const Lhs& lhs, Rhs& rhs) 00122 { 00123 if (LoopIndex>0) 00124 rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex).template segment<LoopIndex>(StartIndex).transpose() 00125 .cwiseProduct(rhs.template segment<LoopIndex>(StartIndex)).sum(); 00126 00127 if(!(Mode & UnitDiag)) 00128 rhs.coeffRef(DiagIndex) /= lhs.coeff(DiagIndex,DiagIndex); 00129 00130 triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex+1,Size>::run(lhs,rhs); 00131 } 00132 }; 00133 00134 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size> 00135 struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,true> { 00136 static void run(const Lhs&, Rhs&) {} 00137 }; 00138 00139 template<typename Lhs, typename Rhs, int Mode> 00140 struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> { 00141 static void run(const Lhs& lhs, Rhs& rhs) 00142 { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } 00143 }; 00144 00145 template<typename Lhs, typename Rhs, int Mode> 00146 struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> { 00147 static void run(const Lhs& lhs, Rhs& rhs) 00148 { 00149 Transpose<const Lhs> trLhs(lhs); 00150 Transpose<Rhs> trRhs(rhs); 00151 00152 triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>, 00153 ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag), 00154 0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs); 00155 } 00156 }; 00157 00158 } // end namespace internal 00159 00160 /*************************************************************************** 00161 * TriangularView methods 00162 ***************************************************************************/ 00163 00164 template<typename MatrixType, unsigned int Mode> 00165 template<int Side, typename OtherDerived> 00166 void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const 00167 { 00168 OtherDerived& other = _other.const_cast_derived(); 00169 eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) ); 00170 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); 00171 00172 enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit && OtherDerived::IsVectorAtCompileTime }; 00173 typedef typename internal::conditional<copy, 00174 typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; 00175 OtherCopy otherCopy(other); 00176 00177 internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type, 00178 Side, Mode>::run(derived().nestedExpression(), otherCopy); 00179 00180 if (copy) 00181 other = otherCopy; 00182 } 00183 00184 template<typename Derived, unsigned int Mode> 00185 template<int Side, typename Other> 00186 const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other> 00187 TriangularViewImpl<Derived,Mode,Dense>::solve(const MatrixBase<Other>& other) const 00188 { 00189 return internal::triangular_solve_retval<Side,TriangularViewType,Other>(derived(), other.derived()); 00190 } 00191 00192 namespace internal { 00193 00194 00195 template<int Side, typename TriangularType, typename Rhs> 00196 struct traits<triangular_solve_retval<Side, TriangularType, Rhs> > 00197 { 00198 typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType; 00199 }; 00200 00201 template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval 00202 : public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> > 00203 { 00204 typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned; 00205 typedef ReturnByValue<triangular_solve_retval> Base; 00206 00207 triangular_solve_retval(const TriangularType& tri, const Rhs& rhs) 00208 : m_triangularMatrix(tri), m_rhs(rhs) 00209 {} 00210 00211 inline Index rows() const { return m_rhs.rows(); } 00212 inline Index cols() const { return m_rhs.cols(); } 00213 00214 template<typename Dest> inline void evalTo(Dest& dst) const 00215 { 00216 if(!(is_same<RhsNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_rhs))) 00217 dst = m_rhs; 00218 m_triangularMatrix.template solveInPlace<Side>(dst); 00219 } 00220 00221 protected: 00222 const TriangularType& m_triangularMatrix; 00223 typename Rhs::Nested m_rhs; 00224 }; 00225 00226 } // namespace internal 00227 00228 } // end namespace Eigen 00229 00230 #endif // EIGEN_SOLVETRIANGULAR_H