32 const int nlhs,
const int nrhs)
34 utlGlobalException(!utl::mexCheckType<T>(prhs[0]),
"type of argument 1 is not consistent");
35 utlGlobalException(!utl::mexCheckType<T>(prhs[1]),
"type of argument 2 is not consistent");
36 utlGlobalException(!utl::mexCheckType<T>(prhs[2]),
"type of argument 3 is not consistent");
39 typedef itk::Image<double, 3> ImageType;
42 utlException(mxGetNumberOfDimensions(prhs[0])!=4,
"the input should have 4 dimension");
44 int verbose = utl::GetScalarStructDef<int>(prhs[3],
"verbose",1);
47 const mwSize* dimsDWIs = mxGetDimensions(prhs[0]);
48 int Nx =
static_cast<int>(dimsDWIs[0]);
49 int Ny =
static_cast<int>(dimsDWIs[1]);
50 int Nz =
static_cast<int>(dimsDWIs[2]);
51 int numberOfDWIs =
static_cast<int>(dimsDWIs[3]);
54 const mwSize* dimsOrientation = mxGetDimensions(prhs[1]);
55 utlException(dimsOrientation[0]!=numberOfDWIs,
"wrong number of gradients");
56 utlException(dimsOrientation[1]!=3,
"the column of gradient should be 3");
58 const mwSize* dimsBVec = mxGetDimensions(prhs[2]);
59 utlException(dimsBVec[0]!=numberOfDWIs,
"wrong number of bVec");
60 utlException(dimsBVec[1]!=1,
"the column of bVec should be 1");
62 double MD0 = utl::GetScalarStructDef<double>(prhs[3],
"MD0",-1.0);
63 double tau = utl::GetScalarStructDef<double>(prhs[3],
"tau",
ONE_OVER_4_PI_2);
64 double scale = utl::GetScalarStructDef<double>(prhs[3],
"scale",-1.0);
65 int sh = utl::GetScalarStructDef<int>(prhs[3],
"sh",-1);
67 int ra = utl::GetScalarStructDef<int>(prhs[3],
"ra",-1);
69 std::string estimation = utl::GetScalarStructDef<std::string>(prhs[3],
"estimation",
"LS");
70 std::string
solver = utl::GetScalarStructDef<std::string>(prhs[3],
"solver",
"SPAMS");
71 double lambdaSH = utl::GetScalarStructDef<double>(prhs[3],
"lambdaSH",0.0);
72 double lambdaRA = utl::GetScalarStructDef<double>(prhs[3],
"lambdaRA",0.0);
73 double lambdaL1 = utl::GetScalarStructDef<double>(prhs[3],
"lambdaL1",0.0);
77 double numericalB0Weight = utl::GetScalarStructDef<double>(prhs[3],
"numericalB0Weight",-1.0);
80 double energyPower = utl::GetScalarStructDef<double>(prhs[3],
"energyPower",1.0);
81 int maxIter = utl::GetScalarStructDef<int>(prhs[3],
"maxIter",1000);
82 double minChange = utl::GetScalarStructDef<double>(prhs[3],
"minChange",0.0001);
84 int thread = utl::GetScalarStructDef<int>(prhs[3],
"thread",-1.0);
87 typename SPFIFilterBaseType::Pointer spfiFilter=NULL;
89 spfiFilter = SPFIFilterType::New();
90 std::cout <<
"Use SPF basis" << std::endl << std::flush;
91 spfiFilter->SetIsOriginalBasis(
true);
95 typename ImageType::Pointer maskImage = ImageType::New();
97 spfiFilter->SetMaskImage(maskImage);
99 typename VectorImageType::Pointer dwiImage = VectorImageType::New();
101 spfiFilter->SetInput(dwiImage);
104 spfiFilter->SetMD0(MD0);
107 spfiFilter->GetSamplingSchemeQSpace()->SetTau(tau);
109 typename SPFIFilterBaseType::MatrixPointer grad(
new typename SPFIFilterBaseType::MatrixType());
112 spfiFilter->GetSamplingSchemeQSpace()->SetOrientationsSpherical(grad);
114 typename SPFIFilterBaseType::STDVectorPointer bVec(
new typename SPFIFilterBaseType::STDVectorType());
116 spfiFilter->GetSamplingSchemeQSpace()->SetBVector(bVec);
117 spfiFilter->SetSHRank(sh);
118 spfiFilter->SetRadialRank(ra);
119 spfiFilter->SetBasisScale(scale);
120 spfiFilter->SetIsAnalyticalB0(
true);
122 spfiFilter->SetLambdaSpherical(lambdaSH);
123 spfiFilter->SetLambdaRadial(lambdaRA);
124 spfiFilter->SetLambdaL1(lambdaL1);
126 if (estimation==
"LS")
128 spfiFilter->SetEstimationType(SPFIFilterBaseType::LS);
130 else if (estimation==
"L1_2" || estimation==
"L1_DL")
133 if (estimation==
"L1_2")
135 utlGlobalException(lambdaSH<=0 && lambdaRA<=0,
"need to set lambdaSH and lambdaRA when estimation=\"L1_2\".");
136 spfiFilter->SetEstimationType(SPFIFilterBaseType::L1_2);
138 else if (estimation==
"L1_DL")
141 spfiFilter->SetEstimationType(SPFIFilterBaseType::L1_DL);
144 if (solver==
"FISTA_LS")
147 L1SolverType::Pointer l1Sol = L1SolverType::New();
148 l1Sol->SetUseL2SolverForInitialization(solver==
"FISTA_LS");
149 l1Sol->SetMaxNumberOfIterations(maxIter);
150 l1Sol->SetMinRelativeChangeOfCostFunction(minChange);
151 l1Sol->SetMinRelativeChangeOfPrimalResidual(minChange);
152 spfiFilter->SetL1FISTASolver(l1Sol);
153 spfiFilter->SetL1SolverType(SPFIFilterBaseType::FISTA_LS);
158 L1SolverType::Pointer l1Sol = L1SolverType::New();
159 spfiFilter->SetL1SpamsSolver(l1Sol);
160 spfiFilter->SetL1SolverType(SPFIFilterBaseType::SPAMS);
168 typename ImageType::Pointer mdImage = ImageType::New();
170 spfiFilter->SetMDImage(mdImage);
174 spfiFilter->SetBasisEnergyPowerDL(1.0);
178 spfiFilter->SetNumberOfThreads(thread);
179 spfiFilter->SetDebug(utl::LogLevel>=
LOG_DEBUG);
180 spfiFilter->SetLogLevel(utl::LogLevel);
182 std::cout <<
"SPF estimation starts" << std::endl << std::flush;
183 spfiFilter->Update();
184 std::cout <<
"SPF estimation ends" << std::endl << std::flush;
186 typename VectorImageType::Pointer spf = spfiFilter->GetOutput();
198 int nrhs,
const mxArray *prhs[])
206 callFunction<double>(plhs,prhs,nlhs,nrhs);
mxArray * GetArrayStruct(const mxArray *pr_struct, const char *name)
void solver(const Matrix< T > &X, const AbstractMatrixB< T > &D, const Matrix< T > &alpha0, Matrix< T > &alpha, const ParamFISTA< T > ¶m1, Matrix< T > &optim_info, const GraphStruct< T > *graph_st=NULL, const TreeStruct< T > *tree_st=NULL, const GraphPathStruct< T > *graph_path_st=NULL)
AbstractMatrixB is basically either SpMatrix or Matrix.
helper functions specifically used in dmritool
solve least square problem with L1 regularization using FISTA
void GetMXArrayFromITKVectorImage(const SmartPointer< VectorImage< T, VImageDimension > > &image, mxArray *&pr)
used for debug information. this->GetDebug()
#define utlException(cond, expout)
void GetITKVectorImageFromMXArray(const mxArray *pr, SmartPointer< VectorImage< T, 3 > > &image)
NDArray< T, 2 > CartesianToSpherical(const NDArray< T, 2 > &in)
#define utlGlobalException(cond, expout)
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
estimate the coefficients in SPF model
static void InitializeThreadedLibraries(const int numThreads)
estimate the coeffcients of generalized Spherical Polar Fourier basis which can be separated into dif...
solve weighted LASSO using spams
void GetUtlMatrixFromMXArray(const mxArray *pr, NDArray< T, 2 > *mat)
itk::VectorImage< ScalarType, 3 > VectorImageType
void GetSTDVectorFromMXArray(const mxArray *pr, std::vector< T > *vec)
void GetITKImageFromMXArray(const mxArray *pr, SmartPointer< Image< T, VImageDimension > > &image)
void callFunction(mxArray *plhs[], const mxArray *prhs[], const int nlhs, const int nrhs)