19 #ifndef __itkSpamsWeightedLassoSolver_hxx 20 #define __itkSpamsWeightedLassoSolver_hxx 30 template <
class TPrecision >
45 m_NumberOfThreads = -1;
48 template <
class TPrecision >
53 itkDebugMacro(
"setting A to " << *mat);
54 if ( this->m_A != mat )
63 template <
class TPrecision >
68 itkDebugMacro(
"setting B to " << *b);
78 template <
class TPrecision >
83 itkDebugMacro(
"setting B to " << *b);
95 template <
class TPrecision >
100 itkDebugMacro(
"setting W to " << *w);
101 if ( this->m_W != w )
110 template <
class TPrecision >
115 itkDebugMacro(
"setting W to " << *w);
118 if ( this->m_W != W )
128 template <
class TPrecision >
133 Superclass::VerifyInputs();
134 int N = this->GetXDimension();
135 int M = this->GetXNumber();
136 utlException(m_W->Columns()!=M,
"wrong size of m_W!, m_W->Columns()="<<m_W->Columns()<<
", M="<<M);
137 utlException(m_A->Columns()!=N,
"wrong size of m_A!, m_A->Columns()="<<m_A->Columns()<<
", N="<<N);
138 utlException(m_A->Rows()!=m_B->Rows(),
"wrong rows of m_A!, m_A->Rows()="<<m_A->Rows()<<
", m_B->Rows()="<<m_B->Rows());
141 template <
class TPrecision >
142 typename LightObject::Pointer
146 typename LightObject::Pointer loPtr = Superclass::InternalClone();
150 itkExceptionMacro(<<
"downcast to type " << this->GetNameOfClass()<<
" failed.");
166 rval->m_ConstraintType = m_ConstraintType;
167 rval->m_Lambda = m_Lambda;
168 rval->m_Positive = m_Positive;
169 rval->m_NumberOfThreads = m_NumberOfThreads;
175 template <
class TPrecision >
181 this->VerifyInputs();
184 int N = this->GetXDimension();
185 int M = this->GetXNumber();
187 if (m_ConstraintType==L1CONS)
192 else if (m_ConstraintType==L2CONS)
197 else if (m_ConstraintType==
PENALTY)
208 spams::lassoWeight<double>(*m_Bs,*m_As,*m_Ws,*m_Xs,
utl::min(m_A->Columns(), m_A->Rows()), 0.5*m_Lambda, mode,m_Positive,m_NumberOfThreads);
215 this->m_x = m_X->GetColumn(0);
220 template <
class TPrecision >
227 VectorType e = (*m_A) * (*xx)-m_B->GetColumn(col);
232 template <
class TPrecision >
239 for (
int i = 0; i < xx->Columns(); i += 1 )
242 vec = xx->GetColumn(i);
243 func += EvaluateCostFunctionInColumn(vec, i);
249 template <
class TPrecision>
254 Superclass::PrintSelf(os, indent);
255 PrintVar3(
true, m_ConstraintType, m_Positive, m_Lambda, os<<indent);
NDArray is a N-Dimensional array class (row-major, c version)
ValueType EvaluateCostFunction(const MatrixType &x=MatrixType()) const ITK_OVERRIDE
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
ValueType EvaluateCostFunctionInColumn(const VectorType &x, const int col) const
#define utlException(cond, expout)
void VerifyInputs() const ITK_OVERRIDE
virtual LightObject::Pointer InternalClone() const ITK_OVERRIDE
SpamsWeightedLassoSolver()
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
void SetW(const MatrixPointer &W)
const T & min(const T &a, const T &b)
Return the minimum between a and b.
#define utlGlobalException(cond, expout)
double GetSquaredTwoNorm() const
SmartPointer< Self > Pointer
Superclass::ValueType ValueType
void Solve(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
void Setw(const VectorPointer &w)
void SetA(const MatrixPointer &mat)
utl_shared_ptr< VectorType > VectorPointer
#define PrintVar3(cond, var1, var2, var3, os)
utl_shared_ptr< MatrixType > MatrixPointer
void SpMatrixToUtlMatrix(const SpMatrix< T > &mat, utl::NDArray< T, 2 > &result)
#define utlShowPosition(cond)
void UtlMatrixToMatrix(const utl::NDArray< T, 2 > &matUtl, Matrix< T > &matSpams)
void SetB(const MatrixPointer &B)
void Setb(const VectorPointer &b)
Base class for some optimization solvers using primal-dual updates.
utl_shared_ptr< spams::Matrix< double > > SpamsMatrixPointer
utl_shared_ptr< spams::SpMatrix< double > > SpamsSpMatrixPointer