DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkSpamsWeightedLassoSolver.hxx
Go to the documentation of this file.
1 
19 #ifndef __itkSpamsWeightedLassoSolver_hxx
20 #define __itkSpamsWeightedLassoSolver_hxx
21 
23 #include "utlCore.h"
24 // #include "utlITKSpams.h"
25 #include "utlSpams.h"
26 
27 namespace itk
28 {
29 
30 template < class TPrecision >
33  m_A(new MatrixType()),
34  m_B(new MatrixType()),
35  m_W(new MatrixType()),
36  m_X(new MatrixType()),
37  m_As(new SpamsMatrixType()),
38  m_Bs(new SpamsMatrixType()),
39  m_Ws(new SpamsMatrixType()),
40  m_Xs(new SpamsSpMatrixType())
41 {
42  m_ConstraintType=PENALTY;
43  m_Positive = false;
44  m_Lambda = 1.0;
45  m_NumberOfThreads = -1;
46 }
47 
48 template < class TPrecision >
49 void
51 ::SetA (const MatrixPointer& mat)
52 {
53  itkDebugMacro("setting A to " << *mat);
54  if ( this->m_A != mat )
55  {
56  this->m_A = mat;
57  m_As = SpamsMatrixPointer(new SpamsMatrixType());
58  spams::UtlMatrixToMatrix(*m_A, *m_As);
59  this->Modified();
60  }
61 }
62 
63 template < class TPrecision >
64 void
66 ::SetB (const MatrixPointer& b)
67 {
68  itkDebugMacro("setting B to " << *b);
69  if ( this->m_B != b )
70  {
71  m_B = b;
72  m_Bs = SpamsMatrixPointer(new SpamsMatrixType());
73  spams::UtlMatrixToMatrix(*m_B, *m_Bs);
74  this->Modified();
75  }
76 }
77 
78 template < class TPrecision >
79 void
81 ::Setb (const VectorPointer& b)
82 {
83  itkDebugMacro("setting B to " << *b);
84  MatrixPointer B( new MatrixType(b->Size(),1));
85  B->SetColumn(0, *b);
86  if ( this->m_B != B )
87  {
88  m_B = B;
89  m_Bs = SpamsMatrixPointer(new SpamsMatrixType());
90  spams::UtlMatrixToMatrix(*m_B, *m_Bs);
91  this->Modified();
92  }
93 }
94 
95 template < class TPrecision >
96 void
98 ::SetW (const MatrixPointer& w)
99 {
100  itkDebugMacro("setting W to " << *w);
101  if ( this->m_W != w )
102  {
103  this->m_W = w;
104  m_Ws = SpamsMatrixPointer(new SpamsMatrixType());
105  spams::UtlMatrixToMatrix(*m_W, *m_Ws);
106  this->Modified();
107  }
108 }
109 
110 template < class TPrecision >
111 void
114 {
115  itkDebugMacro("setting W to " << *w);
116  MatrixPointer W( new MatrixType(w->Size(),1) );
117  W->SetColumn(0, *w);
118  if ( this->m_W != W )
119  {
120  m_W = W;
121  m_Ws = SpamsMatrixPointer(new SpamsMatrixType());
122  spams::UtlMatrixToMatrix(*m_W, *m_Ws);
123  this->Modified();
124  }
125 }
126 
127 
128 template < class TPrecision >
129 void
132 {
133  Superclass::VerifyInputs();
134  int N = this->GetXDimension();
135  int M = this->GetXNumber();
136  utlException(m_W->Columns()!=M, "wrong size of m_W!, m_W->Columns()="<<m_W->Columns()<<", M="<<M);
137  utlException(m_A->Columns()!=N, "wrong size of m_A!, m_A->Columns()="<<m_A->Columns()<<", N="<<N);
138  utlException(m_A->Rows()!=m_B->Rows(), "wrong rows of m_A!, m_A->Rows()="<<m_A->Rows()<<", m_B->Rows()="<<m_B->Rows());
139 }
140 
141 template < class TPrecision >
142 typename LightObject::Pointer
145 {
146  typename LightObject::Pointer loPtr = Superclass::InternalClone();
147  typename Self::Pointer rval = dynamic_cast<Self *>(loPtr.GetPointer());
148  if(rval.IsNull())
149  {
150  itkExceptionMacro(<< "downcast to type " << this->GetNameOfClass()<< " failed.");
151  }
152  rval->m_A = m_A;
153  rval->m_As = m_As;
154  rval->m_B = m_B;
155  rval->m_Bs = m_Bs;
156  rval->m_W = m_W;
157  rval->m_Ws = m_Ws;
158 
159  rval->m_X = m_X;
160  rval->m_Xs = m_Xs;
161  // if (m_Xs->m()*m_Xs->n()>0)
162  // rval->m_Xs->copy(*m_Xs);
163  // else
164  // rval->m_Xs= SpamsSpMatrixPointer(new SpamsSpMatrixType());
165 
166  rval->m_ConstraintType = m_ConstraintType;
167  rval->m_Lambda = m_Lambda;
168  rval->m_Positive = m_Positive;
169  rval->m_NumberOfThreads = m_NumberOfThreads;
170 
171  return loPtr;
172 }
173 
174 
175 template < class TPrecision >
176 void
178 ::Solve (const VectorType& )
179 {
180  utlShowPosition(this->GetDebug());
181  this->VerifyInputs();
182  // this->Initialize();
183  // SpamsSpMatrixType xs;
184  int N = this->GetXDimension();
185  int M = this->GetXNumber();
187  if (m_ConstraintType==L1CONS)
188  {
189  mode = spams::L1COEFFS;
190  utlGlobalException(true, "TODO");
191  }
192  else if (m_ConstraintType==L2CONS)
193  {
194  mode = spams::L2ERROR;
195  utlGlobalException(true, "TODO");
196  }
197  else if (m_ConstraintType==PENALTY)
198  {
199  mode = spams::PENALTY;
200  // m_Bs.print("m_Bs");
201  // m_As.print("m_As");
202  // m_Ws.print("m_Ws");
203  // std::cout << "m_Lambda = " << m_Lambda << std::endl << std::flush;
204  // std::cout << "utl::min(m_A.Columns(), m_A.Rows()) = " << utl::min<int>(m_A.Columns(), m_A.Rows()) << std::endl << std::flush;
205  // std::cout << "mode = " << mode << std::endl << std::flush;
206  // std::cout << "m_Positive = " << m_Positive << std::endl << std::flush;
208  spams::lassoWeight<double>(*m_Bs,*m_As,*m_Ws,*m_Xs,utl::min(m_A->Columns(), m_A->Rows()), 0.5*m_Lambda, mode,m_Positive,m_NumberOfThreads);
209  }
210  else
211  utlGlobalException(true, "wrong m_ConstraintType");
212  m_X= MatrixPointer(new MatrixType());
213  spams::SpMatrixToUtlMatrix(*m_Xs, *m_X);
214  if (M==1)
215  this->m_x = m_X->GetColumn(0);
216  // utl::PrintUtlMatrix(m_X, "m_X");
217  // utl::PrintUtlVector(this->m_x, "m_x");
218 }
219 
220 template < class TPrecision >
223 ::EvaluateCostFunctionInColumn (const VectorType& x, const int col) const
224 {
225  utlException(x.Size()==0, "need to give a vector");
226  const VectorType* xx = &x;
227  VectorType e = (*m_A) * (*xx)-m_B->GetColumn(col);
228  ValueType func = e.GetSquaredTwoNorm() + m_Lambda* utl::ToVector<double>(m_W->GetColumn(col) % (*xx))->GetOneNorm();
229  return func;
230 }
231 
232 template < class TPrecision >
236 {
237  ValueType func=0;
238  const MatrixType* xx = (x.Size()!=0? (&x) : (this->m_X.get()));
239  for ( int i = 0; i < xx->Columns(); i += 1 )
240  {
241  VectorType vec;
242  vec = xx->GetColumn(i);
243  func += EvaluateCostFunctionInColumn(vec, i);
244  }
245  return func;
246 }
247 
248 
249 template <class TPrecision>
250 void
252 ::PrintSelf(std::ostream& os, Indent indent) const
253 {
254  Superclass::PrintSelf(os, indent);
255  PrintVar3(true, m_ConstraintType, m_Positive, m_Lambda, os<<indent);
256  utl::PrintUtlMatrix(*m_A, "m_A", " ", os<<indent);
257  utl::PrintUtlMatrix(*m_B, "m_B", " ", os<<indent);
258  utl::PrintUtlMatrix(*m_W, "m_W", " ", os<<indent);
259  utl::PrintUtlMatrix(*m_X, "m_X", " ", os<<indent);
260 
261  // m_As.print("m_As");
262  // m_Bs.print("m_Bs");
263  // m_Ws.print("m_Ws");
264  // m_Xs.print("m_Xs");
265 }
266 
267 }
268 
269 #endif
270 
271 
272 
NDArray is a N-Dimensional array class (row-major, c version)
Definition: utlFunctors.h:131
Sparse Matrix class.
Definition: linalg.h:63
TPrecision ValueType
Definition: itkSolverBase.h:51
constraint_type
Definition: decomp.h:88
ValueType EvaluateCostFunction(const MatrixType &x=MatrixType()) const ITK_OVERRIDE
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
ValueType EvaluateCostFunctionInColumn(const VectorType &x, const int col) const
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
virtual LightObject::Pointer InternalClone() const ITK_OVERRIDE
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
const T & min(const T &a, const T &b)
Return the minimum between a and b.
Definition: utlCore.h:257
#define utlGlobalException(cond, expout)
Definition: utlCoreMacro.h:372
double GetSquaredTwoNorm() const
Definition: utlNDArray.h:1017
SmartPointer< Self > Pointer
Definition: itkSolverBase.h:45
void Solve(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
void SetA(const MatrixPointer &mat)
utl_shared_ptr< VectorType > VectorPointer
Definition: itkSolverBase.h:57
#define PrintVar3(cond, var1, var2, var3, os)
Definition: utlCoreMacro.h:462
utl_shared_ptr< MatrixType > MatrixPointer
Definition: itkSolverBase.h:56
void SpMatrixToUtlMatrix(const SpMatrix< T > &mat, utl::NDArray< T, 2 > &result)
Definition: utlSpams.h:82
#define utlShowPosition(cond)
Definition: utlCoreMacro.h:554
void UtlMatrixToMatrix(const utl::NDArray< T, 2 > &matUtl, Matrix< T > &matSpams)
Definition: utlSpams.h:28
SizeType Size() const
Definition: utlNDArray.h:321
Dense Matrix class.
Definition: linalg.h:61
Base class for some optimization solvers using primal-dual updates.
Definition: itkSolverBase.h:39
utl_shared_ptr< spams::Matrix< double > > SpamsMatrixPointer
utl_shared_ptr< spams::SpMatrix< double > > SpamsSpMatrixPointer