MOAB  4.9.3pre
BlasUtil.h
Go to the documentation of this file.
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2009-2010 Gael Guennebaud <[email protected]>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_BLASUTIL_H
00011 #define EIGEN_BLASUTIL_H
00012 
00013 // This file contains many lightweight helper classes used to
00014 // implement and control fast level 2 and level 3 BLAS-like routines.
00015 
00016 namespace Eigen {
00017 
00018 namespace internal {
00019 
00020 // forward declarations
00021 template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
00022 struct gebp_kernel;
00023 
00024 template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
00025 struct gemm_pack_rhs;
00026 
00027 template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
00028 struct gemm_pack_lhs;
00029 
00030 template<
00031   typename Index,
00032   typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00033   typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00034   int ResStorageOrder>
00035 struct general_matrix_matrix_product;
00036 
00037 template<typename Index,
00038          typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
00039          typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
00040 struct general_matrix_vector_product;
00041 
00042 
00043 template<bool Conjugate> struct conj_if;
00044 
00045 template<> struct conj_if<true> {
00046   template<typename T>
00047   inline T operator()(const T& x) { return numext::conj(x); }
00048   template<typename T>
00049   inline T pconj(const T& x) { return internal::pconj(x); }
00050 };
00051 
00052 template<> struct conj_if<false> {
00053   template<typename T>
00054   inline const T& operator()(const T& x) { return x; }
00055   template<typename T>
00056   inline const T& pconj(const T& x) { return x; }
00057 };
00058 
00059 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
00060 {
00061   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
00062   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
00063 };
00064 
00065 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
00066 {
00067   typedef std::complex<RealScalar> Scalar;
00068   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00069   { return c + pmul(x,y); }
00070 
00071   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00072   { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
00073 };
00074 
00075 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
00076 {
00077   typedef std::complex<RealScalar> Scalar;
00078   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00079   { return c + pmul(x,y); }
00080 
00081   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00082   { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
00083 };
00084 
00085 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
00086 {
00087   typedef std::complex<RealScalar> Scalar;
00088   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00089   { return c + pmul(x,y); }
00090 
00091   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00092   { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
00093 };
00094 
00095 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
00096 {
00097   typedef std::complex<RealScalar> Scalar;
00098   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
00099   { return padd(c, pmul(x,y)); }
00100   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
00101   { return conj_if<Conj>()(x)*y; }
00102 };
00103 
00104 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
00105 {
00106   typedef std::complex<RealScalar> Scalar;
00107   EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
00108   { return padd(c, pmul(x,y)); }
00109   EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
00110   { return x*conj_if<Conj>()(y); }
00111 };
00112 
00113 template<typename From,typename To> struct get_factor {
00114   EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
00115 };
00116 
00117 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
00118   EIGEN_DEVICE_FUNC
00119   static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
00120 };
00121 
00122 
00123 template<typename Scalar, typename Index>
00124 class BlasVectorMapper {
00125   public:
00126   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
00127 
00128   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
00129     return m_data[i];
00130   }
00131   template <typename Packet, int AlignmentType>
00132   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
00133     return ploadt<Packet, AlignmentType>(m_data + i);
00134   }
00135 
00136   template <typename Packet>
00137   EIGEN_DEVICE_FUNC bool aligned(Index i) const {
00138     return (size_t(m_data+i)%sizeof(Packet))==0;
00139   }
00140 
00141   protected:
00142   Scalar* m_data;
00143 };
00144 
00145 template<typename Scalar, typename Index, int AlignmentType>
00146 class BlasLinearMapper {
00147   public:
00148   typedef typename packet_traits<Scalar>::type Packet;
00149   typedef typename packet_traits<Scalar>::half HalfPacket;
00150 
00151   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
00152 
00153   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
00154     internal::prefetch(&operator()(i));
00155   }
00156 
00157   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
00158     return m_data[i];
00159   }
00160 
00161   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
00162     return ploadt<Packet, AlignmentType>(m_data + i);
00163   }
00164 
00165   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
00166     return ploadt<HalfPacket, AlignmentType>(m_data + i);
00167   }
00168 
00169   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
00170     pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
00171   }
00172 
00173   protected:
00174   Scalar *m_data;
00175 };
00176 
00177 // Lightweight helper class to access matrix coefficients.
00178 template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
00179 class blas_data_mapper {
00180   public:
00181   typedef typename packet_traits<Scalar>::type Packet;
00182   typedef typename packet_traits<Scalar>::half HalfPacket;
00183 
00184   typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
00185   typedef BlasVectorMapper<Scalar, Index> VectorMapper;
00186 
00187   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00188 
00189   EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
00190   getSubMapper(Index i, Index j) const {
00191     return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
00192   }
00193 
00194   EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
00195     return LinearMapper(&operator()(i, j));
00196   }
00197 
00198   EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
00199     return VectorMapper(&operator()(i, j));
00200   }
00201 
00202 
00203   EIGEN_DEVICE_FUNC
00204   EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
00205     return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
00206   }
00207 
00208   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
00209     return ploadt<Packet, AlignmentType>(&operator()(i, j));
00210   }
00211 
00212   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
00213     return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
00214   }
00215 
00216   template<typename SubPacket>
00217   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
00218     pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
00219   }
00220 
00221   template<typename SubPacket>
00222   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
00223     return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
00224   }
00225 
00226   EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
00227   EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
00228 
00229   EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
00230     if (size_t(m_data)%sizeof(Scalar)) {
00231       return -1;
00232     }
00233     return internal::first_default_aligned(m_data, size);
00234   }
00235 
00236   protected:
00237   Scalar* EIGEN_RESTRICT m_data;
00238   const Index m_stride;
00239 };
00240 
00241 // lightweight helper class to access matrix coefficients (const version)
00242 template<typename Scalar, typename Index, int StorageOrder>
00243 class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
00244   public:
00245   EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
00246 
00247   EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
00248     return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
00249   }
00250 };
00251 
00252 
00253 /* Helper class to analyze the factors of a Product expression.
00254  * In particular it allows to pop out operator-, scalar multiples,
00255  * and conjugate */
00256 template<typename XprType> struct blas_traits
00257 {
00258   typedef typename traits<XprType>::Scalar Scalar;
00259   typedef const XprType& ExtractType;
00260   typedef XprType _ExtractType;
00261   enum {
00262     IsComplex = NumTraits<Scalar>::IsComplex,
00263     IsTransposed = false,
00264     NeedToConjugate = false,
00265     HasUsableDirectAccess = (    (int(XprType::Flags)&DirectAccessBit)
00266                               && (   bool(XprType::IsVectorAtCompileTime)
00267                                   || int(inner_stride_at_compile_time<XprType>::ret) == 1)
00268                              ) ?  1 : 0
00269   };
00270   typedef typename conditional<bool(HasUsableDirectAccess),
00271     ExtractType,
00272     typename _ExtractType::PlainObject
00273     >::type DirectLinearAccessType;
00274   static inline ExtractType extract(const XprType& x) { return x; }
00275   static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
00276 };
00277 
00278 // pop conjugate
00279 template<typename Scalar, typename NestedXpr>
00280 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
00281  : blas_traits<NestedXpr>
00282 {
00283   typedef blas_traits<NestedXpr> Base;
00284   typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
00285   typedef typename Base::ExtractType ExtractType;
00286 
00287   enum {
00288     IsComplex = NumTraits<Scalar>::IsComplex,
00289     NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
00290   };
00291   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00292   static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
00293 };
00294 
00295 // pop scalar multiple
00296 template<typename Scalar, typename NestedXpr>
00297 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
00298  : blas_traits<NestedXpr>
00299 {
00300   typedef blas_traits<NestedXpr> Base;
00301   typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
00302   typedef typename Base::ExtractType ExtractType;
00303   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00304   static inline Scalar extractScalarFactor(const XprType& x)
00305   { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
00306 };
00307 
00308 // pop opposite
00309 template<typename Scalar, typename NestedXpr>
00310 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
00311  : blas_traits<NestedXpr>
00312 {
00313   typedef blas_traits<NestedXpr> Base;
00314   typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
00315   typedef typename Base::ExtractType ExtractType;
00316   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00317   static inline Scalar extractScalarFactor(const XprType& x)
00318   { return - Base::extractScalarFactor(x.nestedExpression()); }
00319 };
00320 
00321 // pop/push transpose
00322 template<typename NestedXpr>
00323 struct blas_traits<Transpose<NestedXpr> >
00324  : blas_traits<NestedXpr>
00325 {
00326   typedef typename NestedXpr::Scalar Scalar;
00327   typedef blas_traits<NestedXpr> Base;
00328   typedef Transpose<NestedXpr> XprType;
00329   typedef Transpose<const typename Base::_ExtractType>  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
00330   typedef Transpose<const typename Base::_ExtractType> _ExtractType;
00331   typedef typename conditional<bool(Base::HasUsableDirectAccess),
00332     ExtractType,
00333     typename ExtractType::PlainObject
00334     >::type DirectLinearAccessType;
00335   enum {
00336     IsTransposed = Base::IsTransposed ? 0 : 1
00337   };
00338   static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
00339   static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
00340 };
00341 
00342 template<typename T>
00343 struct blas_traits<const T>
00344      : blas_traits<T>
00345 {};
00346 
00347 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
00348 struct extract_data_selector {
00349   static const typename T::Scalar* run(const T& m)
00350   {
00351     return blas_traits<T>::extract(m).data();
00352   }
00353 };
00354 
00355 template<typename T>
00356 struct extract_data_selector<T,false> {
00357   static typename T::Scalar* run(const T&) { return 0; }
00358 };
00359 
00360 template<typename T> const typename T::Scalar* extract_data(const T& m)
00361 {
00362   return extract_data_selector<T>::run(m);
00363 }
00364 
00365 } // end namespace internal
00366 
00367 } // end namespace Eigen
00368 
00369 #endif // EIGEN_BLASUTIL_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines