DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkSphericalPolarFourierEstimationImageFilter.hxx
Go to the documentation of this file.
1 
18 #ifndef __itkSphericalPolarFourierEstimationImageFilter_hxx
19 #define __itkSphericalPolarFourierEstimationImageFilter_hxx
20 
22 #include "utl.h"
23 
24 namespace itk
25 {
26 
27 template< class TInputImage, class TOutputImage >
30  m_BasisCombinationMatrix(new MatrixType()),
31  m_BasisEnergyDL(new VectorType()),
32  m_BasisMatrixForB0(new MatrixType())
33 {
34  m_BasisScale = -1.0;
35  m_EstimationType = LS;
36  m_LambdaSpherical = 0;
37  m_LambdaRadial = 0;
38  m_LambdaL1 = 0;
39  m_LambdaL2 = 0;
40  m_B0Weight = 1.0;
41  m_IsAnalyticalB0 = true;
42  m_BasisEnergyPowerDL = 1.0;
43 
44  m_MDImage=ScalarImageType::New();
45  m_ScaleImage=ScalarImageType::New();
46 
47  // m_L1SolverType = FISTA_LS;
48  m_L1SolverType = SPAMS;
49 
50  m_L2Solver = NULL;
51  m_L1FISTASolver = NULL;
52  m_L1SpamsSolver = NULL;
53 
54  m_IsOriginalBasis = true;
55 
56 }
57 
58 template< class TInputImage, class TOutputImage >
59 void
62 {
63  itkShowPositionThreadedLogger(this->GetDebug());
64  Superclass::VerifyInputParameters();
65 
66  utlGlobalException(m_BasisScale<0, "negative scale");
67  utlGlobalException(m_LambdaSpherical<0 || m_LambdaRadial<0, "negatvie regularization parameters");
68  InputImagePointer input = const_cast<InputImageType *>(this->GetInput());
69  utlGlobalException(!IsImageEmpty(m_MDImage) && !itk::VerifyImageInformation(input,m_MDImage,true), "wrong information in m_MDImage");
70 }
71 
72 template< class TInputImage, class TOutputImage >
73 void
76 {
77  itkShowPositionThreadedLogger(this->GetDebug());
78  Superclass::GenerateOutputInformation();
79  OutputImagePointer outputPtr = this->GetOutput();
80  unsigned int numberOfComponentsPerPixel = this->RankToDim();
81  outputPtr->SetNumberOfComponentsPerPixel(numberOfComponentsPerPixel);
82 }
83 
84 template< class TInputImage, class TOutputImage >
85 double
87 ::ComputeScale(const bool setScale)
88 {
89  m_BasisScale = -1.0;
90  return m_BasisScale;
91 }
92 
93 template< class TInputImage, class TOutputImage >
94 void
96 ::SetBasisScale(const double scale)
97 {
98  itkShowPositionThreadedLogger(this->GetDebug());
99  double scale_old = m_BasisScale;
100  if (scale>0)
101  m_BasisScale = scale;
102  else
103  this->ComputeScale(true);
104  itkDebugMacro("setting m_BasisScale to " << m_BasisScale);
105 
106  if (scale>0 && std::fabs((scale_old-m_BasisScale)/m_BasisScale)>1e-8)
107  {
108  this->Modified();
109  this->m_BasisRadialMatrix=MatrixPointer(new MatrixType());
110  this->m_BasisMatrix=MatrixPointer(new MatrixType());
111  this->m_BasisMatrixForB0=MatrixPointer(new MatrixType());
112  }
113 }
114 
115 template< class TInputImage, class TOutputImage >
116 typename LightObject::Pointer
119 {
120  itkShowPositionThreadedLogger(this->GetDebug());
121  typename LightObject::Pointer loPtr = Superclass::InternalClone();
122 
123  typename Self::Pointer rval = dynamic_cast<Self *>(loPtr.GetPointer());
124  if(rval.IsNull())
125  {
126  itkExceptionMacro(<< "downcast to type " << this->GetNameOfClass()<< " failed.");
127  }
128  rval->m_LambdaSpherical = m_LambdaSpherical;
129  rval->m_LambdaRadial = m_LambdaRadial;
130  rval->m_LambdaL1 = m_LambdaL1;
131  rval->m_LambdaL2 = m_LambdaL2;
132  rval->m_BasisScale = m_BasisScale;
133  rval->m_IsAnalyticalB0 = m_IsAnalyticalB0;
134  rval->m_B0Weight = m_B0Weight;
135  rval->m_EstimationType = m_EstimationType;
136  rval->m_L1SolverType = m_L1SolverType;
137 
138  // NOTE: shared_ptr is thread safe, if the data is read in threads (not modified), thus do not need to copy the data block
139  rval->m_BasisMatrixForB0 = m_BasisMatrixForB0;
140  rval->m_IsOriginalBasis = m_IsOriginalBasis;
141 
142  rval->m_BasisCombinationMatrix = m_BasisCombinationMatrix;
143  rval->m_BasisEnergyDL = m_BasisEnergyDL;
144  rval->m_BasisEnergyPowerDL = m_BasisEnergyPowerDL;
145 
146  rval->m_MDImage = m_MDImage;
147  rval->m_ScaleImage = m_ScaleImage;
148 
149  if (m_L2Solver)
150  rval->m_L2Solver = m_L2Solver->Clone();
151  if (m_L1FISTASolver)
152  rval->m_L1FISTASolver = m_L1FISTASolver->Clone();
153  if (m_L1SpamsSolver)
154  rval->m_L1SpamsSolver = m_L1SpamsSolver->Clone();
155 
156  return loPtr;
157 }
158 
159 template< class TInputImage, class TOutputImage >
160 void
163 {
164  itkShowPositionThreadedLogger(this->GetDebug());
165  if (this->m_SamplingSchemeQSpace->GetBVector()->size()>0 && this->m_SamplingSchemeQSpace->GetRadiusVector()->size()==0)
166  this->m_SamplingSchemeQSpace->ConvertBVectorToQVector();
167 
168  this->VerifyInputParameters();
169 
170 
171  // create m_LoggerVector for multiple threads
172  if (this->GetDebug() && this->GetNumberOfThreads()>1)
173  this->CreateLoggerVector();
174 
175  if (this->GetDebug())
176  this->Print(std::cout<<"this BeforeThreadedGenerateData = ");
177 
178  std::cout << "Use " << this->GetNumberOfThreads() << " threads!" << std::endl << std::flush;
179  if (!IsImageEmpty(this->m_MaskImage))
180  std::cout << "Use a mask" << std::endl << std::flush;
181 
182  if (this->m_EstimationType==Self::LS)
183  {
184  std::cout << "Use Least Square Estimation" << std::endl << std::flush;
185  if (!this->m_L2Solver)
186  this->m_L2Solver = L2SolverType::New();
187  }
188  else if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::L1_DL )
189  {
190  if (this->m_EstimationType==Self::L1_2)
191  std::cout << "Use L1 Estimation with two lambdas m_LambdaSpherical, m_LambdaRadial (L1_2)" << std::endl << std::flush;
192  if (this->m_EstimationType==Self::L1_DL)
193  std::cout << "Use L1 Estimation with learned dictionary, one m_LambdaL1" << std::endl << std::flush;
194  utlGlobalException(!this->m_L1FISTASolver && !this->m_L1SpamsSolver, "must set a L1 solver first");
195  utlGlobalException( (this->m_L1SolverType==Self::FISTA_LS) && !this->m_L1FISTASolver, "m_L1FISTASolver is needed.");
196  utlGlobalException( (this->m_L1SolverType==Self::SPAMS) && !this->m_L1SpamsSolver, "m_L1SpamsSolver is needed.");
197 
198  if (this->m_L1SolverType==Self::FISTA_LS)
199  {
200  std::cout << "Use FISTA with least square initialization (FISTA_LS)" << std::endl << std::flush;
201  this->m_L1FISTASolver->SetUseL2SolverForInitialization(true);
202  }
203  else if (this->m_L1SolverType==Self::SPAMS)
204  std::cout << "Use SPAMS for weighted lasso (SPAMS)" << std::endl << std::flush;
205  }
206  else
207  utlGlobalException(true, "wrong type");
208 }
209 
210 template< class TInputImage, class TOutputImage >
211 void
214 {
216 #ifdef DMRITOOL_USE_OPENMP
217  // if (this->m_L1SolverType==Self::SPAMS)
218  // this->m_L1SpamsSolver->SetNumberOfThreads(this->GetNumberOfThreads()>1?1:-1);
219 
220  // NOTE: it seems that OMP_NUM_THREAD=1 in SpamsSolver works better than multiple thread, even if this->GetNumberOfThreads()==1,
221  // it is maybe becuase when this->GetNumberOfThreads()==1, we already use multiple-threads for blas.
222  if (this->m_L1SolverType==Self::SPAMS)
223  this->m_L1SpamsSolver->SetNumberOfThreads(1);
224 #endif
225 
226 }
227 
228 
229 template< class TInputImage, class TOutputImage >
230 void
232 ::PrintSelf(std::ostream& os, Indent indent) const
233 {
234  Superclass::PrintSelf(os, indent);
235  PrintVar2(true, m_BasisScale, m_IsOriginalBasis, os<<indent);
236  PrintVar2(true, m_IsAnalyticalB0, m_B0Weight, os<<indent);
237  if (m_BasisCombinationMatrix->Rows()!=0)
238  utl::PrintUtlMatrix(*m_BasisCombinationMatrix, "m_BasisCombinationMatrix", " ", os<<indent);
239  if (m_EstimationType==LS)
240  os << indent << "Least Square Estimation is used " << std::endl;
241  else if (m_EstimationType==L1_2)
242  os << indent << "L1 Estimation is used (two lambdas, Laplace-Beltrami like regularization)" << std::endl;
243  else if (m_EstimationType==L1_DL)
244  os << indent << "L1 Estimation with learned dictionary is used " << std::endl;
245  PrintVar4(true, m_LambdaSpherical, m_LambdaRadial, m_LambdaL1, m_LambdaL2, os<<indent);
246 
247 }
248 
249 }
250 
251 #endif
252 
253 
254 
base filter for estimation of diffusion models
helper functions specifically used in dmritool
bool IsImageEmpty(const SmartPointer< ImageType > &image)
Definition: utlITK.h:435
void Print(Args...args)
Definition: utlCore11.h:87
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
#define utlGlobalException(cond, expout)
Definition: utlCoreMacro.h:372
bool VerifyImageInformation(const SmartPointer< Image1Type > &image1, const SmartPointer< Image2Type > &image2, const bool isMinimalDimension=false)
Definition: utlITK.h:850
estimate the coeffcients of generalized Spherical Polar Fourier basis which can be separated into dif...
static void InitializeThreadedLibraries(const int numThreads)
Definition: utl.h:327
#define itkShowPositionThreadedLogger(cond)
Definition: utlITKMacro.h:192
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
#define PrintVar4(cond, var1, var2, var3, var4, os)
Definition: utlCoreMacro.h:470
#define PrintVar2(cond, var1, var2, os)
Definition: utlCoreMacro.h:454