![]() |
Eigen
3.2.8
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CONJUGATE_GRADIENT_H 00011 #define EIGEN_CONJUGATE_GRADIENT_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner> 00027 EIGEN_DONT_INLINE 00028 void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, 00029 const Preconditioner& precond, int& iters, 00030 typename Dest::RealScalar& tol_error) 00031 { 00032 using std::sqrt; 00033 using std::abs; 00034 typedef typename Dest::RealScalar RealScalar; 00035 typedef typename Dest::Scalar Scalar; 00036 typedef Matrix<Scalar,Dynamic,1> VectorType; 00037 00038 RealScalar tol = tol_error; 00039 int maxIters = iters; 00040 00041 int n = mat.cols(); 00042 00043 VectorType residual = rhs - mat * x; //initial residual 00044 00045 RealScalar rhsNorm2 = rhs.squaredNorm(); 00046 if(rhsNorm2 == 0) 00047 { 00048 x.setZero(); 00049 iters = 0; 00050 tol_error = 0; 00051 return; 00052 } 00053 RealScalar threshold = tol*tol*rhsNorm2; 00054 RealScalar residualNorm2 = residual.squaredNorm(); 00055 if (residualNorm2 < threshold) 00056 { 00057 iters = 0; 00058 tol_error = sqrt(residualNorm2 / rhsNorm2); 00059 return; 00060 } 00061 00062 VectorType p(n); 00063 p = precond.solve(residual); //initial search direction 00064 00065 VectorType z(n), tmp(n); 00066 RealScalar absNew = numext::real(residual.dot(p)); // the square of the absolute value of r scaled by invM 00067 int i = 0; 00068 while(i < maxIters) 00069 { 00070 tmp.noalias() = mat * p; // the bottleneck of the algorithm 00071 00072 Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir 00073 x += alpha * p; // update solution 00074 residual -= alpha * tmp; // update residue 00075 00076 residualNorm2 = residual.squaredNorm(); 00077 if(residualNorm2 < threshold) 00078 break; 00079 00080 z = precond.solve(residual); // approximately solve for "A z = residual" 00081 00082 RealScalar absOld = absNew; 00083 absNew = numext::real(residual.dot(z)); // update the absolute value of r 00084 RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction 00085 p = z + beta * p; // update search direction 00086 i++; 00087 } 00088 tol_error = sqrt(residualNorm2 / rhsNorm2); 00089 iters = i; 00090 } 00091 00092 } 00093 00094 template< typename _MatrixType, int _UpLo=Lower, 00095 typename _Preconditioner = DiagonalPreconditioner<typename _MatrixType::Scalar> > 00096 class ConjugateGradient; 00097 00098 namespace internal { 00099 00100 template< typename _MatrixType, int _UpLo, typename _Preconditioner> 00101 struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > 00102 { 00103 typedef _MatrixType MatrixType; 00104 typedef _Preconditioner Preconditioner; 00105 }; 00106 00107 } 00108 00146 template< typename _MatrixType, int _UpLo, typename _Preconditioner> 00147 class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > 00148 { 00149 typedef IterativeSolverBase<ConjugateGradient> Base; 00150 using Base::mp_matrix; 00151 using Base::m_error; 00152 using Base::m_iterations; 00153 using Base::m_info; 00154 using Base::m_isInitialized; 00155 public: 00156 typedef _MatrixType MatrixType; 00157 typedef typename MatrixType::Scalar Scalar; 00158 typedef typename MatrixType::Index Index; 00159 typedef typename MatrixType::RealScalar RealScalar; 00160 typedef _Preconditioner Preconditioner; 00161 00162 enum { 00163 UpLo = _UpLo 00164 }; 00165 00166 public: 00167 00169 ConjugateGradient() : Base() {} 00170 00181 template<typename MatrixDerived> 00182 explicit ConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {} 00183 00184 ~ConjugateGradient() {} 00185 00191 template<typename Rhs,typename Guess> 00192 inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess> 00193 solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const 00194 { 00195 eigen_assert(m_isInitialized && "ConjugateGradient is not initialized."); 00196 eigen_assert(Base::rows()==b.rows() 00197 && "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b"); 00198 return internal::solve_retval_with_guess 00199 <ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0); 00200 } 00201 00203 template<typename Rhs,typename Dest> 00204 void _solveWithGuess(const Rhs& b, Dest& x) const 00205 { 00206 typedef typename internal::conditional<UpLo==(Lower|Upper), 00207 const MatrixType&, 00208 SparseSelfAdjointView<const MatrixType, UpLo> 00209 >::type MatrixWrapperType; 00210 m_iterations = Base::maxIterations(); 00211 m_error = Base::m_tolerance; 00212 00213 for(int j=0; j<b.cols(); ++j) 00214 { 00215 m_iterations = Base::maxIterations(); 00216 m_error = Base::m_tolerance; 00217 00218 typename Dest::ColXpr xj(x,j); 00219 internal::conjugate_gradient(MatrixWrapperType(*mp_matrix), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error); 00220 } 00221 00222 m_isInitialized = true; 00223 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; 00224 } 00225 00227 template<typename Rhs,typename Dest> 00228 void _solve(const Rhs& b, Dest& x) const 00229 { 00230 x.setZero(); 00231 _solveWithGuess(b,x); 00232 } 00233 00234 protected: 00235 00236 }; 00237 00238 00239 namespace internal { 00240 00241 template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs> 00242 struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> 00243 : solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs> 00244 { 00245 typedef ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> Dec; 00246 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs) 00247 00248 template<typename Dest> void evalTo(Dest& dst) const 00249 { 00250 dec()._solve(rhs(),dst); 00251 } 00252 }; 00253 00254 } // end namespace internal 00255 00256 } // end namespace Eigen 00257 00258 #endif // EIGEN_CONJUGATE_GRADIENT_H