DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkSpamsWeightedLassoSolver.h
Go to the documentation of this file.
1 
19 #ifndef __itkSpamsWeightedLassoSolver_h
20 #define __itkSpamsWeightedLassoSolver_h
21 
22 #include "itkSolverBase.h"
23 
24 #include <linalg.h>
25 #include <decomp.h>
26 
27 namespace itk
28 {
29 
48 template <class TPrecision>
49 class ITK_EXPORT SpamsWeightedLassoSolver
50  : public SolverBase<TPrecision>
51 {
52 public:
57  typedef SmartPointer<Self> Pointer;
58 
60  itkNewMacro(Self);
61 
64 
65  typedef typename Superclass::ValueType ValueType;
66  typedef typename Superclass::MatrixType MatrixType;
67  typedef typename Superclass::VectorType VectorType;
69  typedef typename Superclass::VectorPointer VectorPointer;
70 
74  typedef typename utl_shared_ptr<spams::Matrix<double> > SpamsMatrixPointer;
75  typedef typename utl_shared_ptr<spams::SpMatrix<double> > SpamsSpMatrixPointer;
76  typedef typename utl_shared_ptr<spams::Vector<double> > SpamsVectorPointer;
77 
78  typedef enum
79  {
80  L1CONS=0,
84 
85  itkSetMacro(ConstraintType, ConstraintType);
86  itkGetMacro(ConstraintType, ConstraintType);
87 
88  itkSetMacro(NumberOfThreads, int);
89  itkGetMacro(NumberOfThreads, int);
90 
91  itkSetMacro(Lambda, double);
92  itkGetMacro(Lambda, double);
93 
94  void SetA(const MatrixPointer& mat);
95  itkGetMacro(A, MatrixPointer);
96  void SetW(const MatrixPointer& W);
97  void Setw(const VectorPointer& w);
98  itkGetMacro(W, MatrixPointer);
99  void SetB(const MatrixPointer& B);
100  void Setb(const VectorPointer& b);
101  itkGetMacro(B, MatrixPointer);
102 
103  itkGetMacro(X, MatrixPointer);
104 
105  itkSetMacro(Positive,bool);
106  itkGetMacro(Positive,bool);
107  itkBooleanMacro(Positive);
108 
109 
111  {
112  int N = m_A->Columns();
113  utlException(N==0, "wrong size! m_A.Columns()="<<m_A->Columns());
114  return N;
115  }
116  int GetXNumber() const
117  {
118  int M = m_B->Columns();
119  utlException(M==0, "wrong size! m_B.Columns()="<<m_B->Columns());
120  return M;
121  }
122 
124  {
125  Superclass::Clear();
126  ClearA();
127  ClearW();
128  ClearB();
129  }
130  void ClearA()
131  {
132  m_A=MatrixPointer(new MatrixType()); m_As=SpamsMatrixPointer(new SpamsMatrixType());
133  }
134  void ClearW()
135  {
136  m_W=MatrixPointer(new MatrixType()); m_Ws=SpamsMatrixPointer(new SpamsMatrixType());
137  }
138  void ClearB()
139  {
140  m_B=MatrixPointer(new MatrixType()); m_Bs=SpamsMatrixPointer(new SpamsMatrixType());
141  }
142 
143 
144  void VerifyInputs() const ITK_OVERRIDE;
145 
146  void Solve(const VectorType& xInitial=VectorType()) ITK_OVERRIDE;
147 
148  ValueType EvaluateCostFunctionInColumn(const VectorType& x, const int col) const;
149  ValueType EvaluateCostFunction(const MatrixType& x=MatrixType()) const ITK_OVERRIDE;
150 
151 
152 protected:
155 
156  void PrintSelf(std::ostream& os, Indent indent) const ITK_OVERRIDE;
157 
158  virtual typename LightObject::Pointer InternalClone() const ITK_OVERRIDE;
159 
161  MatrixPointer m_A;
162  MatrixPointer m_B;
163  MatrixPointer m_W;
164  MatrixPointer m_X;
165 
166  SpamsMatrixPointer m_As;
167  SpamsMatrixPointer m_Bs;
168  SpamsMatrixPointer m_Ws;
169  // SpamsMatrixType m_Xs;
170  SpamsSpMatrixPointer m_Xs;
171 
172  ConstraintType m_ConstraintType;
173 
175 
176  double m_Lambda;
178 
179 private:
180  SpamsWeightedLassoSolver(const Self&); //purposely not implemented
181  void operator=(const Self&); //purposely not implemented
182 
183 
184 };
185 
186 } // end namespace itk
187 
188 // Define instantiation macro for this template.
189 #define ITK_TEMPLATE_SpamsWeightedLassoSolver(_, EXPORT, TypeX, TypeY) \
190  namespace itk \
191  { \
192  _( 1 ( class EXPORT SpamsWeightedLassoSolver< ITK_TEMPLATE_1 TypeX > ) ) \
193  namespace Templates \
194  { \
195  typedef SpamsWeightedLassoSolver< ITK_TEMPLATE_1 TypeX > SpamsWeightedLassoSolver##TypeY; \
196  } \
197  }
198 
199 #if ITK_TEMPLATE_EXPLICIT
200 #include "Templates/itkSpamsWeightedLassoSolver+-.h"
201 #endif
202 
203 #if !defined(ITK_MANUAL_INSTANTIATION) && !defined(__itkSpamsWeightedLassoSolver_hxx)
205 #endif
206 
207 #endif
Sparse Matrix class.
Definition: linalg.h:63
TPrecision ValueType
Definition: itkSolverBase.h:51
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
Superclass::VectorPointer VectorPointer
#define ITK_OVERRIDE
Definition: utlITKMacro.h:46
Dense Vector class.
Definition: linalg.h:65
Contains sparse decomposition algorithms It requires the toolbox linalg.
spams::SpMatrix< double > SpamsSpMatrixType
utl_shared_ptr< spams::Vector< double > > SpamsVectorPointer
utl_shared_ptr< MatrixType > MatrixPointer
Definition: itkSolverBase.h:56
solve weighted LASSO using spams
Superclass::MatrixPointer MatrixPointer
Dense Matrix class.
Definition: linalg.h:61
Base class for some optimization solvers using primal-dual updates.
Definition: itkSolverBase.h:39
utl_shared_ptr< spams::Matrix< double > > SpamsMatrixPointer
utl_shared_ptr< spams::SpMatrix< double > > SpamsSpMatrixPointer