18 #ifndef __itkGeneralizedHighOrderTensorImageFilter_hxx 19 #define __itkGeneralizedHighOrderTensorImageFilter_hxx 22 #include "itkProgressReporter.h" 26 template<
class TInputImage,
class TOutputImage >
32 template<
class TInputImage,
class TOutputImage >
40 MatrixPointer qOrientations = this->m_SamplingSchemeQSpace->GetOrientationsSpherical();
45 double bMax = *
std::max_element(this->m_SamplingSchemeQSpace->GetBVector()->begin(), this->m_SamplingSchemeQSpace->GetBVector()->end());
46 scale = 0.5*bMax/(4*
M_PI*
M_PI*this->m_SamplingSchemeQSpace->GetTau());
49 this->SetBasisScale(scale);
52 std::cout <<
"m_BasisScale = " << this->m_BasisScale << std::endl;
57 template<
class TInputImage,
class TOutputImage >
62 std::vector<int> result;
63 int radialRank=-1, shRank=-1;
64 for (
int radialRank = 1; radialRank <= 10; radialRank += 1 )
66 for (
int shRank = 0; shRank <= 12; shRank += 2 )
68 int dim = RankToDim(
false, radialRank, shRank);
71 result.push_back(radialRank);
72 result.push_back(shRank);
81 template<
class TInputImage,
class TOutputImage >
84 ::RankToDim (
const bool is_radial,
const int radialRank,
const int shRank)
const 86 int radialRank_real = radialRank>=0?radialRank:this->m_RadialRank;
87 int shRank_real = shRank>=0?shRank:this->m_SHRank;
89 return radialRank_real;
91 return (shRank_real + 1)*(shRank_real + 2)/2*(radialRank_real);
94 template<
class TInputImage,
class TOutputImage >
100 Superclass::VerifyInputParameters();
106 template<
class TInputImage,
class TOutputImage >
114 n_s = this->m_SamplingSchemeQSpace->GetBVector()->size();
117 if (this->GetDebug())
118 std::cout <<
"m_BasisScale = " << this->m_BasisScale << std::endl;
120 const STDVectorPointer bVector = this->m_SamplingSchemeQSpace->GetBVector();
124 for (
int i = 0; i < bVector->size(); i += 1 )
126 qVector[i] = std::sqrt((*bVector)[i]/(4*
M_PI*
M_PI*this->m_SamplingSchemeQSpace->GetTau()));
131 n_b = this->m_RadialRank;
135 std::cout <<
"Generating the "<< n_s <<
"x" << n_b <<
" RadialMatrix...\n";
139 for (
int js = 0; js < n_s; js += 1 )
141 double x_temp = qVector[js] / std::sqrt(this->m_BasisScale);
142 for (
int ib = 0; ib < n_b; ib += 1 )
144 (*B)(js,ib) = std::pow(x_temp, 2.0*(ib+1));
151 std::cout <<
"Generated the "<< n_s <<
"x" << n_b <<
" RadialMatrix...\n";
155 this->m_BasisRadialMatrix = B;
159 template<
class TInputImage,
class TOutputImage >
166 if (this->m_BasisSHMatrix->Rows()==0)
170 if (this->m_BasisRadialMatrix->Rows()==0)
171 this->ComputeRadialMatrix();
174 MatrixPointer qOrientations = this->m_SamplingSchemeQSpace->GetOrientationsSpherical();
176 int bVector_size = this->m_SamplingSchemeQSpace->GetBVector()->size();
177 int grad_size = qOrientations->Rows();
179 if (this->GetDebug())
181 std::cout <<
"this->m_SamplingSchemeQSpace->GetBVector()->size() = " << this->m_SamplingSchemeQSpace->GetBVector()->size() << std::endl;
182 std::cout <<
"qOrientations->Rows() = " << qOrientations->Rows() << std::endl;
183 std::cout <<
"bVector_size = " << bVector_size << std::endl;
184 std::cout <<
"grad_size = " << grad_size << std::endl;
186 utlException(bVector_size!=grad_size,
"bVector_size and grad_size should keep the same size");
187 int n_s = bVector_size;
190 int n_b_ra = this->m_RadialRank;
191 int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
192 int n_b = n_b_ra * n_b_sh;
196 std::cout <<
"n_b_ra = " << n_b_ra << std::endl;
197 std::cout <<
"n_b_sh = " << n_b_sh << std::endl;
198 std::cout <<
"n_b = " << n_b << std::endl;
199 std::cout <<
"Generating the "<< n_s <<
"x" << n_b <<
" basis matrix...\n";
203 utlException(B_sh->Columns()!=n_b_sh,
"the SHMatrix does not have the right width. B_sh=" << *B_sh <<
", n_b_sh="<< n_b_sh);
204 utlException(B_sh->Rows()!=B_ra->Rows(),
"the SHMatrix and the RadialMatrix do not have the same samples. B_sh="<< *B_sh <<
", B_ra=" << *B_ra);
205 utlException(B_sh->Rows()!=B->Rows(),
"the SHMatrix and the basisMatrix do not have the same samples. B_sh="<< *B_sh <<
", B="<< *B);
207 utlException(B_ra->Columns()!=n_b_ra,
"the RadialMatrix does not have the right width");
208 for (
int i = 0; i < n_b_ra; i += 1 )
209 for (
int j = 0; j < n_b_sh; j += 1 )
210 for (
int k = 0; k < n_s; k += 1 )
212 (*B)(k,i*n_b_sh+j) = (*B_ra)(k,i) * (*B_sh)(k,j);
220 this->m_BasisMatrix = B;
223 template<
class TInputImage,
class TOutputImage >
230 int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
231 int n_b_ra = this->m_RadialRank;
234 for (
int i = 0; i <= this->m_RadialRank-1; i += 1 )
237 for (
int l = 0; l <= this->m_SHRank; l += 2 )
239 for (
int m = -l; m <= l; m += 1 )
241 (*this->m_RegularizationWeight)(i*n_b_sh+j) = this->m_LambdaSpherical*l*l*(l+1)*(l+1) + this->m_LambdaRadial*i*i*(i+1)*(i+1);
248 template<
class TInputImage,
class TOutputImage >
255 ComputeBasisMatrix();
256 this->VerifyInputParameters();
257 this->m_L2Solver = L2SolverType::New();
258 this->m_L2Solver->SetA(this->m_BasisMatrix);
259 if (this->m_LambdaSpherical>0 || this->m_LambdaRadial>0)
261 this->ComputeRegularizationWeight();
263 mat->SetDiagonal(*this->m_RegularizationWeight);
264 this->m_L2Solver->SetLambda(mat);
274 template<
class TInputImage,
class TOutputImage >
280 ProgressReporter progress(
this, threadId, outputRegionForThread.GetNumberOfPixels());
285 Pointer selfClone = this->Clone();
288 ImageRegionIteratorWithIndex<OutputImageType> outputIt(outputPtr, outputRegionForThread );
289 ImageRegionConstIteratorWithIndex<InputImageType> inputIt(inputPtr, outputRegionForThread );
290 ImageRegionIteratorWithIndex<MaskImageType> maskIt;
291 if (this->IsMaskUsed())
292 maskIt = ImageRegionIteratorWithIndex<MaskImageType>(this->m_MaskImage, outputRegionForThread);
298 unsigned int numberOfCoeffcients = outputPtr->GetNumberOfComponentsPerPixel();;
299 outputPixel.SetSize(numberOfCoeffcients);
300 unsigned int numberofDWIs = inputPtr->GetNumberOfComponentsPerPixel();
301 inputPixel.SetSize(numberofDWIs);
304 outputIt.GoToBegin();
305 VectorType dwiPixel(numberofDWIs), coef(numberOfCoeffcients);
306 while( !inputIt.IsAtEnd() )
308 if (!this->IsMaskUsed() || (this->IsMaskUsed() && maskIt.Get()>0))
310 inputPixel=inputIt.Get();
311 for (
int i = 0; i < numberofDWIs; i += 1 )
312 dwiPixel[i] = -std::log(inputPixel[i]);
316 selfClone->m_L2Solver->Solve();
317 coef = selfClone->m_L2Solver->Getx();
318 for (
int i = 0; i < numberOfCoeffcients; i += 1 )
319 outputPixel[i] = coef[i];
324 outputPixel.
Fill(0.0);
326 outputIt.Set(outputPixel);
327 progress.CompletedPixel();
329 if (this->IsMaskUsed())
336 template<
class TInputImage,
class TOutputImage >
341 Superclass::PrintSelf(os, indent);
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
void ComputeRegularizationWeight() ITK_OVERRIDE
utl_shared_ptr< MatrixType > MatrixPointer
std::vector< double > STDVectorType
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
#define utlException(cond, expout)
SmartPointer< Self > Pointer
Superclass::InputImageConstPointer InputImageConstPointer
utl_shared_ptr< STDVectorType > STDVectorPointer
std::shared_ptr< NDArray< T, 2 > > ComputeSHMatrix(const unsigned int rank, const NDArray< T, 2 > &grad, const int mode)
utl_shared_ptr< VectorType > VectorPointer
GeneralizedHighOrderTensorImageFilter()
#define utlGlobalException(cond, expout)
void Fill(const T &value)
int RankToDim(const bool is_radial=false, const int radialRank=-1, const int shRank=-1) const ITK_OVERRIDE
Superclass::InputImagePixelType InputImagePixelType
Superclass::OutputImagePixelType OutputImagePixelType
Superclass::OutputImagePointer OutputImagePointer
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, ThreadIdType threadId) ITK_OVERRIDE
Superclass::OutputImageRegionType OutputImageRegionType
void VerifyInputParameters() const ITK_OVERRIDE
void BeforeThreadedGenerateData() ITK_OVERRIDE
#define utlShowPosition(cond)
std::vector< int > DimToRank(const int dimm) const ITK_OVERRIDE
void ComputeBasisMatrix() ITK_OVERRIDE
void ComputeRadialMatrix() ITK_OVERRIDE
double ComputeScale(const bool setScale=true) ITK_OVERRIDE
T max_element(const std::vector< T > &v)