DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
mexSphericalPolarFourierImaging.cxx
Go to the documentation of this file.
1 
18 #include "mex.h"
19 #include "utl.h"
20 #include "utlMEX.h"
21 
23 #include "itkDWIReader.h"
29 
30 template <typename T>
31  inline void callFunction(mxArray* plhs[], const mxArray* prhs[],
32  const int nlhs,const int nrhs)
33 {
34  utlGlobalException(!utl::mexCheckType<T>(prhs[0]),"type of argument 1 is not consistent");
35  utlGlobalException(!utl::mexCheckType<T>(prhs[1]),"type of argument 2 is not consistent");
36  utlGlobalException(!utl::mexCheckType<T>(prhs[2]),"type of argument 3 is not consistent");
37 
38  typedef itk::VectorImage<T, 3> VectorImageType;
39  typedef itk::Image<double, 3> ImageType;
41 
42  utlException(mxGetNumberOfDimensions(prhs[0])!=4, "the input should have 4 dimension");
43 
44  int verbose = utl::GetScalarStructDef<int>(prhs[3],"verbose",1);
45  utl::LogLevel = verbose;
46 
47  const mwSize* dimsDWIs = mxGetDimensions(prhs[0]);
48  int Nx = static_cast<int>(dimsDWIs[0]);
49  int Ny = static_cast<int>(dimsDWIs[1]);
50  int Nz = static_cast<int>(dimsDWIs[2]);
51  int numberOfDWIs = static_cast<int>(dimsDWIs[3]);
52 
53 
54  const mwSize* dimsOrientation = mxGetDimensions(prhs[1]);
55  utlException(dimsOrientation[0]!=numberOfDWIs, "wrong number of gradients");
56  utlException(dimsOrientation[1]!=3, "the column of gradient should be 3");
57 
58  const mwSize* dimsBVec = mxGetDimensions(prhs[2]);
59  utlException(dimsBVec[0]!=numberOfDWIs, "wrong number of bVec");
60  utlException(dimsBVec[1]!=1, "the column of bVec should be 1");
61 
62  double MD0 = utl::GetScalarStructDef<double>(prhs[3],"MD0",-1.0);
63  double tau = utl::GetScalarStructDef<double>(prhs[3],"tau",ONE_OVER_4_PI_2);
64  double scale = utl::GetScalarStructDef<double>(prhs[3],"scale",-1.0);
65  int sh = utl::GetScalarStructDef<int>(prhs[3],"sh",-1);
66  utlGlobalException(sh<=0, "need to set sh");
67  int ra = utl::GetScalarStructDef<int>(prhs[3],"ra",-1);
68  utlGlobalException(ra<=0, "need to set ra");
69  std::string estimation = utl::GetScalarStructDef<std::string>(prhs[3],"estimation","LS");
70  std::string solver = utl::GetScalarStructDef<std::string>(prhs[3],"solver","SPAMS");
71  double lambdaSH = utl::GetScalarStructDef<double>(prhs[3],"lambdaSH",0.0);
72  double lambdaRA = utl::GetScalarStructDef<double>(prhs[3],"lambdaRA",0.0);
73  double lambdaL1 = utl::GetScalarStructDef<double>(prhs[3],"lambdaL1",0.0);
74  // int odfOrder = utl::GetScalarStructDef<double>(prhs[3],"odfOrder",2);
75  // double radius = utl::GetScalarStructDef<double>(prhs[3],"radius",0.015);
76  mxArray* mdImageArray = utl::GetArrayStruct(prhs[3], "mdImage" );
77  double numericalB0Weight = utl::GetScalarStructDef<double>(prhs[3],"numericalB0Weight",-1.0);
78  mxArray* dictionaryArray = utl::GetArrayStruct(prhs[3], "dictionary" );
79  mxArray* energyArray = utl::GetArrayStruct(prhs[3], "energy" );
80  double energyPower = utl::GetScalarStructDef<double>(prhs[3],"energyPower",1.0);
81  int maxIter = utl::GetScalarStructDef<int>(prhs[3],"maxIter",1000);
82  double minChange = utl::GetScalarStructDef<double>(prhs[3],"minChange",0.0001);
83  mxArray* maskArray = utl::GetArrayStruct(prhs[3], "mask" );
84  int thread = utl::GetScalarStructDef<int>(prhs[3],"thread",-1.0);
85 
86 
87  typename SPFIFilterBaseType::Pointer spfiFilter=NULL;
89  spfiFilter = SPFIFilterType::New();
90  std::cout << "Use SPF basis" << std::endl << std::flush;
91  spfiFilter->SetIsOriginalBasis(true);
92 
93  if (maskArray)
94  {
95  typename ImageType::Pointer maskImage = ImageType::New();
96  itk::GetITKImageFromMXArray(maskArray, maskImage);
97  spfiFilter->SetMaskImage(maskImage);
98  }
99  typename VectorImageType::Pointer dwiImage = VectorImageType::New();
100  itk::GetITKVectorImageFromMXArray(prhs[0], dwiImage);
101  spfiFilter->SetInput(dwiImage);
102 
103  if (MD0>0)
104  spfiFilter->SetMD0(MD0);
105  //NOTE: set tau before spfiFilter->GetSamplingSchemeQSpace()->SetBVector(bVec) because it uses tau to convert b values to q values
106  if (tau>0)
107  spfiFilter->GetSamplingSchemeQSpace()->SetTau(tau);
108 
109  typename SPFIFilterBaseType::MatrixPointer grad(new typename SPFIFilterBaseType::MatrixType());
110  utl::GetUtlMatrixFromMXArray(prhs[1], grad.get() );
111  *grad = utl::CartesianToSpherical(*grad); // convert to spherical format
112  spfiFilter->GetSamplingSchemeQSpace()->SetOrientationsSpherical(grad);
113 
114  typename SPFIFilterBaseType::STDVectorPointer bVec(new typename SPFIFilterBaseType::STDVectorType());
115  utl::GetSTDVectorFromMXArray(prhs[2], bVec.get() );
116  spfiFilter->GetSamplingSchemeQSpace()->SetBVector(bVec);
117  spfiFilter->SetSHRank(sh);
118  spfiFilter->SetRadialRank(ra);
119  spfiFilter->SetBasisScale(scale);
120  spfiFilter->SetIsAnalyticalB0(true);
121 
122  spfiFilter->SetLambdaSpherical(lambdaSH);
123  spfiFilter->SetLambdaRadial(lambdaRA);
124  spfiFilter->SetLambdaL1(lambdaL1);
125 
126  if (estimation=="LS")
127  {
128  spfiFilter->SetEstimationType(SPFIFilterBaseType::LS);
129  }
130  else if (estimation=="L1_2" || estimation=="L1_DL")
131  {
132 
133  if (estimation=="L1_2")
134  {
135  utlGlobalException(lambdaSH<=0 && lambdaRA<=0, "need to set lambdaSH and lambdaRA when estimation=\"L1_2\".");
136  spfiFilter->SetEstimationType(SPFIFilterBaseType::L1_2);
137  }
138  else if (estimation=="L1_DL")
139  {
140  utlGlobalException(lambdaL1<=0, "need to set lambdaL1 when estimation=\"L1_DL\".");
141  spfiFilter->SetEstimationType(SPFIFilterBaseType::L1_DL);
142  }
143 
144  if (solver=="FISTA_LS")
145  {
147  L1SolverType::Pointer l1Sol = L1SolverType::New();
148  l1Sol->SetUseL2SolverForInitialization(solver=="FISTA_LS");
149  l1Sol->SetMaxNumberOfIterations(maxIter);
150  l1Sol->SetMinRelativeChangeOfCostFunction(minChange);
151  l1Sol->SetMinRelativeChangeOfPrimalResidual(minChange);
152  spfiFilter->SetL1FISTASolver(l1Sol);
153  spfiFilter->SetL1SolverType(SPFIFilterBaseType::FISTA_LS);
154  }
155  if (solver=="SPAMS")
156  {
157  typedef itk::SpamsWeightedLassoSolver<double> L1SolverType;
158  L1SolverType::Pointer l1Sol = L1SolverType::New();
159  spfiFilter->SetL1SpamsSolver(l1Sol);
160  spfiFilter->SetL1SolverType(SPFIFilterBaseType::SPAMS);
161  }
162  }
163  else
164  utlGlobalException(true, "wrong estimation type");
165 
166  if (mdImageArray)
167  {
168  typename ImageType::Pointer mdImage = ImageType::New();
169  itk::GetITKImageFromMXArray(mdImageArray, mdImage);
170  spfiFilter->SetMDImage(mdImage);
171  }
172 
173  // spfiFilter->SetBasisEnergyPowerDL(energyPower);
174  spfiFilter->SetBasisEnergyPowerDL(1.0);
175 
177  if (thread>0)
178  spfiFilter->SetNumberOfThreads(thread);
179  spfiFilter->SetDebug(utl::LogLevel>=LOG_DEBUG);
180  spfiFilter->SetLogLevel(utl::LogLevel);
181 
182  std::cout << "SPF estimation starts" << std::endl << std::flush;
183  spfiFilter->Update();
184  std::cout << "SPF estimation ends" << std::endl << std::flush;
185 
186  typename VectorImageType::Pointer spf = spfiFilter->GetOutput();
187  itk::GetMXArrayFromITKVectorImage(spf, plhs[0]);
188 
189  // if (nlhs==2)
190  // {
191  // typename ImageType::Pointer scaleImage = spfiFilter->GetScaleImage();
192  // itk::GetMXArrayFromITKImage(scaleImage, plhs[1]);
193  // }
194 }
195 
196 
197 void mexFunction(int nlhs, mxArray *plhs[],
198  int nrhs, const mxArray *prhs[])
199 {
200  utlGlobalException(nrhs!=4, "Bad number of inputs arguments");
201  utlGlobalException(nlhs!=1, "Bad number of outputs arguments");
202 
203  // if (mxGetClassID(prhs[0]) == mxSINGLE_CLASS)
204  // callFunction<float>(plhs,prhs,nlhs,nrhs);
205  // else
206  callFunction<double>(plhs,prhs,nlhs,nrhs);
207 }
mxArray * GetArrayStruct(const mxArray *pr_struct, const char *name)
Definition: mexutils.h:333
void solver(const Matrix< T > &X, const AbstractMatrixB< T > &D, const Matrix< T > &alpha0, Matrix< T > &alpha, const ParamFISTA< T > &param1, Matrix< T > &optim_info, const GraphStruct< T > *graph_st=NULL, const TreeStruct< T > *tree_st=NULL, const GraphPathStruct< T > *graph_path_st=NULL)
AbstractMatrixB is basically either SpMatrix or Matrix.
Definition: fista.h:3436
helper functions specifically used in dmritool
solve least square problem with L1 regularization using FISTA
void GetMXArrayFromITKVectorImage(const SmartPointer< VectorImage< T, VImageDimension > > &image, mxArray *&pr)
Definition: mexITK.h:96
used for debug information. this->GetDebug()
Definition: utlCoreMacro.h:193
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
void GetITKVectorImageFromMXArray(const mxArray *pr, SmartPointer< VectorImage< T, 3 > > &image)
Definition: mexITK.h:128
int mwSize
Definition: utils.h:38
NDArray< T, 2 > CartesianToSpherical(const NDArray< T, 2 > &in)
#define utlGlobalException(cond, expout)
Definition: utlCoreMacro.h:372
#define ONE_OVER_4_PI_2
Definition: utlCoreMacro.h:63
static int LogLevel
Definition: utlCoreMacro.h:203
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
static void InitializeThreadedLibraries(const int numThreads)
Definition: utl.h:327
estimate the coeffcients of generalized Spherical Polar Fourier basis which can be separated into dif...
solve weighted LASSO using spams
void GetUtlMatrixFromMXArray(const mxArray *pr, NDArray< T, 2 > *mat)
Definition: utlMEX.h:33
itk::VectorImage< ScalarType, 3 > VectorImageType
Definition: 4DImageMath.cxx:30
void GetSTDVectorFromMXArray(const mxArray *pr, std::vector< T > *vec)
Definition: mexSTD.h:30
void GetITKImageFromMXArray(const mxArray *pr, SmartPointer< Image< T, VImageDimension > > &image)
Definition: mexITK.h:55
void callFunction(mxArray *plhs[], const mxArray *prhs[], const int nlhs, const int nrhs)