DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkL1RegularizedLeastSquaresFISTASolver.hxx
Go to the documentation of this file.
1 
18 #ifndef __itkL1RegularizedLeastSquaresFISTASolver_hxx
19 #define __itkL1RegularizedLeastSquaresFISTASolver_hxx
20 
22 #include "utl.h"
23 #include "utlVNLBlas.h"
24 
25 namespace itk
26 {
27 
28 template < class TPrecision >
31  m_A(new MatrixType()),
32  m_b(new VectorType()),
33  m_w(new VectorType()),
34  m_At(new MatrixType()),
35  m_AtA(new MatrixType()),
36  m_Atb(new VectorType()),
37  m_xOld(new VectorType())
38 {
39  m_Step=-1;
40  m_L2Solver = L2SolverType::New();
41  m_UseL2SolverForInitialization = false;
42 }
43 
44 template < class TPrecision >
45 void
47 ::SetA (const MatrixPointer& mat)
48 {
49  itkDebugMacro("setting A to " << *mat);
50  if ( *this->m_A != *mat )
51  {
52  m_A=MatrixPointer(new MatrixType());
53  m_At=MatrixPointer(new MatrixType());
54  m_AtA=MatrixPointer(new MatrixType());
55  *this->m_A = *mat;
56  // utl::MatrixCopy(*mat, *this->m_A, 1.0, 'N');
57  this->Modified();
58  *this->m_At = this->m_A->GetTranspose();
59  // utl::MatrixCopy(*this->m_A, *m_At, 1.0, 'T');
60  utl::ProductUtlXtX(*m_A, *m_AtA);
61  m_Step = 0.5/m_AtA->GetTwoNorm();
62  if (m_b->Size()>0)
63  {
64  utlException(m_A->Rows()!=m_b->Size(), "wrong size of m_A");
65  m_Atb=VectorPointer(new VectorType());
66  utl::ProductUtlMv(*m_At, *m_b, *m_Atb);
67  }
68  }
69  if (m_UseL2SolverForInitialization)
70  m_L2Solver->SetA(m_A);
71 }
72 
73 template < class TPrecision >
74 void
76 ::Setb (const VectorPointer& b)
77 {
78  itkDebugMacro("setting b to " << *b);
79  if ( *this->m_b != *b )
80  {
81  m_b=VectorPointer(new VectorType());
82  *this->m_b = *b;
83  this->Modified();
84  if (m_A->Size()>0)
85  {
86  utlException(m_At->Columns()!=m_b->Size(), "wrong size of m_A");
87  m_Atb=VectorPointer(new VectorType());
88  utl::ProductUtlMv(*m_At, *m_b, *m_Atb);
89  }
90  }
91  if (m_UseL2SolverForInitialization)
92  m_L2Solver->Setb(m_b);
93 }
94 
95 template < class TPrecision >
96 void
98 ::Setw (const VectorPointer& w)
99 {
100  itkDebugMacro("setting w to " << *w);
101  if ( this->m_w != w )
102  {
103  this->m_w = w;
104  this->Modified();
105  }
106 }
107 
108 template < class TPrecision >
109 void
112 {
113  int N = w->Size();
114  MatrixPointer lambda(new MatrixType(N, N));
115  lambda->Fill(0.0);
116  lambda->SetDiagonal(*w);
117  utlException(!m_UseL2SolverForInitialization, "need to set m_UseL2SolverForInitialization");
118  m_L2Solver->SetLambda(lambda);
119  this->Modified();
120 }
121 
122 template < class TPrecision >
123 void
126 {
127  Superclass::VerifyInputs();
128  int N = this->GetXDimension();
129  utlException(m_w->Size()!=N, "wrong size of m_w!, m_w->Size()="<<m_w->Size()<<", N="<<N);
130  utlException(m_A->Columns()!=N, "wrong size of m_A!, m_A->Columns()="<<m_A->Columns()<<", N="<<N);
131  utlException(m_A->Rows()!=m_b->Size(), "wrong rows of m_A!, m_A->Rows()="<<m_A->Rows()<<", m_b->Size()="<<m_b->Size());
132 }
133 
134 template < class TPrecision >
135 typename LightObject::Pointer
138 {
139  typename LightObject::Pointer loPtr = Superclass::InternalClone();
140  typename Self::Pointer rval = dynamic_cast<Self *>(loPtr.GetPointer());
141  if(rval.IsNull())
142  {
143  itkExceptionMacro(<< "downcast to type " << this->GetNameOfClass()<< " failed.");
144  }
145  rval->m_A = m_A;
146  rval->m_b = m_b;
147  rval->m_w = m_w;
148  rval->m_At = m_At;
149  rval->m_AtA = m_AtA;
150  rval->m_Atb = m_Atb;
151  rval->m_Step = m_Step;
152  rval->m_UseL2SolverForInitialization = m_UseL2SolverForInitialization;
153  rval->m_L2Solver = m_L2Solver->Clone();
154  return loPtr;
155 }
156 
157 template < class TPrecision >
158 void
160 ::Initialize ( const VectorType& xInitial)
161 {
162  utlShowPosition(this->GetDebug());
163  Superclass::Initialize(xInitial);
164  if (xInitial.Size()==0)
165  {
166  utlException(!m_UseL2SolverForInitialization, "need to set m_UseL2SolverForInitialization");
167  if (m_L2Solver->GetLambda()->Size()==0)
168  SetwForInitialization(this->m_w);
169  m_L2Solver->Solve();
170  this->m_x = m_L2Solver->Getx();
171  // m_L2Solver->Print(std::cout<<"m_L2Solver : ");
172  }
173  if (this->GetDebug())
174  utl::PrintUtlVector(this->m_x, "m_x initialization");
175 }
176 
177 template < class TPrecision >
178 void
180 ::Solve ( const VectorType& xInitial)
181 {
182  utlShowPosition(this->GetDebug());
183  this->VerifyInputs();
184  Initialize(xInitial);
185  Iterate();
186  // EndSolve();
187 }
188 
189 template < class TPrecision >
193 {
194  const VectorType* xx = (x.Size()!=0? (&x) : (&this->m_x));
195  VectorPointer e(new VectorType());
196  utl::ProductUtlMv(*m_A, *xx, *e);
197  utl::vSub(e->Size(), e->GetData(), m_b->GetData(), e->GetData());
198  // *e -= *m_b;
199  double eNorm = utl::cblas_nrm2(e->Size(), e->GetData(), 1);
200  VectorType tmp(e->Size());
201  utl::vMul(e->Size(), m_w->GetData(), (double*)xx->GetData(), tmp.GetData());
202  ValueType func = eNorm*eNorm + utl::cblas_asum(e->Size(), tmp.GetData(),1);
203  // ValueType func = e->squared_magnitude() + element_product(*m_w, *xx).one_norm();
204  return func;
205 }
206 
207 template < class TPrecision >
208 void
211 {
212  ValueType fValue = EvaluateCostFunction(), changePercentage=0, changePercentage_x=0;
213  this->m_CostFunction.push_back(fValue);
214  int size = this->m_CostFunction.size();
215  VectorType tmp(this->m_x.Size());
216  if (size>=2)
217  {
218  changePercentage = (this->m_CostFunction[size-2] - this->m_CostFunction[size-1])/this->m_CostFunction[size-2];
219  double xOldNorm = utl::cblas_nrm2(this->m_xOld->Size(), this->m_xOld->GetData(), 1);
220  if (xOldNorm>0)
221  {
222  utl::vSub(this->m_x.Size(), this->m_x.GetData(), m_xOld->GetData(), tmp.GetData());
223  changePercentage_x = utl::cblas_nrm2(tmp.Size(), tmp.GetData(), 1) / xOldNorm;
224  }
225  if (changePercentage <= this->m_MinRelativeChangeOfCostFunction && changePercentage>=0 && changePercentage_x<=this->m_MinRelativeChangeOfPrimalResidual)
226  {
227  this->m_UpdateInformation = Self::CONTINUE;
228  this->m_NumberOfChangeLessThanThreshold++;
229  }
230  else
231  {
232  this->m_UpdateInformation = Self::CONTINUE;
233  this->m_NumberOfChangeLessThanThreshold = 0;
234  }
235  if (this->m_NumberOfChangeLessThanThreshold == 3)
236  this->m_UpdateInformation = Self::STOP_MIN_CHANGE;
237  }
238  else
239  this->m_UpdateInformation = Self::CONTINUE;
240  // if (this->GetDebug())
241  // {
242  // // utl::PrintUtlVector(this->m_x, "m_x");
243  // utlPrintVar4(true, fValue, changePercentage, changePercentage_x, this->m_MinRelativeChangeOfCostFunction);
244  // std::cout << "m_NumberOfChangeLessThanThreshold = " << this->m_NumberOfChangeLessThanThreshold << std::endl << std::flush;
245  // }
246 }
247 
248 template < class TPrecision >
249 void
252 {
253  utlShowPosition(this->GetDebug());
254 
255  VectorType y=this->m_x, xg=this->m_x;
256  double t = 1.0, t_new=1.0, func=-1.0;
257  VectorType w_new = (*m_w) % m_Step, tmp(xg.Size());
258  this->m_CostFunction.push_back(EvaluateCostFunction());
259 
260  int N = GetXDimension();
261  this->m_NumberOfIterations=0;
262  int num_change_less = 0;
263  m_xOld->ReSize(this->m_x.Size());
264  while ( this->m_NumberOfIterations <= this->m_MaxNumberOfIterations )
265  {
266 
267  utl::cblas_copy(this->m_x.Size(), this->m_x.GetData(), 1, m_xOld->GetData(), 1);
268  if (this->GetDebug())
269  {
270  std::cout << "iter = " << this->m_NumberOfIterations << std::endl << std::flush;
271  }
272 
273  // FISTA
274  utl::ProductUtlMv(*m_AtA, y, tmp);
275  utl::vSub(tmp.Size(), tmp.GetData(), m_Atb->GetData(), tmp.GetData());
276  utl::cblas_scal<double>(tmp.Size(), m_Step*2, tmp.GetData(), 1);
277  utl::vSub(y.Size(), y.GetData(), tmp.GetData(), xg.GetData());
278  // xg = y - m_Step*2*(tmp - *m_Atb);
279  for ( int j = 0; j < N; j += 1 )
280  {
281  xg[j] = xg[j]>w_new[j] ? (xg[j]-w_new[j]) : (xg[j]<-w_new[j] ? (xg[j]+w_new[j]) : 0 );
282  }
283  t_new = 0.5+0.5*std::sqrt(1+4*t*t);
284 
285  utl::vSub(xg.Size(), xg.GetData(), this->m_x.GetData(), tmp.GetData());
286  utl::cblas_scal<double>(xg.Size(), (t-1)/t_new, tmp.GetData(), 1);
287  utl::vAdd(xg.Size(), xg.GetData(), tmp.GetData(), y.GetData());
288  // y = xg + (t-1)/t_new * (xg-this->m_x);
289 
290  // // ISTA
291  // xg = m_x - m_Step*2*(m_AtA*m_x-m_Atb);
292  // for ( int j = 0; j < N; j += 1 )
293  //
294  // xg[j] = xg[j]>w_new[j] ? (xg[j]-w_new[j]) : (xg[j]<-w_new[j] ? (xg[j]+w_new[j]) : 0 );
295  // }
296 
297  // update
298  // this->m_x = xg;
299  utl::cblas_copy(this->m_x.Size(), xg.GetData(), 1, this->m_x.GetData(), 1);
300  t = t_new;
301 
302  this->m_NumberOfIterations++;
303 
304  HistoryUpdateAndConvergenceCheck();
305  if (this->m_UpdateInformation==Self::STOP_MIN_CHANGE)
306  break;
307  }
308 }
309 
310 template <class TPrecision>
311 void
313 ::PrintSelf(std::ostream& os, Indent indent) const
314 {
315  Superclass::PrintSelf(os, indent);
316  utl::PrintUtlMatrix(*m_A, "m_A", " ", os<<indent);
317  utl::PrintUtlVector(*m_b, "m_b", " ", os<<indent);
318  utl::PrintUtlVector(*m_w, "m_w", " ", os<<indent);
319  PrintVar2(true, m_Step, m_UseL2SolverForInitialization, os<<indent);
320  if (m_L2Solver)
321  m_L2Solver->Print(os<<indent<< "m_L2Solver for initialization = ");
322 }
323 
324 }
325 
326 
327 #endif
328 
329 
330 
331 
NDArray is a N-Dimensional array class (row-major, c version)
Definition: utlFunctors.h:131
void vMul(int n, T *vecIn, T *vecIn2, T *vecOut)
interface to v*Mul
Definition: utlCoreMKL.h:307
Base class for some optimization solvers using primal-dual updates.
void ProductUtlXtX(const utl::NDArray< T, 2 > &A, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
void vSub(int n, T *vecIn, T *vecIn2, T *vecOut)
interface to v*Sub
Definition: utlCoreMKL.h:279
void cblas_scal< double >(const INTT N, const double alpha, double *X, const INTT incX)
Definition: utlBlas.h:292
void cblas_copy(const INTT N, const T *X, const INTT incX, T *Y, const INTT incY)
helper functions specifically used in dmritool
ValueType EvaluateCostFunction(const VectorType &x=VectorType()) const ITK_OVERRIDE
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
void vAdd(int n, T *vecIn, T *vecIn2, T *vecOut)
interface to v*Add
Definition: utlCoreMKL.h:300
virtual LightObject::Pointer InternalClone() const ITK_OVERRIDE
T cblas_asum(const INTT N, const T *X, const INTT incX)
void Initialize(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
#define utlShowPosition(cond)
Definition: utlCoreMacro.h:554
void PrintUtlVector(const NDArray< T, 1 > &vec, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout, bool showStats=true)
SizeType Size() const
Definition: utlNDArray.h:321
T cblas_nrm2(const INTT N, const T *X, const INTT incX)
void ProductUtlMv(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 1 > &b, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
void Solve(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
#define PrintVar2(cond, var1, var2, os)
Definition: utlCoreMacro.h:454
Base class for some optimization solvers using primal-dual updates.
Definition: itkSolverBase.h:39
SmartPointer< Self > Pointer