18 #ifndef __itkL1RegularizedLeastSquaresFISTASolver_hxx 19 #define __itkL1RegularizedLeastSquaresFISTASolver_hxx 28 template <
class TPrecision >
40 m_L2Solver = L2SolverType::New();
41 m_UseL2SolverForInitialization =
false;
44 template <
class TPrecision >
49 itkDebugMacro(
"setting A to " << *mat);
50 if ( *this->m_A != *mat )
58 *this->m_At = this->m_A->GetTranspose();
61 m_Step = 0.5/m_AtA->GetTwoNorm();
64 utlException(m_A->Rows()!=m_b->Size(),
"wrong size of m_A");
69 if (m_UseL2SolverForInitialization)
70 m_L2Solver->SetA(m_A);
73 template <
class TPrecision >
78 itkDebugMacro(
"setting b to " << *b);
79 if ( *this->m_b != *b )
86 utlException(m_At->Columns()!=m_b->Size(),
"wrong size of m_A");
91 if (m_UseL2SolverForInitialization)
92 m_L2Solver->Setb(m_b);
95 template <
class TPrecision >
100 itkDebugMacro(
"setting w to " << *w);
101 if ( this->m_w != w )
108 template <
class TPrecision >
116 lambda->SetDiagonal(*w);
117 utlException(!m_UseL2SolverForInitialization,
"need to set m_UseL2SolverForInitialization");
118 m_L2Solver->SetLambda(lambda);
122 template <
class TPrecision >
127 Superclass::VerifyInputs();
128 int N = this->GetXDimension();
129 utlException(m_w->Size()!=N,
"wrong size of m_w!, m_w->Size()="<<m_w->Size()<<
", N="<<N);
130 utlException(m_A->Columns()!=N,
"wrong size of m_A!, m_A->Columns()="<<m_A->Columns()<<
", N="<<N);
131 utlException(m_A->Rows()!=m_b->Size(),
"wrong rows of m_A!, m_A->Rows()="<<m_A->Rows()<<
", m_b->Size()="<<m_b->Size());
134 template <
class TPrecision >
135 typename LightObject::Pointer
139 typename LightObject::Pointer loPtr = Superclass::InternalClone();
143 itkExceptionMacro(<<
"downcast to type " << this->GetNameOfClass()<<
" failed.");
151 rval->m_Step = m_Step;
152 rval->m_UseL2SolverForInitialization = m_UseL2SolverForInitialization;
153 rval->m_L2Solver = m_L2Solver->Clone();
157 template <
class TPrecision >
163 Superclass::Initialize(xInitial);
164 if (xInitial.
Size()==0)
166 utlException(!m_UseL2SolverForInitialization,
"need to set m_UseL2SolverForInitialization");
167 if (m_L2Solver->GetLambda()->Size()==0)
168 SetwForInitialization(this->m_w);
170 this->m_x = m_L2Solver->Getx();
173 if (this->GetDebug())
177 template <
class TPrecision >
183 this->VerifyInputs();
184 Initialize(xInitial);
189 template <
class TPrecision >
197 utl::vSub(e->Size(), e->GetData(), m_b->GetData(), e->GetData());
201 utl::vMul(e->Size(), m_w->GetData(), (
double*)xx->
GetData(), tmp.GetData());
207 template <
class TPrecision >
212 ValueType fValue = EvaluateCostFunction(), changePercentage=0, changePercentage_x=0;
213 this->m_CostFunction.push_back(fValue);
214 int size = this->m_CostFunction.size();
218 changePercentage = (this->m_CostFunction[size-2] - this->m_CostFunction[size-1])/this->m_CostFunction[size-2];
219 double xOldNorm =
utl::cblas_nrm2(this->m_xOld->Size(), this->m_xOld->GetData(), 1);
222 utl::vSub(this->m_x.Size(), this->m_x.GetData(), m_xOld->GetData(), tmp.GetData());
223 changePercentage_x =
utl::cblas_nrm2(tmp.Size(), tmp.GetData(), 1) / xOldNorm;
225 if (changePercentage <= this->m_MinRelativeChangeOfCostFunction && changePercentage>=0 && changePercentage_x<=this->m_MinRelativeChangeOfPrimalResidual)
227 this->m_UpdateInformation = Self::CONTINUE;
228 this->m_NumberOfChangeLessThanThreshold++;
232 this->m_UpdateInformation = Self::CONTINUE;
233 this->m_NumberOfChangeLessThanThreshold = 0;
235 if (this->m_NumberOfChangeLessThanThreshold == 3)
236 this->m_UpdateInformation = Self::STOP_MIN_CHANGE;
239 this->m_UpdateInformation = Self::CONTINUE;
248 template <
class TPrecision >
256 double t = 1.0, t_new=1.0, func=-1.0;
257 VectorType w_new = (*m_w) % m_Step, tmp(xg.Size());
258 this->m_CostFunction.push_back(EvaluateCostFunction());
260 int N = GetXDimension();
261 this->m_NumberOfIterations=0;
262 int num_change_less = 0;
263 m_xOld->ReSize(this->m_x.Size());
264 while ( this->m_NumberOfIterations <= this->m_MaxNumberOfIterations )
267 utl::cblas_copy(this->m_x.Size(), this->m_x.GetData(), 1, m_xOld->GetData(), 1);
268 if (this->GetDebug())
270 std::cout <<
"iter = " << this->m_NumberOfIterations << std::endl << std::flush;
275 utl::vSub(tmp.Size(), tmp.GetData(), m_Atb->GetData(), tmp.GetData());
279 for (
int j = 0; j < N; j += 1 )
281 xg[j] = xg[j]>w_new[j] ? (xg[j]-w_new[j]) : (xg[j]<-w_new[j] ? (xg[j]+w_new[j]) : 0 );
283 t_new = 0.5+0.5*std::sqrt(1+4*t*t);
285 utl::vSub(xg.Size(), xg.GetData(), this->m_x.GetData(), tmp.GetData());
299 utl::cblas_copy(this->m_x.Size(), xg.GetData(), 1, this->m_x.GetData(), 1);
302 this->m_NumberOfIterations++;
304 HistoryUpdateAndConvergenceCheck();
305 if (this->m_UpdateInformation==Self::STOP_MIN_CHANGE)
310 template <
class TPrecision>
315 Superclass::PrintSelf(os, indent);
319 PrintVar2(
true, m_Step, m_UseL2SolverForInitialization, os<<indent);
321 m_L2Solver->Print(os<<indent<<
"m_L2Solver for initialization = ");
NDArray is a N-Dimensional array class (row-major, c version)
void Setb(const VectorPointer &b)
void vMul(int n, T *vecIn, T *vecIn2, T *vecOut)
interface to v*Mul
void SetwForInitialization(const VectorPointer &w)
Base class for some optimization solvers using primal-dual updates.
void ProductUtlXtX(const utl::NDArray< T, 2 > &A, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
void vSub(int n, T *vecIn, T *vecIn2, T *vecOut)
interface to v*Sub
void cblas_scal< double >(const INTT N, const double alpha, double *X, const INTT incX)
void cblas_copy(const INTT N, const T *X, const INTT incX, T *Y, const INTT incY)
helper functions specifically used in dmritool
Superclass::MatrixPointer MatrixPointer
ValueType EvaluateCostFunction(const VectorType &x=VectorType()) const ITK_OVERRIDE
void VerifyInputs() const ITK_OVERRIDE
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
#define utlException(cond, expout)
void vAdd(int n, T *vecIn, T *vecIn2, T *vecOut)
interface to v*Add
L1RegularizedLeastSquaresFISTASolver()
Superclass::ValueType ValueType
Superclass::VectorPointer VectorPointer
virtual LightObject::Pointer InternalClone() const ITK_OVERRIDE
T cblas_asum(const INTT N, const T *X, const INTT incX)
void Initialize(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
void Iterate() ITK_OVERRIDE
#define utlShowPosition(cond)
void Setw(const VectorPointer &w)
void PrintUtlVector(const NDArray< T, 1 > &vec, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout, bool showStats=true)
void SetA(const MatrixPointer &mat)
T cblas_nrm2(const INTT N, const T *X, const INTT incX)
void HistoryUpdateAndConvergenceCheck() ITK_OVERRIDE
void ProductUtlMv(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 1 > &b, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
void Solve(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
#define PrintVar2(cond, var1, var2, os)
Base class for some optimization solvers using primal-dual updates.
SmartPointer< Self > Pointer