![]() |
Eigen
3.2.8
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H 00011 #define EIGEN_SPARSE_CWISE_BINARY_OP_H 00012 00013 namespace Eigen { 00014 00015 // Here we have to handle 3 cases: 00016 // 1 - sparse op dense 00017 // 2 - dense op sparse 00018 // 3 - sparse op sparse 00019 // We also need to implement a 4th iterator for: 00020 // 4 - dense op dense 00021 // Finally, we also need to distinguish between the product and other operations : 00022 // configuration returned mode 00023 // 1 - sparse op dense product sparse 00024 // generic dense 00025 // 2 - dense op sparse product sparse 00026 // generic dense 00027 // 3 - sparse op sparse product sparse 00028 // generic sparse 00029 // 4 - dense op dense product dense 00030 // generic dense 00031 00032 namespace internal { 00033 00034 template<> struct promote_storage_type<Dense,Sparse> 00035 { typedef Sparse ret; }; 00036 00037 template<> struct promote_storage_type<Sparse,Dense> 00038 { typedef Sparse ret; }; 00039 00040 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived, 00041 typename _LhsStorageMode = typename traits<Lhs>::StorageKind, 00042 typename _RhsStorageMode = typename traits<Rhs>::StorageKind> 00043 class sparse_cwise_binary_op_inner_iterator_selector; 00044 00045 } // end namespace internal 00046 00047 template<typename BinaryOp, typename Lhs, typename Rhs> 00048 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse> 00049 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 00050 { 00051 public: 00052 class InnerIterator; 00053 class ReverseInnerIterator; 00054 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived; 00055 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived) 00056 CwiseBinaryOpImpl() 00057 { 00058 EIGEN_STATIC_ASSERT(( 00059 (!internal::is_same<typename internal::traits<Lhs>::StorageKind, 00060 typename internal::traits<Rhs>::StorageKind>::value) 00061 || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))), 00062 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH); 00063 } 00064 }; 00065 00066 template<typename BinaryOp, typename Lhs, typename Rhs> 00067 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator 00068 : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator> 00069 { 00070 public: 00071 typedef typename Lhs::Index Index; 00072 typedef internal::sparse_cwise_binary_op_inner_iterator_selector< 00073 BinaryOp,Lhs,Rhs, InnerIterator> Base; 00074 00075 // NOTE: we have to prefix Index by "typename Lhs::" to avoid an ICE with VC11 00076 EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename Lhs::Index outer) 00077 : Base(binOp.derived(),outer) 00078 {} 00079 }; 00080 00081 /*************************************************************************** 00082 * Implementation of inner-iterators 00083 ***************************************************************************/ 00084 00085 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; }; 00086 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; }; 00087 00088 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any ! 00089 00090 namespace internal { 00091 00092 // sparse - sparse (generic) 00093 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived> 00094 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse> 00095 { 00096 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr; 00097 typedef typename traits<CwiseBinaryXpr>::Scalar Scalar; 00098 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 00099 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 00100 typedef typename _LhsNested::InnerIterator LhsIterator; 00101 typedef typename _RhsNested::InnerIterator RhsIterator; 00102 typedef typename Lhs::Index Index; 00103 00104 public: 00105 00106 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00107 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 00108 { 00109 this->operator++(); 00110 } 00111 00112 EIGEN_STRONG_INLINE Derived& operator++() 00113 { 00114 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index())) 00115 { 00116 m_id = m_lhsIter.index(); 00117 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value()); 00118 ++m_lhsIter; 00119 ++m_rhsIter; 00120 } 00121 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index()))) 00122 { 00123 m_id = m_lhsIter.index(); 00124 m_value = m_functor(m_lhsIter.value(), Scalar(0)); 00125 ++m_lhsIter; 00126 } 00127 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index()))) 00128 { 00129 m_id = m_rhsIter.index(); 00130 m_value = m_functor(Scalar(0), m_rhsIter.value()); 00131 ++m_rhsIter; 00132 } 00133 else 00134 { 00135 m_value = 0; // this is to avoid a compilation warning 00136 m_id = -1; 00137 } 00138 return *static_cast<Derived*>(this); 00139 } 00140 00141 EIGEN_STRONG_INLINE Scalar value() const { return m_value; } 00142 00143 EIGEN_STRONG_INLINE Index index() const { return m_id; } 00144 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); } 00145 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); } 00146 00147 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; } 00148 00149 protected: 00150 LhsIterator m_lhsIter; 00151 RhsIterator m_rhsIter; 00152 const BinaryOp& m_functor; 00153 Scalar m_value; 00154 Index m_id; 00155 }; 00156 00157 // sparse - sparse (product) 00158 template<typename T, typename Lhs, typename Rhs, typename Derived> 00159 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse> 00160 { 00161 typedef scalar_product_op<T> BinaryFunc; 00162 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 00163 typedef typename CwiseBinaryXpr::Scalar Scalar; 00164 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 00165 typedef typename _LhsNested::InnerIterator LhsIterator; 00166 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 00167 typedef typename _RhsNested::InnerIterator RhsIterator; 00168 typedef typename Lhs::Index Index; 00169 public: 00170 00171 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00172 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 00173 { 00174 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 00175 { 00176 if (m_lhsIter.index() < m_rhsIter.index()) 00177 ++m_lhsIter; 00178 else 00179 ++m_rhsIter; 00180 } 00181 } 00182 00183 EIGEN_STRONG_INLINE Derived& operator++() 00184 { 00185 ++m_lhsIter; 00186 ++m_rhsIter; 00187 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 00188 { 00189 if (m_lhsIter.index() < m_rhsIter.index()) 00190 ++m_lhsIter; 00191 else 00192 ++m_rhsIter; 00193 } 00194 return *static_cast<Derived*>(this); 00195 } 00196 00197 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); } 00198 00199 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 00200 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 00201 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 00202 00203 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); } 00204 00205 protected: 00206 LhsIterator m_lhsIter; 00207 RhsIterator m_rhsIter; 00208 const BinaryFunc& m_functor; 00209 }; 00210 00211 // sparse - dense (product) 00212 template<typename T, typename Lhs, typename Rhs, typename Derived> 00213 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense> 00214 { 00215 typedef scalar_product_op<T> BinaryFunc; 00216 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 00217 typedef typename CwiseBinaryXpr::Scalar Scalar; 00218 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 00219 typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested; 00220 typedef typename _LhsNested::InnerIterator LhsIterator; 00221 typedef typename Lhs::Index Index; 00222 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit }; 00223 public: 00224 00225 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00226 : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer) 00227 {} 00228 00229 EIGEN_STRONG_INLINE Derived& operator++() 00230 { 00231 ++m_lhsIter; 00232 return *static_cast<Derived*>(this); 00233 } 00234 00235 EIGEN_STRONG_INLINE Scalar value() const 00236 { return m_functor(m_lhsIter.value(), 00237 m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); } 00238 00239 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 00240 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 00241 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 00242 00243 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; } 00244 00245 protected: 00246 RhsNested m_rhs; 00247 LhsIterator m_lhsIter; 00248 const BinaryFunc m_functor; 00249 const Index m_outer; 00250 }; 00251 00252 // sparse - dense (product) 00253 template<typename T, typename Lhs, typename Rhs, typename Derived> 00254 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse> 00255 { 00256 typedef scalar_product_op<T> BinaryFunc; 00257 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 00258 typedef typename CwiseBinaryXpr::Scalar Scalar; 00259 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 00260 typedef typename _RhsNested::InnerIterator RhsIterator; 00261 typedef typename Lhs::Index Index; 00262 00263 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit }; 00264 public: 00265 00266 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00267 : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer) 00268 {} 00269 00270 EIGEN_STRONG_INLINE Derived& operator++() 00271 { 00272 ++m_rhsIter; 00273 return *static_cast<Derived*>(this); 00274 } 00275 00276 EIGEN_STRONG_INLINE Scalar value() const 00277 { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); } 00278 00279 EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); } 00280 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); } 00281 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); } 00282 00283 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; } 00284 00285 protected: 00286 const CwiseBinaryXpr& m_xpr; 00287 RhsIterator m_rhsIter; 00288 const BinaryFunc& m_functor; 00289 const Index m_outer; 00290 }; 00291 00292 } // end namespace internal 00293 00294 /*************************************************************************** 00295 * Implementation of SparseMatrixBase and SparseCwise functions/operators 00296 ***************************************************************************/ 00297 00298 template<typename Derived> 00299 template<typename OtherDerived> 00300 EIGEN_STRONG_INLINE Derived & 00301 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other) 00302 { 00303 return derived() = derived() - other.derived(); 00304 } 00305 00306 template<typename Derived> 00307 template<typename OtherDerived> 00308 EIGEN_STRONG_INLINE Derived & 00309 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other) 00310 { 00311 return derived() = derived() + other.derived(); 00312 } 00313 00314 template<typename Derived> 00315 template<typename OtherDerived> 00316 EIGEN_STRONG_INLINE const typename SparseMatrixBase<Derived>::template CwiseProductDenseReturnType<OtherDerived>::Type 00317 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const 00318 { 00319 return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived()); 00320 } 00321 00322 } // end namespace Eigen 00323 00324 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H