DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkL1RegularizedLeastSquaresFISTASolver.h
Go to the documentation of this file.
1 
18 #ifndef __itkL1RegularizedLeastSquaresFistaSolver_h
19 #define __itkL1RegularizedLeastSquaresFistaSolver_h
20 
21 #include "itkIterativeSolverBase.h"
23 
24 namespace itk
25 {
26 
43 template <class TPrecision>
44 class ITK_EXPORT L1RegularizedLeastSquaresFISTASolver : public IterativeSolverBase<TPrecision>
45 {
46 public:
51  typedef SmartPointer<Self> Pointer;
52 
54  itkNewMacro(Self);
55 
58 
59  typedef typename Superclass::ValueType ValueType;
62  typedef typename Superclass::MatrixPointer MatrixPointer;
63  typedef typename Superclass::VectorPointer VectorPointer;
64 
66  typedef typename Superclass::UpdateInfomationType UpdateInfomationType;
68 
69  void SetA(const MatrixPointer& mat);
70  itkGetMacro(A, MatrixPointer);
71  void Setw(const VectorPointer& w);
72  void SetwForInitialization(const VectorPointer& w);
73  itkGetMacro(w, VectorPointer);
74  void Setb(const VectorPointer& b);
75  itkGetMacro(b, VectorPointer);
76 
77  itkSetMacro(UseL2SolverForInitialization,bool);
78  itkGetMacro(UseL2SolverForInitialization,bool);
79  itkBooleanMacro(UseL2SolverForInitialization);
80 
82  {
83  int N = m_A->Columns();
84  utlException(N==0, "wrong size! m_A->Columns()="<<m_A->Columns());
85  return N;
86  }
87 
89  {
90  Superclass::Clear();
91  m_A=MatrixPointer(new MatrixType());
92  m_AtA=MatrixPointer(new MatrixType());
93  m_Atb=VectorPointer(new VectorType());
94  m_b=VectorPointer(new VectorType());
95  m_w=VectorPointer(new VectorType());
96  m_L2Solver->Clear();
97  }
98  void ClearA()
99  {
100  m_A=MatrixPointer(new MatrixType());
101  m_AtA=MatrixPointer(new MatrixType());
102  m_Atb=VectorPointer(new VectorType());
103  m_L2Solver->ClearA();
104  }
105  void Clearw()
106  {
107  m_w=VectorPointer(new VectorType());
108  m_L2Solver->ClearLamabda();
109  }
110  void Clearb()
111  {
112  m_b=VectorPointer(new VectorType());
113  m_Atb=VectorPointer(new VectorType());
114  m_L2Solver->Clearb();
115  }
116 
118  void HistoryUpdateAndConvergenceCheck() ITK_OVERRIDE;
119 
120  void VerifyInputs() const ITK_OVERRIDE;
121 
122  void Solve(const VectorType& xInitial=VectorType()) ITK_OVERRIDE;
123  void Iterate() ITK_OVERRIDE;
124  void Initialize(const VectorType& xInitial=VectorType()) ITK_OVERRIDE;
125 
126  ValueType EvaluateCostFunction(const VectorType& x=VectorType()) const ITK_OVERRIDE;
127 
128 
129 protected:
132 
133  void PrintSelf(std::ostream& os, Indent indent) const ITK_OVERRIDE;
134 
135  virtual typename LightObject::Pointer InternalClone() const ITK_OVERRIDE;
136 
138  MatrixPointer m_A;
140  VectorPointer m_b;
141 
143  VectorPointer m_w;
144 
146 
147 private:
148  L1RegularizedLeastSquaresFISTASolver(const Self&); //purposely not implemented
149  void operator=(const Self&); //purposely not implemented
150 
151  MatrixPointer m_At;
152  MatrixPointer m_AtA;
153  VectorPointer m_Atb;
154  double m_Step;
155 
156  // private members
157  VectorPointer m_xOld;
158 
160 };
161 
162 } // end namespace itk
163 
164 // Define instantiation macro for this template.
165 #define ITK_TEMPLATE_L1RegularizedLeastSquaresFISTASolver(_, EXPORT, TypeX, TypeY) \
166  namespace itk \
167  { \
168  _( 1 ( class EXPORT L1RegularizedLeastSquaresFISTASolver< ITK_TEMPLATE_1 TypeX > ) ) \
169  namespace Templates \
170  { \
171  typedef L1RegularizedLeastSquaresFISTASolver< ITK_TEMPLATE_1 TypeX > L1RegularizedLeastSquaresFISTASolver##TypeY; \
172  } \
173  }
174 
175 #if ITK_TEMPLATE_EXPLICIT
176 #include "Templates/itkL1RegularizedLeastSquaresFISTASolver+-.h"
177 #endif
178 
179 #if !defined(ITK_MANUAL_INSTANTIATION) && !defined(__itkL1RegularizedLeastSquaresFISTASolver_hxx)
181 #endif
182 
183 #endif
184 
NDArray is a N-Dimensional array class (row-major, c version)
Definition: utlFunctors.h:131
Base class for some optimization solvers using primal-dual updates.
solve least square problem with L1 regularization using FISTA
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
#define ITK_OVERRIDE
Definition: utlITKMacro.h:46
solve least square problem with L2 regularization
L2RegularizedLeastSquaresSolver< TPrecision > L2SolverType
std::vector< ValueType > ValueContainerType
Definition: itkSolverBase.h:58