18 #ifndef __itkSphericalPolarFourierImageFilter_hxx 19 #define __itkSphericalPolarFourierImageFilter_hxx 27 #include <itkProgressReporter.h> 32 template<
class TInputImage,
class TOutputImage >
39 this->ComputeScale(
true);
42 template<
class TInputImage,
class TOutputImage >
48 double scale_old = this->m_BasisScale;
50 this->m_BasisScale = scale;
52 this->ComputeScale(
true);
53 itkDebugMacro(
"setting this->m_BasisScale to " << this->m_BasisScale);
55 if (scale>0 && std::fabs((scale_old-this->m_BasisScale)/this->m_BasisScale)>1e-8)
66 template<
class TInputImage,
class TOutputImage >
67 typename LightObject::Pointer
72 typename LightObject::Pointer loPtr = Superclass::InternalClone();
77 itkExceptionMacro(<<
"downcast to type " << this->GetNameOfClass()<<
" failed.");
81 rval->m_G0DWI = m_G0DWI;
85 template<
class TInputImage,
class TOutputImage >
92 double tau = this->m_SamplingSchemeQSpace->GetTau();
93 if (this->m_IsOriginalBasis)
94 scale = 1.0 / (8*
M_PI*
M_PI*tau*this->m_MD0);
96 scale = 2*tau*this->m_MD0;
101 this->SetBasisScale(scale);
102 if (this->GetDebug())
103 std::cout <<
"m_BasisScale = " << this->m_BasisScale << std::endl;
107 template<
class TInputImage,
class TOutputImage >
112 int sh_b = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
113 int n = index / sh_b;
114 int residual = index - n*sh_b;
116 std::vector<int> nlm;
119 nlm.push_back(lm[0]);
120 nlm.push_back(lm[1]);
124 template<
class TInputImage,
class TOutputImage >
132 template<
class TInputImage,
class TOutputImage >
137 std::vector<int> result;
138 int radialRank=-1, shRank=-1;
139 for (
int radialRank = 0; radialRank <= 10; radialRank += 1 )
141 for (
int shRank = 0; shRank <= 12; shRank += 2 )
143 int dim = RankToDim(
false, radialRank, shRank);
146 result.push_back(radialRank);
147 result.push_back(shRank);
156 template<
class TInputImage,
class TOutputImage >
159 ::RankToDim (
const bool is_radial,
const int radialRank,
const int shRank)
const 161 int radialRank_real = radialRank>=0?radialRank:this->m_RadialRank;
162 int shRank_real = shRank>=0?shRank:this->m_SHRank;
163 utlException(radialRank_real<0 || shRank_real<0,
"wrong rank");
165 return radialRank_real+1;
167 return (shRank_real + 1)*(shRank_real + 2)/2*(radialRank_real+1);
170 template<
class TInputImage,
class TOutputImage >
178 spf->SetSPFType(SPFGenerator::SPF);
179 spf->SetScale(this->m_BasisScale);
180 for (
int i = 0; i < this->m_RadialRank+1; i += 1 )
183 (*m_Gn0)[i] = spf->Evaluate(0);
187 template<
class TInputImage,
class TOutputImage >
193 if (this->m_SamplingSchemeQSpace->GetIndicesInShells()->size()==0)
194 this->m_SamplingSchemeQSpace->GroupRadiusValues();
197 typename Superclass::SamplingSchemeQSpaceType::Index2DVectorPointer indices = this->m_SamplingSchemeQSpace->GetIndicesInShells();
199 for (
int i = 0; i < indices->size(); i += 1 )
201 typename Superclass::SamplingSchemeQSpaceType::IndexVectorType indexTemp = (*indices)[i];
202 double qq = (*qVector)[ indexTemp[0] ];
203 double val = std::exp(-qq*qq/(2.0*this->m_BasisScale));
204 for (
int j = 0; j < indexTemp.size(); j += 1 )
205 (*m_G0DWI)[ indexTemp[j] ] = val;
218 template<
class TInputImage,
class TOutputImage >
225 if (this->m_SamplingSchemeQSpace->GetBVector()->size()>0 && this->m_SamplingSchemeQSpace->GetRadiusVector()->size()==0)
226 this->m_SamplingSchemeQSpace->ConvertBVectorToQVector();
228 if (this->m_SamplingSchemeQSpace->GetIndicesInShells()->size()==0)
229 this->m_SamplingSchemeQSpace->GroupRadiusValues();
233 n_s = this->m_SamplingSchemeQSpace->GetRadiusVector()->size();
236 const STDVectorPointer qVector = this->m_SamplingSchemeQSpace->GetRadiusVector();
237 typename Superclass::SamplingSchemeQSpaceType::Index2DVectorPointer indices = this->m_SamplingSchemeQSpace->GetIndicesInShells();
239 std::string threadIDStr = this->ThreadIDToString();
240 if (this->GetDebug())
242 std::ostringstream msg;
243 msg << threadIDStr <<
"m_BasisScale = " << this->m_BasisScale <<
", m_Tau = " << this->m_SamplingSchemeQSpace->GetTau() <<
", numberOfShell = " << indices->size() << std::endl;
244 this->WriteLogger(msg.str());
249 spf->SetScale(this->m_BasisScale);
251 if (this->m_IsOriginalBasis)
253 spf->SetSPFType(SPFGenerator::SPF);
256 n_b = this->m_RadialRank + 1;
261 std::ostringstream msg;
262 msg << threadIDStr <<
"Generating the "<< n_s <<
"x" << n_b <<
" RadialMatrix...\n";
263 this->WriteLogger(msg.str());
268 for (
int ib = 0; ib < n_b; ib += 1 )
271 for (
int shell = 0; shell < indices->size(); shell += 1 )
273 typename Superclass::SamplingSchemeQSpaceType::IndexVectorType indexTemp = (*indices)[shell];
274 double qq = (*qVector)[ indexTemp[0] ];
275 double spfVal = spf->Evaluate(qq,
false);
276 for (
int js = 0; js < indexTemp.size(); js += 1 )
277 (*B)(indexTemp[js],ib) = spfVal;
284 spf->SetSPFType(SPFGenerator::DSPF);
286 n_b = (this->m_RadialRank+1)*(this->m_SHRank/2+1);
290 std::ostringstream msg;
291 msg << threadIDStr <<
"Generating the "<< n_s <<
"x" << n_b <<
" RadialMatrix...\n";
292 this->WriteLogger(msg.str());
297 for (
int ib = 0; ib < this->m_RadialRank+1; ib += 1 )
300 for (
int l = 0; l <= this->m_SHRank; l += 2 )
303 int col = ib*(this->m_SHRank/2+1)+l/2;
304 for (
int shell = 0; shell < indices->size(); shell += 1 )
306 typename Superclass::SamplingSchemeQSpaceType::IndexVectorType indexTemp = (*indices)[shell];
307 double qq = (*qVector)[ indexTemp[0] ];
308 double spfVal = spf->Evaluate(qq,
false);
309 for (
int js = 0; js < indexTemp.size(); js += 1 )
310 (*B)(indexTemp[js],col) = spfVal;
319 std::ostringstream msg;
321 this->WriteLogger(msg.str());
324 this->m_BasisRadialMatrix = B;
327 template<
class TInputImage,
class TOutputImage >
334 if (this->m_SamplingSchemeQSpace->GetBVector()->size()>0 && this->m_SamplingSchemeQSpace->GetRadiusVector()->size()==0)
335 this->m_SamplingSchemeQSpace->ConvertBVectorToQVector();
337 if (this->m_SamplingSchemeQSpace->GetIndicesInShells()->size()==0)
338 this->m_SamplingSchemeQSpace->GroupRadiusValues();
340 if (this->m_BasisSHMatrix->Rows()==0)
344 if (this->m_BasisRadialMatrix->Rows()==0)
345 this->ComputeRadialMatrix();
348 MatrixPointer qOrientations = this->m_SamplingSchemeQSpace->GetOrientationsSpherical();
350 int qVector_size = this->m_SamplingSchemeQSpace->GetRadiusVector()->size();
351 int grad_size = qOrientations->Rows();
353 std::string threadIDStr = this->ThreadIDToString();
354 if (this->GetDebug())
356 std::ostringstream msg;
357 msg << threadIDStr <<
"this->m_SamplingSchemeQSpace->GetRadiusVector()->size() = " << this->m_SamplingSchemeQSpace->GetRadiusVector()->size() << std::endl;
358 msg << threadIDStr <<
"qOrientations->Rows() = " << qOrientations->Rows() << std::endl;
359 msg << threadIDStr <<
"qVector_size = " << qVector_size << std::endl;
360 msg << threadIDStr <<
"grad_size = " << grad_size << std::endl;
361 this->WriteLogger(msg.str());
363 utlException(qVector_size!=grad_size,
"qVector_size and grad_size should keep the same size");
364 int n_s = qVector_size;
367 int n_b_ra = this->m_RadialRank + 1;
368 int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
369 int n_b = n_b_ra * n_b_sh;
373 std::ostringstream msg;
374 msg << threadIDStr <<
"n_b_ra = " << n_b_ra << std::endl;
375 msg << threadIDStr <<
"n_b_sh = " << n_b_sh << std::endl;
376 msg << threadIDStr <<
"n_b = " << n_b << std::endl;
377 msg << threadIDStr <<
"Generating the "<< n_s <<
"x" << n_b <<
" basis matrix...\n";
378 this->WriteLogger(msg.str());
382 utlException(B_sh->Columns()!=n_b_sh,
"the SHMatrix does not have the right width. B_sh=" << *B_sh <<
", n_b_sh="<< n_b_sh);
383 utlException(B_sh->Rows()!=B_ra->Rows(),
"the SHMatrix and the RadialMatrix do not have the same samples. B_sh="<< *B_sh <<
", B_ra=" << *B_ra);
386 utlException(B_sh->Rows()!=B->Rows(),
"the SHMatrix and the basisMatrix do not have the same samples. B_sh="<< *B_sh <<
", B="<< *B);
388 if (this->m_IsOriginalBasis)
390 utlException(B_ra->Columns()!=n_b_ra,
"the RadialMatrix does not have the right width");
393 double *B_sh_data = B_sh->GetData();
394 double *B_ra_data = B_ra->GetData();
395 double *B_data = B->GetData();
396 int index_ra=0, index_B=0;
397 std::vector<int> index_sh(n_b_ra,0);
398 for (
int k = 0; k < n_s; k += 1 )
400 for (
int i = 0; i < n_b_ra; i += 1 )
402 for (
int j = 0; j < n_b_sh; j += 1 )
404 B_data[index_B] = B_ra_data[index_ra] * B_sh_data[index_sh[i] ];
425 utlException(B_ra->Columns()!=n_b_ra*(this->m_SHRank/2+1),
"the RadialMatrix does not have the right width");
426 for (
int n = 0; n < n_b_ra; n += 1 )
429 for (
int l = 0; l <= this->m_SHRank; l += 2 )
431 int col_ra = n*(this->m_SHRank/2+1)+l/2;
432 for (
int m = -l; m <= l; m += 1 )
434 int col = n*n_b_sh+jj;
435 for (
int k = 0; k < n_s; k += 1 )
437 (*B)(k,col) = (*B_ra)(k,col_ra) * (*B_sh)(k,jj);
465 std::ostringstream msg;
467 this->WriteLogger(msg.str());
470 this->m_BasisMatrix = B;
473 template<
class TInputImage,
class TOutputImage >
479 Pointer selfClone = this->Clone();
483 selfClone->m_SamplingSchemeQSpace = SamplingSchemeQSpaceType::New();
487 sampling->SetBVector(b0Vector);
488 sampling->SetOrientationsCartesian(grad);
489 selfClone->ComputeBasisMatrix();
490 this->m_BasisMatrixForB0 = selfClone->GetBasisMatrix();
493 template<
class TInputImage,
class TOutputImage >
500 int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
501 int n_b_ra = this->m_RadialRank + 1;
503 utlException((this->m_EstimationType==Self::LS || this->m_EstimationType==Self::L1_2) && this->m_LambdaSpherical<0 && this->m_LambdaRadial<0,
"need to set m_LambdaSpherical and m_LambdaSpherical");
504 utlException((this->m_EstimationType==Self::L1_DL) && this->m_LambdaL1<0,
"need to set m_LambdaL1");
505 if (this->m_EstimationType!=Self::L1_DL)
507 this->m_RegularizationWeight=
VectorPointer(
new VectorType(!this->m_IsAnalyticalB0?this->RankToDim():(n_b_sh*this->m_RadialRank)) );
508 for (
int i = 0; i <= ((!this->m_IsAnalyticalB0)?this->m_RadialRank:(this->m_RadialRank-1)); i += 1 )
511 for (
int l = 0; l <= this->m_SHRank; l += 2 )
514 if (this->m_IsAnalyticalB0)
515 lambda = this->m_LambdaSpherical*l*l*(l+1)*(l+1) + this->m_LambdaRadial*(i+1)*(i+1)*(i+2)*(i+2);
517 lambda = this->m_LambdaSpherical*l*l*(l+1)*(l+1) + this->m_LambdaRadial*i*i*(i+1)*(i+1);
518 for (
int m = -l; m <= l; m += 1 )
520 if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::LS)
521 (*this->m_RegularizationWeight)[i*n_b_sh+j] = lambda;
529 if (this->m_BasisCombinationMatrix->Size()==0)
533 if (this->GetDebug())
536 if (this->m_BasisEnergyDL->Size()==0)
538 if (
std::abs(this->m_BasisEnergyPowerDL)>1e-10)
540 std::vector<double> vecTemp;
543 *this->m_BasisEnergyDL /= this->m_BasisEnergyDL->GetMean();
545 if (
std::abs(this->m_BasisEnergyPowerDL-1)>1e-10)
546 utl::PowerVector(this->m_BasisEnergyDL->Begin(), this->m_BasisEnergyDL->End(), this->m_BasisEnergyPowerDL);
551 this->m_BasisEnergyDL->Fill(1.0);
554 if (this->GetDebug())
557 utlException(this->m_BasisMatrix->Size()>0 && this->m_BasisCombinationMatrix->Rows()!=this->m_BasisMatrix->Columns()-n_b_sh,
"wrong size of dictionary. m_BasisCombinationMatrix->Rows()="<<this->m_BasisCombinationMatrix->Rows() <<
", BasisMatrix->Columns()-n_b_sh="<< this->m_BasisMatrix->Columns()-n_b_sh);
558 utlException(this->m_BasisMatrix->Size()>0 && this->m_BasisCombinationMatrix->Columns()!=this->m_BasisEnergyDL->Size(),
"wrong size of dictionary. this->m_BasisEnergyDL.size()="<<this->m_BasisEnergyDL->Size());
561 for (
int i = 0; i < this->m_BasisEnergyDL->Size(); i += 1 )
563 if (this->m_EstimationType==Self::L1_DL)
564 (*this->m_RegularizationWeight)[i] = this->m_LambdaL1 / (*this->m_BasisEnergyDL)[i];
566 (*this->m_RegularizationWeight)[i] = this->m_LambdaL2 / (*this->m_BasisEnergyDL)[i];
572 template<
class TInputImage,
class TOutputImage >
578 Superclass::BeforeThreadedGenerateData();
580 if (this->m_SamplingSchemeQSpace->GetBVector()->size()>0 && this->m_SamplingSchemeQSpace->GetRadiusVector()->size()==0)
581 this->m_SamplingSchemeQSpace->ConvertBVectorToQVector();
583 if (this->m_SamplingSchemeQSpace->GetIndicesInShells()->size()==0)
584 this->m_SamplingSchemeQSpace->GroupRadiusValues();
586 int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
587 int n_b_ra = this->m_RadialRank + 1;
590 if (!this->m_IsAnalyticalB0)
592 std::cout <<
"Use numerical way for constraint E(0)=1" << std::endl << std::flush;
593 this->ComputeBasisMatrixForB0();
597 if (m_Gn0->size()==0)
599 this->ComputeRadialVectorForE0InBasis();
600 this->ComputeRadialVectorForE0InDWI();
602 utlGlobalException(!this->m_IsOriginalBasis,
"TODO: use analytical way for E(0)=1 and DSPF basis");
603 std::cout <<
"Use analytical way for constraint E(0)=1" << std::endl << std::flush;
606 if (!this->IsAdaptiveScale())
608 std::cout <<
"Use the same scale for all voxels!" << std::endl << std::flush;
609 this->ComputeBasisMatrix();
612 if (!this->m_IsAnalyticalB0)
614 utlGlobalException(this->m_EstimationType==Self::L1_DL,
"L1-DL only supports analytical way");
615 *basisMatrix =
utl::ConnectUtlMatrix(*this->m_BasisMatrix, *utl::ToMatrix<double>(*this->m_BasisMatrixForB0 % this->m_B0Weight),
true);
620 utlException((*m_Gn0)[0]==0,
"it should be not zero!, (*m_Gn0)[0]="<< (*m_Gn0)[0]);
623 for (
int i = 0; i < this->m_RadialRank; i += 1 )
625 for (
int j = 0; j < n_b_sh; j += 1 )
627 for (
int ss = 0; ss < basisMatrix->Rows(); ss += 1 )
629 (*basisMatrix)(ss,i*n_b_sh+j) = (*this->m_BasisMatrix)(ss,(i+1)*n_b_sh+j) - (*m_Gn0)[i+1]/(*m_Gn0)[0] * (*this->m_BasisMatrix)(ss,j);
633 if (this->m_EstimationType==Self::L1_DL)
635 if (this->m_BasisCombinationMatrix->Size()==0)
640 basisMatrix = tmpMat;
643 if (this->GetDebug())
646 if (this->m_EstimationType==Self::LS)
648 this->m_L2Solver->SetA(basisMatrix);
650 else if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::L1_DL)
652 if (this->m_L1SolverType==Self::FISTA_LS)
654 this->m_L1FISTASolver->SetA(basisMatrix);
657 else if (this->m_L1SolverType==Self::SPAMS)
658 this->m_L1SpamsSolver->SetA(basisMatrix);
665 std::cout <<
"Use adaptive scale for each voxel!" << std::endl << std::flush;
668 typename ScaleFromMDfilterType::Pointer scaleFromMDfilter = ScaleFromMDfilterType::New();
669 scaleFromMDfilter->SetMD0(this->m_MD0);
670 scaleFromMDfilter->SetTau(this->m_SamplingSchemeQSpace->GetTau());
671 scaleFromMDfilter->SetIsOriginalBasis(this->m_IsOriginalBasis);
672 scaleFromMDfilter->SetInput(this->m_MDImage);
673 scaleFromMDfilter->SetInPlace(
false);
674 scaleFromMDfilter->Update();
675 this->m_ScaleImage = scaleFromMDfilter->GetOutput();
678 this->ComputeRegularizationWeight();
679 if (this->GetDebug())
683 if (this->m_EstimationType==Self::LS)
685 if (this->m_LambdaSpherical>0 || this->m_LambdaRadial>0)
688 *lamMat = this->m_RegularizationWeight->GetDiagonalMatrix();
689 this->m_L2Solver->SetLambda(lamMat);
692 else if (this->m_EstimationType==Self::L1_2)
694 if (this->m_L1SolverType==Self::FISTA_LS)
696 this->m_L1FISTASolver->SetwForInitialization(this->m_RegularizationWeight);
697 this->m_L1FISTASolver->Setw(this->m_RegularizationWeight);
700 else if (this->m_L1SolverType==Self::SPAMS)
701 this->m_L1SpamsSolver->Setw(this->m_RegularizationWeight);
703 else if (this->m_EstimationType==Self::L1_DL)
705 if (this->m_L1SolverType==Self::FISTA_LS)
709 this->m_L1FISTASolver->SetwForInitialization(utl::ToVector<double> (*this->m_RegularizationWeight % qVector->size()) );
711 this->m_L1FISTASolver->Setw(utl::ToVector<double>(*this->m_RegularizationWeight % qVector->size()) );
713 else if (this->m_L1SolverType==Self::SPAMS)
714 this->m_L1SpamsSolver->Setw(utl::ToVector<double>(*this->m_RegularizationWeight % qVector->size()));
724 template<
class TInputImage,
class TOutputImage >
730 ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
736 ImageRegionIteratorWithIndex<OutputImageType> outputIt(outputPtr, outputRegionForThread );
737 ImageRegionConstIteratorWithIndex<InputImageType> inputIt(inputPtr, outputRegionForThread );
738 ImageRegionIteratorWithIndex<MaskImageType> maskIt;
739 ImageRegionIteratorWithIndex<ScalarImageType> scaleIt;
740 if (this->IsMaskUsed())
741 maskIt = ImageRegionIteratorWithIndex<MaskImageType>(this->m_MaskImage, outputRegionForThread);
743 scaleIt = ImageRegionIteratorWithIndex<ScalarImageType>(this->m_ScaleImage, outputRegionForThread);
749 unsigned int numberOfCoeffcients = outputPtr->GetNumberOfComponentsPerPixel();;
750 outputPixel.SetSize(numberOfCoeffcients);
751 outputZero.SetSize(numberOfCoeffcients), outputZero.Fill(0.0);
752 unsigned int numberOfDWIs = inputPtr->GetNumberOfComponentsPerPixel();
753 inputPixel.SetSize(numberOfDWIs);
754 int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
755 int n_b_ra = this->m_RadialRank + 1;
756 int n_b = n_b_ra * n_b_sh;
758 VectorType dwiPixel(numberOfDWIs+this->m_BasisMatrixForB0->Rows()), dwiPixel_est(numberOfDWIs+this->m_BasisMatrixForB0->Rows()), dwiPixel_first(numberOfDWIs), coef(numberOfCoeffcients), coef_first;
762 Pointer selfClone = this->Clone();
764 selfClone->m_ThreadID = threadId;
765 std::string threadIDStr = selfClone->ThreadIDToString();
766 if (this->GetDebug())
768 std::ostringstream msg;
769 selfClone->Print(msg << threadIDStr <<
"selfClone = ");
770 this->WriteLogger(msg.str());
774 STDVectorPointer qVector = selfClone->m_SamplingSchemeQSpace->GetRadiusVector();
776 if (!this->IsAdaptiveScale())
778 selfClone->ComputeRadialVectorForE0InDWI();
779 selfClone->ComputeRadialVectorForE0InBasis();
783 if (!this->m_IsAnalyticalB0 && !this->IsAdaptiveScale())
784 *basisMatrix =
utl::ConnectUtlMatrix(*selfClone->m_BasisMatrix, *utl::ToMatrix<double>(*selfClone->m_BasisMatrixForB0 % selfClone->m_B0Weight),
true);
786 for (inputIt.GoToBegin(), outputIt.GoToBegin(), maskIt.GoToBegin(), scaleIt.GoToBegin();
788 progress.CompletedPixel(), ++inputIt, ++outputIt, ++maskIt, ++scaleIt)
791 if (this->IsMaskUsed() && maskIt.Get()<=1e-8)
793 outputIt.Set(outputZero);
797 inputPixel=inputIt.Get();
798 if (inputPixel.GetSquaredNorm()<=1e-8)
800 outputIt.Set(outputZero);
804 index = inputIt.GetIndex();
805 if (this->GetDebug())
807 std::ostringstream msg;
808 msg <<
"\n" << threadIDStr <<
"index = " << index << std::endl << std::flush;
809 this->WriteLogger(msg.str());
812 for (
int i = 0; i < numberOfDWIs; i += 1 )
813 dwiPixel[i] = inputPixel[i];
815 if (this->IsAdaptiveScale())
817 double scale = scaleIt.Get();
821 outputIt.Set(outputZero);
825 selfClone->SetBasisScale(scale);
826 if (selfClone->m_BasisMatrix->Rows()==0)
828 selfClone->ComputeBasisMatrix();
829 if (!selfClone->m_IsAnalyticalB0)
830 selfClone->ComputeBasisMatrixForB0();
833 selfClone->ComputeRadialVectorForE0InDWI();
834 selfClone->ComputeRadialVectorForE0InBasis();
838 if (!this->m_IsAnalyticalB0)
841 *basisMatrix =
utl::ConnectUtlMatrix(*selfClone->m_BasisMatrix, *utl::ToMatrix<double>((*selfClone->m_BasisMatrixForB0)%this->m_B0Weight),
true);
845 if (basisMatrix->Size()==0 || this->m_EstimationType==Self::L1_DL)
847 utlException((*selfClone->m_Gn0)[0]==0,
"it should be not zero!, (*selfClone->m_Gn0)[0]="<< (*selfClone->m_Gn0)[0]);
852 double *selfBasisMatrix_data = selfClone->m_BasisMatrix->GetData();
853 double *basisMatrix_data = basisMatrix->GetData();
854 int index_B=0, index_selfB=0, index_selfB_0=0;
855 for (
int ss = 0; ss < basisMatrix->Rows(); ss += 1 )
857 for (
int i = 0; i < this->m_RadialRank; i += 1 )
860 index_selfB += n_b_sh;
862 index_selfB_0 -= n_b_sh;
863 for (
int j = 0; j < n_b_sh; j += 1 )
865 basisMatrix_data[index_B] = selfBasisMatrix_data[index_selfB] - (*selfClone->m_Gn0)[i+1]/(*selfClone->m_Gn0)[0] * selfBasisMatrix_data[index_selfB_0];
872 index_selfB_0 += n_b_sh*this->m_RadialRank;
894 if (this->m_EstimationType==Self::L1_DL)
898 *basisMatrix = *tmpMat;
904 if (this->GetDebug())
906 std::ostringstream msg;
908 this->WriteLogger(msg.str());
911 if (this->m_EstimationType==Self::LS)
913 selfClone->m_L2Solver->SetA(basisMatrix);
915 else if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::L1_DL)
917 if (this->m_L1SolverType==Self::FISTA_LS)
918 selfClone->m_L1FISTASolver->SetA(basisMatrix);
919 else if (this->m_L1SolverType==Self::SPAMS)
920 selfClone->m_L1SpamsSolver->SetA(basisMatrix);
924 if (this->m_IsAnalyticalB0)
926 for (
int i = 0; i < dwiPixel.Size(); i += 1 )
927 dwiPixel_first[i] = dwiPixel[i] - (*selfClone->m_G0DWI)[i];
929 if (this->m_EstimationType==Self::LS)
932 selfClone->m_L2Solver->Solve();
933 coef_first = selfClone->m_L2Solver->Getx();
939 else if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::L1_DL)
941 if (this->m_L1SolverType==Self::FISTA_LS)
946 selfClone->m_L1FISTASolver->Solve();
947 coef_first = selfClone->m_L1FISTASolver->Getx();
949 else if (this->m_L1SolverType==Self::SPAMS)
952 selfClone->m_L1SpamsSolver->Solve();
953 coef_first = selfClone->m_L1SpamsSolver->Getx();
956 if (this->GetDebug())
958 std::ostringstream msg;
959 if (selfClone->m_L1SpamsSolver)
961 msg << threadIDStr <<
"use m_L1SpamsSolver" << std::endl << std::flush;
962 selfClone->m_L1SpamsSolver->
Print(msg << threadIDStr<<
"this->m_L1QPSolver = ");
963 double func = selfClone->m_L1SpamsSolver->EvaluateCostFunction();
964 msg << threadIDStr <<
"func spams = " << func << std::endl << std::flush;
966 if (selfClone->m_L1FISTASolver)
968 msg << threadIDStr <<
"use m_L1FISTASolver" << std::endl << std::flush;
969 selfClone->m_L1FISTASolver->Print(msg << threadIDStr<<
"this->m_L1FISTASolver = ");
970 std::vector<double> funcVec = selfClone->m_L1FISTASolver->GetCostFunction();
973 this->WriteLogger(msg.str());
977 if (this->m_EstimationType==Self::L1_DL)
981 for (
int i = 0; i < coef_first_tmp.
Size(); i += 1 )
982 coef[i+n_b_sh] = coef_first_tmp[i];
986 for (
int i = 0; i < coef_first.
Size(); i += 1 )
987 coef[i+n_b_sh] = coef_first[i];
991 for (
int l = 0; l <= this->m_SHRank; l += 2 )
993 for (
int m = -l; m <= l; m += 1 )
996 for (
int nn = 1; nn <= this->m_RadialRank; nn += 1 )
998 int index_j = this->GetIndexJ(nn,l,m);
999 sum_tmp += coef[index_j] * (*selfClone->m_Gn0)[nn];
1002 coef[jj] = (std::sqrt(4*
M_PI) - sum_tmp)/(*selfClone->m_Gn0)[0];
1004 coef[jj] = -sum_tmp/(*selfClone->m_Gn0)[0];
1012 for (
int i = 0; i < selfClone->m_BasisMatrixForB0->Rows(); i += 1 )
1013 dwiPixel[i+numberOfDWIs] = selfClone->m_B0Weight;
1017 if (this->m_EstimationType==Self::LS)
1020 selfClone->m_L2Solver->Solve();
1021 coef = selfClone->m_L2Solver->Getx();
1023 else if (this->m_EstimationType==Self::L1_2 || this->m_EstimationType==Self::L1_DL)
1025 if (this->m_L1SolverType==Self::FISTA_LS)
1028 selfClone->m_L1FISTASolver->Solve();
1029 coef = selfClone->m_L1FISTASolver->Getx();
1031 else if (this->m_L1SolverType==Self::SPAMS)
1034 selfClone->m_L1SpamsSolver->Solve();
1035 coef = selfClone->m_L1SpamsSolver->Getx();
1041 if (this->GetDebug())
1043 std::ostringstream msg;
1046 if (this->m_IsAnalyticalB0)
1059 if (selfClone->m_BasisMatrixForB0->Size()==0)
1060 selfClone->ComputeBasisMatrixForB0();
1064 this->WriteLogger(msg.str());
1067 for (
int i = 0; i < numberOfCoeffcients; i += 1 )
1068 outputPixel[i] = coef[i];
1070 outputIt.Set(outputPixel);
1074 template<
class TInputImage,
class TOutputImage >
1079 Superclass::PrintSelf(os, indent);
NDArray is a N-Dimensional array class (row-major, c version)
LightObject::Pointer InternalClone() const ITK_OVERRIDE
void ReadMatrix(const std::string &file, TMatrixType &matrix)
utl_shared_ptr< MatrixType > MatrixPointer
void ComputeRadialVectorForE0InDWI()
base filter for estimation of diffusion models
void ComputeBasisMatrixForB0()
helper functions specifically used in dmritool
std::vector< double > STDVectorType
int GetIndexSHj(const int l, const int m)
void ComputeBasisMatrix() ITK_OVERRIDE
bool IsImageEmpty(const SmartPointer< ImageType > &image)
void Print(std::ostream &os, const char *separate=" ") const
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
static const std::string LearnedSPFDictionary_SH8_RA4_K250
void ComputeRadialMatrix() ITK_OVERRIDE
NDArray< T, 2 > ConnectUtlMatrix(const NDArray< T, 2 > &m1, const NDArray< T, 2 > &m2, const bool isConnectRow)
#define utlException(cond, expout)
std::vector< int > GetIndexSHlm(const int j)
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
SmartPointer< Self > Pointer
void ComputeRegularizationWeight() ITK_OVERRIDE
void ComputeRadialVectorForE0InBasis()
std::vector< int > DimToRank(const int dimm) const ITK_OVERRIDE
int GetIndexJ(const int n, const int l, const int m) const ITK_OVERRIDE
void PowerVector(IteratorType v1, IteratorType v2, const double poww)
Superclass::InputImageIndexType InputImageIndexType
Superclass::InputImageConstPointer InputImageConstPointer
Compute SPF scale from mean diffusivity.
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, ThreadIdType threadId) ITK_OVERRIDE
NDArray< T, 1 > StdVectorToUtlVector(const std::vector< T > &vec)
utl_shared_ptr< STDVectorType > STDVectorPointer
SphericalPolarFourierImageFilter()
std::shared_ptr< NDArray< T, 2 > > ComputeSHMatrix(const unsigned int rank, const NDArray< T, 2 > &grad, const int mode)
std::string CreateExpandedPath(const std::string &path)
utl_shared_ptr< VectorType > VectorPointer
double ComputeScale(const bool setScale=true) ITK_OVERRIDE
void ProductUtlMM(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 2 > &B, utl::NDArray< T, 2 > &C, const double alpha=1.0, const double beta=0.0)
T abs(const T x)
template version of the fabs function
#define utlGlobalException(cond, expout)
static const std::string LearnedSPFEnergy_SH8_RA4_K250
SamplingSchemeQSpaceType::Pointer SamplingSchemeQSpacePointer
Superclass::InputImagePixelType InputImagePixelType
Superclass::OutputImagePixelType OutputImagePixelType
Superclass::OutputImagePointer OutputImagePointer
void ReadVector(const std::string &vectorStr, std::vector< T > &vec, const char *cc=" ")
void PrintVector(const std::vector< T > &vec, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout, bool showStats=true)
void BeforeThreadedGenerateData() ITK_OVERRIDE
static void InitializeThreadedLibraries(const int numThreads)
int RankToDim(const bool is_radial=false, const int radialRank=-1, const int shRank=-1) const ITK_OVERRIDE
Superclass::OutputImageRegionType OutputImageRegionType
#define itkShowPositionThreadedLogger(cond)
NDArrayBase< T, Dim > & ElementAbsolute(T *outVec=NULL)
void PrintUtlVector(const NDArray< T, 1 > &vec, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout, bool showStats=true)
std::vector< int > GetIndexNLM(const int index) const ITK_OVERRIDE
void ProductUtlMv(const utl::NDArray< T, 2 > &A, const utl::NDArray< T, 1 > &b, utl::NDArray< T, 1 > &c, const double alpha=1.0, const double beta=0.0)
SmartPointer< Self > Pointer
void SetBasisScale(const double scale) ITK_OVERRIDE