DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkL2RegularizedLeastSquaresSolver.hxx
Go to the documentation of this file.
1 
19 #ifndef __itkL2RegularizedLeastSquaresSolver_hxx
20 #define __itkL2RegularizedLeastSquaresSolver_hxx
21 
23 #include "utl.h"
24 #include "utlVNLBlas.h"
25 #include "utlVNLLapack.h"
26 
27 
28 namespace itk
29 {
30 
31 template <class TPrecision>
34  m_A(new MatrixType()),
35  m_b(new VectorType()),
36  m_Lambda(new MatrixType()),
37  m_LS(new MatrixType())
38 {
39  m_ConditionNumber = -1;
40  // empty matrix is symmetric
41  m_IsLambdaSymmetric = true;
42 }
43 
44 template <class TPrecision>
45 void
47 ::SetA(const MatrixPointer& mat)
48 {
49  itkDebugMacro("setting A to " << *mat);
50  if ( *this->m_A != *mat )
51  {
52  // NOTE: use value copy because mat can be changed outside, while m_LS can only be changed inside.
53  m_A=MatrixPointer(new MatrixType());
54  *this->m_A = *mat;
55  this->Modified();
56  m_LS = MatrixPointer(new MatrixType());
57  m_ConditionNumber=-1;
58  }
59 }
60 
61 template <class TPrecision>
62 void
65 {
66  itkDebugMacro("setting Lambda to " << *mat);
67  if ( *this->m_Lambda != *mat )
68  {
69  this->m_Lambda=MatrixPointer(new MatrixType());
70  *this->m_Lambda = *mat;
71  this->Modified();
72  m_IsLambdaSymmetric = m_Lambda->IsSymmetric();
73  m_LS = MatrixPointer(new MatrixType());
74  m_ConditionNumber=-1;
75  }
76 }
77 
78 template <class TPrecision>
79 void
82 {
83  Superclass::VerifyInputs();
84  int N = GetXDimension();
85  int M = m_A->Rows();
86  utlGlobalException(M<=0 || N<=0, "need to set m_A" );
87  utlGlobalException(M!=m_b->Size(), "wrong size of m_b! m_A->rows=()"<<m_A->Rows() <<", m_b->Size()="<<m_b->Size());
88  if (m_Lambda->Size()>0)
89  {
90  utlGlobalException(m_Lambda->Rows()!=m_Lambda->Columns(), "m_Lambda needs to be square");
91  utlGlobalException(m_Lambda->Rows()!=N, "wrong size of m_Lambda! m_Lambda->Rows()="<<m_Lambda->Rows() << ", N="<<N);
92  }
93 }
94 
95 template <class TPrecision>
96 void
99 {
100  Superclass::Clear();
101  ClearA();
102  Clearb();
103  ClearLambda();
104 }
105 
107 template <class TPrecision>
108 void
110 ::Initialize(const VectorType& xInitial)
111 {
112  Superclass::Initialize(xInitial);
113  if (m_LS->Size()==0)
114  {
115  m_LS = MatrixPointer(new MatrixType());
116  // utl::ProductVnlMtM(*m_A, *m_A, *m_LS);
117  utl::ProductUtlXtX(*m_A, *m_LS);
118  if (m_Lambda->Size()>0)
119  *m_LS += *m_Lambda;
120  // utl::vAdd(m_LS->Size(),m_LS->data_block(), m_Lambda->data_block(), m_LS->data_block());
121  m_ConditionNumber = m_LS->GetInfNorm();
122  MatrixPointer tmp(new MatrixType());
123  if (m_IsLambdaSymmetric)
124  *tmp = m_LS->PInverseSymmericMatrix();
125  else
126  *tmp = m_LS->PInverseMatrix();
127  m_ConditionNumber *= tmp->GetInfNorm();
128  utl::ProductUtlMMt(*tmp, *m_A, *m_LS);
129  }
130 }
131 
132 template <class TPrecision>
133 void
135 ::Solve(const VectorType& xInitial)
136 {
137  VerifyInputs();
138  Initialize(xInitial);
139  utl::ProductUtlMv(*m_LS, *m_b, this->m_x);
140 }
141 
142 template <class TPrecision>
146 {
147  const VectorType* xx = (x.Size()!=0? (&x) : (&this->m_x));
148  VectorType tmp;
149  utl::ProductUtlMv(*m_A, *xx, tmp);
150  ValueType cost = utl::ToVector<double>(tmp - *m_b)->GetSquaredTwoNorm();
151  if (m_Lambda->Size()>0)
152  {
153  utl::ProductUtlvM(*xx, *m_Lambda, tmp);
154  cost += tmp.InnerProduct(*xx);
155  }
156  return cost;
157 }
158 
159 template < class TPrecision >
160 typename LightObject::Pointer
163 {
164  typename LightObject::Pointer loPtr = Superclass::InternalClone();
165  typename Self::Pointer rval = dynamic_cast<Self *>(loPtr.GetPointer());
166  if(rval.IsNull())
167  {
168  itkExceptionMacro(<< "downcast to type " << this->GetNameOfClass()<< " failed.");
169  }
170  *rval->m_A = *m_A;
171  *rval->m_b = *m_b;
172  *rval->m_Lambda = *m_Lambda;
173  *rval->m_LS = *m_LS;
174  rval->m_IsLambdaSymmetric = m_IsLambdaSymmetric;
175  rval->m_ConditionNumber = m_ConditionNumber;
176  return loPtr;
177 }
178 
179 template <class TPrecision>
180 void
182 ::PrintSelf(std::ostream& os, Indent indent) const
183 {
184  Superclass::PrintSelf(os, indent);
185  utl::PrintUtlMatrix(*m_A, "m_A", " ", os << indent);
186  utl::PrintUtlVector(*m_b, "m_b", " ", os << indent);
187  utl::PrintUtlMatrix(*m_Lambda, "m_Lambda", " ", os << indent);
188  if (m_LS->Size()>0)
189  utl::PrintUtlMatrix(*m_LS, "m_LS", " ", os << indent);
190  os << indent << "m_ConditionNumber = " << m_ConditionNumber << std::endl << std::flush;
191  os << indent << "m_IsLambdaSymmetric = " << m_IsLambdaSymmetric << std::endl << std::flush;
192 }
193 
194 }
195 
196 #endif
NDArray is a N-Dimensional array class (row-major, c version)
Definition: utlFunctors.h:131
void ProductUtlXtX(const utl::NDArray< T, 2 > &A, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
helper functions specifically used in dmritool
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
virtual LightObject::Pointer InternalClone() const ITK_OVERRIDE
void ProductUtlvM(const utl::NDArray< T, 1 > &b, const utl::NDArray< T, 2 > &A, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
void Initialize(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
ValueType EvaluateCostFunction(const VectorType &x=VectorType()) const ITK_OVERRIDE
#define utlGlobalException(cond, expout)
Definition: utlCoreMacro.h:372
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
solve least square problem with L2 regularization
ValueType InnerProduct(const NDArrayBase< T, Dim > &vec) const
Definition: utlNDArray.h:1123
void Solve(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
void ProductUtlMMt(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 2 > &B, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
void PrintUtlVector(const NDArray< T, 1 > &vec, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout, bool showStats=true)
SizeType Size() const
Definition: utlNDArray.h:321
void ProductUtlMv(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 1 > &b, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
Base class for some optimization solvers using primal-dual updates.
Definition: itkSolverBase.h:39