DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
itkGeneralizedHighOrderTensorImageFilter.hxx
Go to the documentation of this file.
1 
18 #ifndef __itkGeneralizedHighOrderTensorImageFilter_hxx
19 #define __itkGeneralizedHighOrderTensorImageFilter_hxx
20 
22 #include "itkProgressReporter.h"
23 
24 namespace itk
25 {
26 template< class TInputImage, class TOutputImage >
29 {
30 }
31 
32 template< class TInputImage, class TOutputImage >
33 double
35 ::ComputeScale(const bool setScale)
36 {
37  utlShowPosition(this->GetDebug());
38  double scale = -1;
39  STDVectorPointer bVector = this->m_SamplingSchemeQSpace->GetBVector();
40  MatrixPointer qOrientations = this->m_SamplingSchemeQSpace->GetOrientationsSpherical();
41  // utl::PrintVector(*this->m_SamplingSchemeQSpace->GetBVector(), "this->m_SamplingSchemeQSpace->GetBVector()");
42  // utl::PrintUtlMatrix(*qOrientations, "qOrientations");
43  utlGlobalException(bVector->size()==0, "no b values");
44  utlGlobalException(qOrientations->Rows()==0, "no gradients");
45  double bMax = *std::max_element(this->m_SamplingSchemeQSpace->GetBVector()->begin(), this->m_SamplingSchemeQSpace->GetBVector()->end());
46  scale = 0.5*bMax/(4*M_PI*M_PI*this->m_SamplingSchemeQSpace->GetTau());
47  if (setScale)
48  {
49  this->SetBasisScale(scale);
50  }
51  if (this->GetDebug())
52  std::cout << "m_BasisScale = " << this->m_BasisScale << std::endl;
53  return scale;
54 
55 }
56 
57 template< class TInputImage, class TOutputImage >
58 std::vector<int>
60 ::DimToRank ( const int dimm ) const
61 {
62  std::vector<int> result;
63  int radialRank=-1, shRank=-1;
64  for ( int radialRank = 1; radialRank <= 10; radialRank += 1 )
65  {
66  for ( int shRank = 0; shRank <= 12; shRank += 2 )
67  {
68  int dim = RankToDim(false, radialRank, shRank);
69  if (dim==dimm)
70  {
71  result.push_back(radialRank);
72  result.push_back(shRank);
73  return result;
74  }
75  }
76  }
77  utlException(true, "wrong logic");
78  return result;
79 }
80 
81 template< class TInputImage, class TOutputImage >
82 int
84 ::RankToDim (const bool is_radial, const int radialRank, const int shRank) const
85 {
86  int radialRank_real = radialRank>=0?radialRank:this->m_RadialRank;
87  int shRank_real = shRank>=0?shRank:this->m_SHRank;
88  if (is_radial)
89  return radialRank_real;
90  else
91  return (shRank_real + 1)*(shRank_real + 2)/2*(radialRank_real);
92 }
93 
94 template< class TInputImage, class TOutputImage >
95 void
98 {
99  utlShowPosition(this->GetDebug());
100  Superclass::VerifyInputParameters();
101  utlGlobalException(this->m_RadialRank<=0, "m_RadialRank should be no less than 1");
102  utlGlobalException(this->m_EstimationType==Superclass::L1_2, "TODO");
103  utlGlobalException(this->m_EstimationType==Superclass::L1_DL, "TODO");
104 }
105 
106 template< class TInputImage, class TOutputImage >
107 void
110 {
111  utlShowPosition(this->GetDebug());
112 
113  int n_s, n_b;
114  n_s = this->m_SamplingSchemeQSpace->GetBVector()->size();
115  utlGlobalException( n_s==0, "no b vector");
116 
117  if (this->GetDebug())
118  std::cout << "m_BasisScale = " << this->m_BasisScale << std::endl;
119 
120  const STDVectorPointer bVector = this->m_SamplingSchemeQSpace->GetBVector();
121 
122  // NOTE: b = 4\pi^2 * _tau * q^2
123  STDVectorType qVector(*bVector);
124  for ( int i = 0; i < bVector->size(); i += 1 )
125  {
126  qVector[i] = std::sqrt((*bVector)[i]/(4*M_PI*M_PI*this->m_SamplingSchemeQSpace->GetTau()));
127  }
128 
129 
130  // the basis of order N has N+1 terms (1,...,N), NOTE: does not start from 0
131  n_b = this->m_RadialRank;
132 
133 
134  if(this->GetDebug())
135  std::cout << "Generating the "<< n_s << "x" << n_b << " RadialMatrix...\n";
136 
137  MatrixPointer B(new MatrixType(n_s,n_b));
138 
139  for ( int js = 0; js < n_s; js += 1 )
140  {
141  double x_temp = qVector[js] / std::sqrt(this->m_BasisScale);
142  for ( int ib = 0; ib < n_b; ib += 1 )
143  {
144  (*B)(js,ib) = std::pow(x_temp, 2.0*(ib+1));
145  }
146  }
147 
148 
149  if(this->GetDebug())
150  {
151  std::cout << "Generated the "<< n_s << "x" << n_b << " RadialMatrix...\n";
152  utl::PrintUtlMatrix(*B,"RadialMatrix");
153  }
154 
155  this->m_BasisRadialMatrix = B;
156 
157 }
158 
159 template< class TInputImage, class TOutputImage >
160 void
163 {
164  utlShowPosition(this->GetDebug());
165 
166  if (this->m_BasisSHMatrix->Rows()==0)
167  this->ComputeSHMatrix();
168  MatrixPointer B_sh = this->m_BasisSHMatrix;
169 
170  if (this->m_BasisRadialMatrix->Rows()==0)
171  this->ComputeRadialMatrix();
172  MatrixPointer B_ra = this->m_BasisRadialMatrix;
173 
174  MatrixPointer qOrientations = this->m_SamplingSchemeQSpace->GetOrientationsSpherical();
175 
176  int bVector_size = this->m_SamplingSchemeQSpace->GetBVector()->size();
177  int grad_size = qOrientations->Rows();
178 
179  if (this->GetDebug())
180  {
181  std::cout << "this->m_SamplingSchemeQSpace->GetBVector()->size() = " << this->m_SamplingSchemeQSpace->GetBVector()->size() << std::endl;
182  std::cout << "qOrientations->Rows() = " << qOrientations->Rows() << std::endl;
183  std::cout << "bVector_size = " << bVector_size << std::endl;
184  std::cout << "grad_size = " << grad_size << std::endl;
185  }
186  utlException(bVector_size!=grad_size, "bVector_size and grad_size should keep the same size");
187  int n_s = bVector_size;
188 
189  // the basis has two parts
190  int n_b_ra = this->m_RadialRank;
191  int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
192  int n_b = n_b_ra * n_b_sh;
193 
194  if(this->GetDebug())
195  {
196  std::cout << "n_b_ra = " << n_b_ra << std::endl;
197  std::cout << "n_b_sh = " << n_b_sh << std::endl;
198  std::cout << "n_b = " << n_b << std::endl;
199  std::cout << "Generating the "<< n_s << "x" << n_b << " basis matrix...\n";
200  }
201 
202  MatrixPointer B(new MatrixType(n_s, n_b));
203  utlException(B_sh->Columns()!=n_b_sh, "the SHMatrix does not have the right width. B_sh=" << *B_sh << ", n_b_sh="<< n_b_sh);
204  utlException(B_sh->Rows()!=B_ra->Rows(), "the SHMatrix and the RadialMatrix do not have the same samples. B_sh="<< *B_sh << ", B_ra=" << *B_ra);
205  utlException(B_sh->Rows()!=B->Rows(), "the SHMatrix and the basisMatrix do not have the same samples. B_sh="<< *B_sh << ", B="<< *B);
206 
207  utlException(B_ra->Columns()!=n_b_ra, "the RadialMatrix does not have the right width");
208  for ( int i = 0; i < n_b_ra; i += 1 )
209  for ( int j = 0; j < n_b_sh; j += 1 )
210  for ( int k = 0; k < n_s; k += 1 )
211  {
212  (*B)(k,i*n_b_sh+j) = (*B_ra)(k,i) * (*B_sh)(k,j);
213  }
214 
215  if(this->GetDebug())
216  {
217  utl::PrintUtlMatrix(*B,"BasisMatrix");
218  }
219 
220  this->m_BasisMatrix = B;
221 }
222 
223 template< class TInputImage, class TOutputImage >
224 void
227 {
228  // utlShowPosition(this->GetDebug());
229 
230  int n_b_sh = (this->m_SHRank+1)*(this->m_SHRank+2)/2;
231  int n_b_ra = this->m_RadialRank;
232 
233  this->m_RegularizationWeight = VectorPointer(new VectorType(this->RankToDim()));
234  for ( int i = 0; i <= this->m_RadialRank-1; i += 1 )
235  {
236  int j = 0;
237  for ( int l = 0; l <= this->m_SHRank; l += 2 )
238  {
239  for ( int m = -l; m <= l; m += 1 )
240  {
241  (*this->m_RegularizationWeight)(i*n_b_sh+j) = this->m_LambdaSpherical*l*l*(l+1)*(l+1) + this->m_LambdaRadial*i*i*(i+1)*(i+1);
242  j++;
243  }
244  }
245  }
246 }
247 
248 template< class TInputImage, class TOutputImage >
249 void
252 {
253  utlShowPosition(this->GetDebug());
254  ComputeScale();
255  ComputeBasisMatrix();
256  this->VerifyInputParameters();
257  this->m_L2Solver = L2SolverType::New();
258  this->m_L2Solver->SetA(this->m_BasisMatrix);
259  if (this->m_LambdaSpherical>0 || this->m_LambdaRadial>0)
260  {
261  this->ComputeRegularizationWeight();
262  MatrixPointer mat(new MatrixType(this->m_RegularizationWeight->Size(), this->m_RegularizationWeight->Size()));
263  mat->SetDiagonal(*this->m_RegularizationWeight);
264  this->m_L2Solver->SetLambda(mat);
265  }
266  // if (this->GetDebug())
267  // {
268  // this->m_L2Solver->Initialize();
269  // MatrixType ls = this->m_L2Solver->GetLS();
270  // utl::PrintUtlMatrix(ls,"LS");
271  // }
272 }
273 
274 template< class TInputImage, class TOutputImage >
275 void
277 ::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,ThreadIdType threadId )
278 {
279  utlShowPosition(this->GetDebug());
280  ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
281  // Pointers
282  InputImageConstPointer inputPtr = this->GetInput();
283  OutputImagePointer outputPtr = this->GetOutput();
284 
285  Pointer selfClone = this->Clone();
286 
287  // iterator for the output image
288  ImageRegionIteratorWithIndex<OutputImageType> outputIt(outputPtr, outputRegionForThread );
289  ImageRegionConstIteratorWithIndex<InputImageType> inputIt(inputPtr, outputRegionForThread );
290  ImageRegionIteratorWithIndex<MaskImageType> maskIt;
291  if (this->IsMaskUsed())
292  maskIt = ImageRegionIteratorWithIndex<MaskImageType>(this->m_MaskImage, outputRegionForThread);
293 
294  InputImagePixelType inputPixel;
295  // OutputImageIndexType outputIndex;
296  OutputImagePixelType outputPixel;
297 
298  unsigned int numberOfCoeffcients = outputPtr->GetNumberOfComponentsPerPixel();;
299  outputPixel.SetSize(numberOfCoeffcients);
300  unsigned int numberofDWIs = inputPtr->GetNumberOfComponentsPerPixel();
301  inputPixel.SetSize(numberofDWIs);
302 
303  inputIt.GoToBegin();
304  outputIt.GoToBegin();
305  VectorType dwiPixel(numberofDWIs), coef(numberOfCoeffcients);
306  while( !inputIt.IsAtEnd() )
307  {
308  if (!this->IsMaskUsed() || (this->IsMaskUsed() && maskIt.Get()>0))
309  {
310  inputPixel=inputIt.Get();
311  for ( int i = 0; i < numberofDWIs; i += 1 )
312  dwiPixel[i] = -std::log(inputPixel[i]);
313  selfClone->m_L2Solver->Setb(VectorPointer(new VectorType(dwiPixel)));
314  // outputIndex=outputIt.GetIndex();
315  // std::cout << "index="<<outputIndex << std::endl << std::flush;
316  selfClone->m_L2Solver->Solve();
317  coef = selfClone->m_L2Solver->Getx();
318  for ( int i = 0; i < numberOfCoeffcients; i += 1 )
319  outputPixel[i] = coef[i];
320  // utl::PrintContainer(inputPixel.GetDataPointer(), inputPixel.GetDataPointer()+inputPixel.GetSize(), "dwi");
321  // utl::PrintContainer(outputPixel.GetDataPointer(), inputPixel.GetDataPointer()+outputPixel.GetSize(), "coef");
322  }
323  else
324  outputPixel.Fill(0.0);
325 
326  outputIt.Set(outputPixel);
327  progress.CompletedPixel();
328 
329  if (this->IsMaskUsed())
330  ++maskIt;
331  ++outputIt;
332  ++inputIt;
333  }
334 }
335 
336 template< class TInputImage, class TOutputImage >
337 void
339 ::PrintSelf(std::ostream& os, Indent indent) const
340 {
341  Superclass::PrintSelf(os, indent);
342 }
343 
344 }
345 
346 #endif
void PrintSelf(std::ostream &os, Indent indent) const ITK_OVERRIDE
void PrintUtlMatrix(const NDArray< T, 2 > &mat, const std::string &str="", const char *separate=" ", std::ostream &os=std::cout)
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
std::shared_ptr< NDArray< T, 2 > > ComputeSHMatrix(const unsigned int rank, const NDArray< T, 2 > &grad, const int mode)
Definition: utl.h:171
#define M_PI
Definition: utlCoreMacro.h:57
#define utlGlobalException(cond, expout)
Definition: utlCoreMacro.h:372
void Fill(const T &value)
Definition: utlNDArray.h:922
int RankToDim(const bool is_radial=false, const int radialRank=-1, const int shRank=-1) const ITK_OVERRIDE
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, ThreadIdType threadId) ITK_OVERRIDE
#define utlShowPosition(cond)
Definition: utlCoreMacro.h:554
std::vector< int > DimToRank(const int dimm) const ITK_OVERRIDE
double ComputeScale(const bool setScale=true) ITK_OVERRIDE
T max_element(const std::vector< T > &v)
Definition: utlCore.h:205