18 #ifndef __itkSphericalPolarFourierEstimationImageFilter_hxx 19 #define __itkSphericalPolarFourierEstimationImageFilter_hxx 27 template<
class TInputImage,
class TOutputImage >
35 m_EstimationType = LS;
36 m_LambdaSpherical = 0;
41 m_IsAnalyticalB0 =
true;
42 m_BasisEnergyPowerDL = 1.0;
44 m_MDImage=ScalarImageType::New();
45 m_ScaleImage=ScalarImageType::New();
48 m_L1SolverType = SPAMS;
51 m_L1FISTASolver = NULL;
52 m_L1SpamsSolver = NULL;
54 m_IsOriginalBasis =
true;
58 template<
class TInputImage,
class TOutputImage >
64 Superclass::VerifyInputParameters();
67 utlGlobalException(m_LambdaSpherical<0 || m_LambdaRadial<0,
"negatvie regularization parameters");
72 template<
class TInputImage,
class TOutputImage >
78 Superclass::GenerateOutputInformation();
80 unsigned int numberOfComponentsPerPixel = this->RankToDim();
81 outputPtr->SetNumberOfComponentsPerPixel(numberOfComponentsPerPixel);
84 template<
class TInputImage,
class TOutputImage >
93 template<
class TInputImage,
class TOutputImage >
99 double scale_old = m_BasisScale;
101 m_BasisScale = scale;
103 this->ComputeScale(
true);
104 itkDebugMacro(
"setting m_BasisScale to " << m_BasisScale);
106 if (scale>0 && std::fabs((scale_old-m_BasisScale)/m_BasisScale)>1e-8)
115 template<
class TInputImage,
class TOutputImage >
116 typename LightObject::Pointer
121 typename LightObject::Pointer loPtr = Superclass::InternalClone();
126 itkExceptionMacro(<<
"downcast to type " << this->GetNameOfClass()<<
" failed.");
128 rval->m_LambdaSpherical = m_LambdaSpherical;
129 rval->m_LambdaRadial = m_LambdaRadial;
130 rval->m_LambdaL1 = m_LambdaL1;
131 rval->m_LambdaL2 = m_LambdaL2;
132 rval->m_BasisScale = m_BasisScale;
133 rval->m_IsAnalyticalB0 = m_IsAnalyticalB0;
134 rval->m_B0Weight = m_B0Weight;
135 rval->m_EstimationType = m_EstimationType;
136 rval->m_L1SolverType = m_L1SolverType;
139 rval->m_BasisMatrixForB0 = m_BasisMatrixForB0;
140 rval->m_IsOriginalBasis = m_IsOriginalBasis;
142 rval->m_BasisCombinationMatrix = m_BasisCombinationMatrix;
143 rval->m_BasisEnergyDL = m_BasisEnergyDL;
144 rval->m_BasisEnergyPowerDL = m_BasisEnergyPowerDL;
146 rval->m_MDImage = m_MDImage;
147 rval->m_ScaleImage = m_ScaleImage;
150 rval->m_L2Solver = m_L2Solver->Clone();
152 rval->m_L1FISTASolver = m_L1FISTASolver->Clone();
154 rval->m_L1SpamsSolver = m_L1SpamsSolver->Clone();
159 template<
class TInputImage,
class TOutputImage >
165 if (this->m_SamplingSchemeQSpace->GetBVector()->size()>0 && this->m_SamplingSchemeQSpace->GetRadiusVector()->size()==0)
166 this->m_SamplingSchemeQSpace->ConvertBVectorToQVector();
168 this->VerifyInputParameters();
172 if (this->GetDebug() && this->GetNumberOfThreads()>1)
173 this->CreateLoggerVector();
175 if (this->GetDebug())
176 this->
Print(std::cout<<
"this BeforeThreadedGenerateData = ");
178 std::cout <<
"Use " << this->GetNumberOfThreads() <<
" threads!" << std::endl << std::flush;
180 std::cout <<
"Use a mask" << std::endl << std::flush;
182 if (this->m_EstimationType==Self::LS)
184 std::cout <<
"Use Least Square Estimation" << std::endl << std::flush;
185 if (!this->m_L2Solver)
186 this->m_L2Solver = L2SolverType::New();
188 else if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::L1_DL )
190 if (this->m_EstimationType==Self::L1_2)
191 std::cout <<
"Use L1 Estimation with two lambdas m_LambdaSpherical, m_LambdaRadial (L1_2)" << std::endl << std::flush;
192 if (this->m_EstimationType==Self::L1_DL)
193 std::cout <<
"Use L1 Estimation with learned dictionary, one m_LambdaL1" << std::endl << std::flush;
194 utlGlobalException(!this->m_L1FISTASolver && !this->m_L1SpamsSolver,
"must set a L1 solver first");
195 utlGlobalException( (this->m_L1SolverType==Self::FISTA_LS) && !this->m_L1FISTASolver,
"m_L1FISTASolver is needed.");
196 utlGlobalException( (this->m_L1SolverType==Self::SPAMS) && !this->m_L1SpamsSolver,
"m_L1SpamsSolver is needed.");
198 if (this->m_L1SolverType==Self::FISTA_LS)
200 std::cout <<
"Use FISTA with least square initialization (FISTA_LS)" << std::endl << std::flush;
201 this->m_L1FISTASolver->SetUseL2SolverForInitialization(
true);
203 else if (this->m_L1SolverType==Self::SPAMS)
204 std::cout <<
"Use SPAMS for weighted lasso (SPAMS)" << std::endl << std::flush;
210 template<
class TInputImage,
class TOutputImage >
216 #ifdef DMRITOOL_USE_OPENMP 222 if (this->m_L1SolverType==Self::SPAMS)
223 this->m_L1SpamsSolver->SetNumberOfThreads(1);
229 template<
class TInputImage,
class TOutputImage >
234 Superclass::PrintSelf(os, indent);
235 PrintVar2(
true, m_BasisScale, m_IsOriginalBasis, os<<indent);
236 PrintVar2(
true, m_IsAnalyticalB0, m_B0Weight, os<<indent);
237 if (m_BasisCombinationMatrix->Rows()!=0)
238 utl::PrintUtlMatrix(*m_BasisCombinationMatrix,
"m_BasisCombinationMatrix",
" ", os<<indent);
239 if (m_EstimationType==LS)
240 os << indent <<
"Least Square Estimation is used " << std::endl;
241 else if (m_EstimationType==L1_2)
242 os << indent <<
"L1 Estimation is used (two lambdas, Laplace-Beltrami like regularization)" << std::endl;
243 else if (m_EstimationType==L1_DL)
244 os << indent <<
"L1 Estimation with learned dictionary is used " << std::endl;
245 PrintVar4(
true, m_LambdaSpherical, m_LambdaRadial, m_LambdaL1, m_LambdaL2, os<<indent);
SphericalPolarFourierEstimationImageFilter()
void GenerateOutputInformation() ITK_OVERRIDE
utl_shared_ptr< MatrixType > MatrixPointer
base filter for estimation of diffusion models
helper functions specifically used in dmritool
Superclass::InputImagePointer InputImagePointer
bool IsImageEmpty(const SmartPointer< ImageType > &image)
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
SmartPointer< Self > Pointer
void InitializeThreadedLibraries() ITK_OVERRIDE
virtual double ComputeScale(const bool setScale=true)
#define utlGlobalException(cond, expout)
bool VerifyImageInformation(const SmartPointer< Image1Type > &image1, const SmartPointer< Image2Type > &image2, const bool isMinimalDimension=false)
estimate the coeffcients of generalized Spherical Polar Fourier basis which can be separated into dif...
Superclass::OutputImagePointer OutputImagePointer
static void InitializeThreadedLibraries(const int numThreads)
void BeforeThreadedGenerateData() ITK_OVERRIDE
#define itkShowPositionThreadedLogger(cond)
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
virtual void SetBasisScale(const double scale)
Superclass::InputImageType InputImageType
LightObject::Pointer InternalClone() const ITK_OVERRIDE
#define PrintVar4(cond, var1, var2, var3, var4, os)
#define PrintVar2(cond, var1, var2, os)
void VerifyInputParameters() const ITK_OVERRIDE