19 #ifndef __itkL2RegularizedLeastSquaresSolver_hxx 20 #define __itkL2RegularizedLeastSquaresSolver_hxx 31 template <
class TPrecision>
39 m_ConditionNumber = -1;
41 m_IsLambdaSymmetric =
true;
44 template <
class TPrecision>
49 itkDebugMacro(
"setting A to " << *mat);
50 if ( *this->m_A != *mat )
61 template <
class TPrecision>
66 itkDebugMacro(
"setting Lambda to " << *mat);
67 if ( *this->m_Lambda != *mat )
70 *this->m_Lambda = *mat;
72 m_IsLambdaSymmetric = m_Lambda->IsSymmetric();
78 template <
class TPrecision>
83 Superclass::VerifyInputs();
84 int N = GetXDimension();
87 utlGlobalException(M!=m_b->Size(),
"wrong size of m_b! m_A->rows=()"<<m_A->Rows() <<
", m_b->Size()="<<m_b->Size());
88 if (m_Lambda->Size()>0)
90 utlGlobalException(m_Lambda->Rows()!=m_Lambda->Columns(),
"m_Lambda needs to be square");
91 utlGlobalException(m_Lambda->Rows()!=N,
"wrong size of m_Lambda! m_Lambda->Rows()="<<m_Lambda->Rows() <<
", N="<<N);
95 template <
class TPrecision>
107 template <
class TPrecision>
112 Superclass::Initialize(xInitial);
118 if (m_Lambda->Size()>0)
121 m_ConditionNumber = m_LS->GetInfNorm();
123 if (m_IsLambdaSymmetric)
124 *tmp = m_LS->PInverseSymmericMatrix();
126 *tmp = m_LS->PInverseMatrix();
127 m_ConditionNumber *= tmp->GetInfNorm();
132 template <
class TPrecision>
138 Initialize(xInitial);
142 template <
class TPrecision>
150 ValueType cost = utl::ToVector<double>(tmp - *m_b)->GetSquaredTwoNorm();
151 if (m_Lambda->Size()>0)
159 template <
class TPrecision >
160 typename LightObject::Pointer
164 typename LightObject::Pointer loPtr = Superclass::InternalClone();
168 itkExceptionMacro(<<
"downcast to type " << this->GetNameOfClass()<<
" failed.");
172 *rval->m_Lambda = *m_Lambda;
174 rval->m_IsLambdaSymmetric = m_IsLambdaSymmetric;
175 rval->m_ConditionNumber = m_ConditionNumber;
179 template <
class TPrecision>
184 Superclass::PrintSelf(os, indent);
190 os << indent <<
"m_ConditionNumber = " << m_ConditionNumber << std::endl << std::flush;
191 os << indent <<
"m_IsLambdaSymmetric = " << m_IsLambdaSymmetric << std::endl << std::flush;
NDArray is a N-Dimensional array class (row-major, c version)
Superclass::ValueType ValueType
Superclass::MatrixPointer MatrixPointer
SmartPointer< Self > Pointer
void ProductUtlXtX(const utl::NDArray< T, 2 > &A, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
helper functions specifically used in dmritool
void SetLambda(const MatrixPointer &mat)
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
virtual LightObject::Pointer InternalClone() const ITK_OVERRIDE
void Clear() ITK_OVERRIDE
void ProductUtlvM(const utl::NDArray< T, 1 > &b, const utl::NDArray< T, 2 > &A, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
void Initialize(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
ValueType EvaluateCostFunction(const VectorType &x=VectorType()) const ITK_OVERRIDE
#define utlGlobalException(cond, expout)
L2RegularizedLeastSquaresSolver()
void VerifyInputs() const ITK_OVERRIDE
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
solve least square problem with L2 regularization
ValueType InnerProduct(const NDArrayBase< T, Dim > &vec) const
void Solve(const VectorType &xInitial=VectorType()) ITK_OVERRIDE
void ProductUtlMMt(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 2 > &B, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
void PrintUtlVector(const NDArray< T, 1 > &vec, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout, bool showStats=true)
void SetA(const MatrixPointer &mat)
void ProductUtlMv(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 1 > &b, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
Base class for some optimization solvers using primal-dual updates.