Eigen  3.2.8
ConjugateGradient.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CONJUGATE_GRADIENT_H
00011 #define EIGEN_CONJUGATE_GRADIENT_H
00012 
00013 namespace Eigen { 
00014 
00015 namespace internal {
00016 
00026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
00027 EIGEN_DONT_INLINE
00028 void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
00029                         const Preconditioner& precond, int& iters,
00030                         typename Dest::RealScalar& tol_error)
00031 {
00032   using std::sqrt;
00033   using std::abs;
00034   typedef typename Dest::RealScalar RealScalar;
00035   typedef typename Dest::Scalar Scalar;
00036   typedef Matrix<Scalar,Dynamic,1> VectorType;
00037   
00038   RealScalar tol = tol_error;
00039   int maxIters = iters;
00040   
00041   int n = mat.cols();
00042 
00043   VectorType residual = rhs - mat * x; //initial residual
00044 
00045   RealScalar rhsNorm2 = rhs.squaredNorm();
00046   if(rhsNorm2 == 0) 
00047   {
00048     x.setZero();
00049     iters = 0;
00050     tol_error = 0;
00051     return;
00052   }
00053   RealScalar threshold = tol*tol*rhsNorm2;
00054   RealScalar residualNorm2 = residual.squaredNorm();
00055   if (residualNorm2 < threshold)
00056   {
00057     iters = 0;
00058     tol_error = sqrt(residualNorm2 / rhsNorm2);
00059     return;
00060   }
00061   
00062   VectorType p(n);
00063   p = precond.solve(residual);      //initial search direction
00064 
00065   VectorType z(n), tmp(n);
00066   RealScalar absNew = numext::real(residual.dot(p));  // the square of the absolute value of r scaled by invM
00067   int i = 0;
00068   while(i < maxIters)
00069   {
00070     tmp.noalias() = mat * p;              // the bottleneck of the algorithm
00071 
00072     Scalar alpha = absNew / p.dot(tmp);   // the amount we travel on dir
00073     x += alpha * p;                       // update solution
00074     residual -= alpha * tmp;              // update residue
00075     
00076     residualNorm2 = residual.squaredNorm();
00077     if(residualNorm2 < threshold)
00078       break;
00079     
00080     z = precond.solve(residual);          // approximately solve for "A z = residual"
00081 
00082     RealScalar absOld = absNew;
00083     absNew = numext::real(residual.dot(z));     // update the absolute value of r
00084     RealScalar beta = absNew / absOld;            // calculate the Gram-Schmidt value used to create the new search direction
00085     p = z + beta * p;                             // update search direction
00086     i++;
00087   }
00088   tol_error = sqrt(residualNorm2 / rhsNorm2);
00089   iters = i;
00090 }
00091 
00092 }
00093 
00094 template< typename _MatrixType, int _UpLo=Lower,
00095           typename _Preconditioner = DiagonalPreconditioner<typename _MatrixType::Scalar> >
00096 class ConjugateGradient;
00097 
00098 namespace internal {
00099 
00100 template< typename _MatrixType, int _UpLo, typename _Preconditioner>
00101 struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
00102 {
00103   typedef _MatrixType MatrixType;
00104   typedef _Preconditioner Preconditioner;
00105 };
00106 
00107 }
00108 
00146 template< typename _MatrixType, int _UpLo, typename _Preconditioner>
00147 class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
00148 {
00149   typedef IterativeSolverBase<ConjugateGradient> Base;
00150   using Base::mp_matrix;
00151   using Base::m_error;
00152   using Base::m_iterations;
00153   using Base::m_info;
00154   using Base::m_isInitialized;
00155 public:
00156   typedef _MatrixType MatrixType;
00157   typedef typename MatrixType::Scalar Scalar;
00158   typedef typename MatrixType::Index Index;
00159   typedef typename MatrixType::RealScalar RealScalar;
00160   typedef _Preconditioner Preconditioner;
00161 
00162   enum {
00163     UpLo = _UpLo
00164   };
00165 
00166 public:
00167 
00169   ConjugateGradient() : Base() {}
00170 
00181   template<typename MatrixDerived>
00182   explicit ConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {}
00183 
00184   ~ConjugateGradient() {}
00185   
00191   template<typename Rhs,typename Guess>
00192   inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess>
00193   solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
00194   {
00195     eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
00196     eigen_assert(Base::rows()==b.rows()
00197               && "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b");
00198     return internal::solve_retval_with_guess
00199             <ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0);
00200   }
00201 
00203   template<typename Rhs,typename Dest>
00204   void _solveWithGuess(const Rhs& b, Dest& x) const
00205   {
00206     typedef typename internal::conditional<UpLo==(Lower|Upper),
00207                                            const MatrixType&,
00208                                            SparseSelfAdjointView<const MatrixType, UpLo>
00209                                           >::type MatrixWrapperType;
00210     m_iterations = Base::maxIterations();
00211     m_error = Base::m_tolerance;
00212 
00213     for(int j=0; j<b.cols(); ++j)
00214     {
00215       m_iterations = Base::maxIterations();
00216       m_error = Base::m_tolerance;
00217 
00218       typename Dest::ColXpr xj(x,j);
00219       internal::conjugate_gradient(MatrixWrapperType(*mp_matrix), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
00220     }
00221 
00222     m_isInitialized = true;
00223     m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
00224   }
00225   
00227   template<typename Rhs,typename Dest>
00228   void _solve(const Rhs& b, Dest& x) const
00229   {
00230     x.setZero();
00231     _solveWithGuess(b,x);
00232   }
00233 
00234 protected:
00235 
00236 };
00237 
00238 
00239 namespace internal {
00240 
00241 template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs>
00242 struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
00243   : solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
00244 {
00245   typedef ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> Dec;
00246   EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
00247 
00248   template<typename Dest> void evalTo(Dest& dst) const
00249   {
00250     dec()._solve(rhs(),dst);
00251   }
00252 };
00253 
00254 } // end namespace internal
00255 
00256 } // end namespace Eigen
00257 
00258 #endif // EIGEN_CONJUGATE_GRADIENT_H
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends