DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
dicts.h
Go to the documentation of this file.
1 
2 /* Software SPAMS v2.1 - Copyright 2009-2011 Julien Mairal
3  *
4  * This file is part of SPAMS.
5  *
6  * SPAMS is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * SPAMS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with SPAMS. If not, see <http://www.gnu.org/licenses/>.
18  */
19 
32 #ifndef DICTS_H
33 #define DICTS_H
34 
35 #include <decomp.h>
36 
37 namespace spams
38 {
39 
40 static char buffer_string[50];
43 
44 template <typename T> struct ParamDictLearn {
45  public:
47  mode(PENALTY),
48  posAlpha(false),
49  modeD(L2),
50  posD(false),
51  modeParam(AUTO),
52  t0(1e-5),
53  rho(5),
54  gamma1(0),
55  mu(0),
56  lambda3(0),
57  lambda4(0),
58  lambda2(0),
59  gamma2(0),
60  approx(0.0),
61  p(1.0),
62  whiten(false),
63  expand(false),
64  isConstant(false),
65  updateConstant(true),
66  ThetaDiag(false),
67  ThetaDiagPlus(false),
68  ThetaId(false),
69  DequalsW(false),
70  weightClasses(false),
71  balanceClasses(false),
72  extend(false),
73  pattern(false),
74  stochastic(false),
75  scaleW(false),
76  batch(false),
77  verbose(true),
78  clean(true),
79  log(false),
80  updateD(true),
81  updateW(true),
82  updateTheta(true),
83  logName(NULL),
84  iter_updateD(1) { };
85  ~ParamDictLearn() { delete[](logName); };
86  int iter;
87  T lambda;
89  bool posAlpha;
91  bool posD;
93  T t0;
94  T rho;
95  T gamma1;
96  T mu;
102  T p;
103  bool whiten;
104  bool expand;
107  bool ThetaDiag;
109  bool ThetaId;
110  bool DequalsW;
113  bool extend;
114  bool pattern;
116  bool scaleW;
117  bool batch;
118  bool verbose;
119  bool clean;
120  bool log;
121  bool updateD;
122  bool updateW;
124  char* logName;
126 };
127 
128 template <typename T> class Trainer {
129  public:
131  Trainer();
133  Trainer(const int k, const int batchsize = 256,
134  const int NUM_THREADS=-1);
136  Trainer(const Matrix<T>& D, const int batchsize = 256,
137  const int NUM_THREADS=-1);
139  Trainer(const Matrix<T>& A, const Matrix<T>& B, const Matrix<T>& D,
140  const int itercount, const int batchsize,
141  const int NUM_THREADS);
142 
144  void train(const Data<T>& X, const ParamDictLearn<T>& param);
145  void trainOffline(const Data<T>& X, const ParamDictLearn<T>& param);
146 
148  void train(const Data<T>& X, const vector_groups& groups,
149  const int J, const constraint_type
150  mode, const bool whiten = false, const T* param_C = NULL,
151  const int p = 1, const bool pattern = false);
152 
154  void getA(Matrix<T>& A) const { A.copy(_A);};
155  void getB(Matrix<T>& B) const { B.copy(_B);};
156  void getD(Matrix<T>& D) const { D.copy(_D);};
157  int getIter() const { return _itercount; };
158 
159  private:
161  explicit Trainer<T>(const Trainer<T>& trainer);
163  Trainer<T>& operator=(const Trainer<T>& trainer);
164 
166  void cleanDict(const Data<T>& X, Matrix<T>& G,
167  const bool posD = false,
168  const constraint_type_D modeD = L2, const T gamma1 = 0,
169  const T gamma2 = 0,
170  const T maxCorrel =
171  0.999999);
172 
174  void cleanDict(Matrix<T>& G);
175 
179  int _k;
184 };
185 
187 template <typename T> Trainer<T>::Trainer() : _k(0), _initialDict(false),
188  _itercount(0), _batchsize(256) {
189  _NUM_THREADS=1;
190 #ifdef _OPENMP
191  _NUM_THREADS = MIN(MAX_THREADS,omp_get_num_procs());
192 #endif
193  _batchsize=floor(_batchsize*(_NUM_THREADS+1)/2);
194  };
195 
197 template <typename T> Trainer<T>::Trainer(const int k, const
198  int batchsize, const int NUM_THREADS) : _k(k),
199  _initialDict(false), _itercount(0),_batchsize(batchsize),
200  _NUM_THREADS(NUM_THREADS) {
201  if (_NUM_THREADS == -1) {
202  _NUM_THREADS=1;
203 #ifdef _OPENMP
204  _NUM_THREADS = MIN(MAX_THREADS,omp_get_num_procs());
205 #endif
206  }
207  };
208 
210 template <typename T> Trainer<T>::Trainer(const Matrix<T>& D,
211  const int batchsize, const int NUM_THREADS) : _k(D.n()),
212  _initialDict(true),_itercount(0),_batchsize(batchsize),
213  _NUM_THREADS(NUM_THREADS) {
214  _D.copy(D);
215  _A.resize(D.n(),D.n());
216  _B.resize(D.m(),D.n());
217  if (_NUM_THREADS == -1) {
218  _NUM_THREADS=1;
219 #ifdef _OPENMP
220  _NUM_THREADS = MIN(MAX_THREADS,omp_get_num_procs());
221 #endif
222  }
223  }
224 
226 template <typename T> Trainer<T>::Trainer(const Matrix<T>& A, const Matrix<T>&
227  B, const Matrix<T>& D, const int itercount, const int batchsize,
228  const int NUM_THREADS) : _k(D.n()),_initialDict(true),_itercount(itercount),
229  _batchsize(batchsize),
230  _NUM_THREADS(NUM_THREADS) {
231  _D.copy(D);
232  _A.copy(A);
233  _B.copy(B);
234  if (_NUM_THREADS == -1) {
235  _NUM_THREADS=1;
236 #ifdef _OPENMP
237  _NUM_THREADS = MIN(MAX_THREADS,omp_get_num_procs());
238 #endif
239  }
240  };
241 
242 template <typename T>
244  const bool posD,
245  const constraint_type_D modeD, const T gamma1,
246  const T gamma2,
247  const T maxCorrel) {
248  int sparseD = modeD == L1L2 ? 2 : 6;
249  const int k = _D.n();
250  const int n = _D.m();
251  const int M = X.n();
252  T* const pr_G=G.rawX();
253  Vector<T> aleat(n);
254  Vector<T> col(n);
255  for (int i = 0; i<k; ++i) {
256  //pr_G[i*k+i] += 1e-10;
257  for (int j = i; j<k; ++j) {
258  if ((j > i && abs(pr_G[i*k+j])/sqrt(pr_G[i*k+i]*pr_G[j*k+j]) > maxCorrel) ||
259  (j == i && abs(pr_G[i*k+j]) < 1e-4)) {
261  const int ind = random() % M;
262  Vector<T> d, g;
263  _D.refCol(j,d);
264  X.getData(col,ind);
265  d.copy(col);
266  if (modeD != L2) {
267  aleat.copy(d);
268  aleat.sparseProject(d,T(1.0),sparseD,gamma1,gamma2,T(2.0),posD);
269  } else {
270  if (posD) d.thrsPos();
271  d.normalize();
272  }
273  G.refCol(j,g);
274  _D.multTrans(d,g);
275  for (int l = 0; l<_D.n(); ++l)
276  pr_G[l*k+j] = pr_G[j*k+l];
277  }
278  }
279  }
280 }
281 
282 
283 template <typename T>
285  const int k = _D.n();
286  const int n = _D.m();
287  T* const pr_G=G.rawX();
288  for (int i = 0; i<k; ++i) {
289  pr_G[i*k+i] += 1e-10;
290  }
291 }
292 
293 
294 template <typename T>
295 void Trainer<T>::train(const Data<T>& X, const ParamDictLearn<T>& param) {
296 
297  T rho = param.rho;
298  T t0 = param.t0;
299  int sparseD = param.modeD == L1L2 ? 2 : param.modeD == L1L2MU ? 7 : 6;
300  int NUM_THREADS=init_omp(_NUM_THREADS);
301  if (param.verbose) {
302  cout << "num param iterD: " << param.iter_updateD << endl;
303  if (param.batch) {
304  cout << "Batch Mode" << endl;
305  } else if (param.stochastic) {
306  cout << "Stochastic Gradient. rho : " << rho << ", t0 : " << t0 << endl;
307  } else {
308  if (param.modeParam == AUTO) {
309  cout << "Online Dictionary Learning with no parameter " << endl;
310  } else if (param.modeParam == PARAM1) {
311  cout << "Online Dictionary Learning with parameters: " << t0 << " rho: " << rho << endl;
312  } else {
313  cout << "Online Dictionary Learning with exponential decay t0: " << t0 << " rho: " << rho << endl;
314  }
315  }
316  if (param.posD)
317  cout << "Positivity constraints on D activated" << endl;
318  if (param.posAlpha)
319  cout << "Positivity constraints on alpha activated" << endl;
320  if (param.modeD != L2) cout << "Sparse dictionaries, mode: " << param.modeD << ", gamma1: " << param.gamma1 << ", gamma2: " << param.gamma2 << endl;
321  cout << "mode Alpha " << param.mode << endl;
322  if (param.clean) cout << "Cleaning activated " << endl;
323  if (param.log && param.logName) {
324  cout << "log activated " << endl;
325  cerr << param.logName << endl;
326  }
327  if (param.mode == PENALTY && param.lambda==0 && param.lambda2 > 0 && !param.posAlpha)
328  cout << "L2 solver is used" << endl;
329  if (_itercount > 0)
330  cout << "Retraining from iteration " << _itercount << endl;
331  flush(cout);
332  }
333 
334  const int M = X.n();
335  const int K = _k;
336  const int n = X.m();
337  const int L = param.mode == SPARSITY ? static_cast<int>(param.lambda) :
338  param.mode == PENALTY && param.lambda == 0 && param.lambda2 > 0 && !param.posAlpha ? K : MIN(n,K);
339  const int batchsize= param.batch ? M : MIN(_batchsize,M);
340 
341  if (param.verbose) {
342  cout << "batch size: " << batchsize << endl;
343  cout << "L: " << L << endl;
344  cout << "lambda: " << param.lambda << endl;
345  cout << "mode: " << param.mode << endl;
346  flush(cout);
347  }
348 
349  if (_D.m() != n || _D.n() != K)
350  _initialDict=false;
351 
352  srandom(0);
353  Vector<T> col(n);
354  if (!_initialDict) {
355  _D.resize(n,K);
356  for (int i = 0; i<K; ++i) {
357  const int ind = random() % M;
358  Vector<T> d;
359  _D.refCol(i,d);
360  X.getData(col,ind);
361  d.copy(col);
362  }
363  _initialDict=true;
364  }
365 
366  if (param.verbose) {
367  cout << "*****Online Dictionary Learning*****" << endl;
368  flush(cout);
369  }
370 
371  Vector<T> tmp(n);
372  if (param.modeD != L2) {
373  for (int i = 0; i<K; ++i) {
374  Vector<T> d;
375  _D.refCol(i,d);
376  tmp.copy(d);
377  tmp.sparseProject(d,T(1.0),sparseD,param.gamma1,
378  param.gamma2,T(2.0),param.posD);
379  }
380  } else {
381  if (param.posD) _D.thrsPos();
382  _D.normalize();
383  }
384 
385  int count=0;
386  int countPrev=0;
387  T scalt0 = abs<T>(t0);
388  if (_itercount == 0) {
389  _A.resize(K,K);
390  _A.setZeros();
391  _B.resize(n,K);
392  _B.setZeros();
393  if (!param.batch) {
394  _A.setDiag(scalt0);
395  _B.copy(_D);
396  _B.scal(scalt0);
397  }
398  }
399 
400  //Matrix<T> G(K,K);
401 
402  Matrix<T> Borig(n,K);
403  Matrix<T> Aorig(K,K);
404  Matrix<T> Bodd(n,K);
405  Matrix<T> Aodd(K,K);
406  Matrix<T> Beven(n,K);
407  Matrix<T> Aeven(K,K);
408  SpVector<T>* spcoeffT=new SpVector<T>[_NUM_THREADS];
409  Vector<T>* DtRT=new Vector<T>[_NUM_THREADS];
413  Matrix<T>* GsT=new Matrix<T>[_NUM_THREADS];
414  Matrix<T>* GaT=new Matrix<T>[_NUM_THREADS];
415  Matrix<T>* invGsT=new Matrix<T>[_NUM_THREADS];
416  Matrix<T>* workT=new Matrix<T>[_NUM_THREADS];
418  for (int i = 0; i<_NUM_THREADS; ++i) {
419  spcoeffT[i].resize(K);
420  DtRT[i].resize(K);
421  XT[i].resize(n);
422  BT[i].resize(n,K);
423  BT[i].setZeros();
424  AT[i].resize(K,K);
425  AT[i].setZeros();
426  GsT[i].resize(L,L);
427  GsT[i].setZeros();
428  invGsT[i].resize(L,L);
429  invGsT[i].setZeros();
430  GaT[i].resize(K,L);
431  GaT[i].setZeros();
432  workT[i].resize(K,3);
433  workT[i].setZeros();
434  uT[i].resize(L);
435  uT[i].setZeros();
436  }
437 
438  Timer time, time2;
439  time.start();
440  srandom(0);
441  Vector<int> perm;
442  perm.randperm(M);
443 
444  Aodd.setZeros();
445  Bodd.setZeros();
446  Aeven.setZeros();
447  Beven.setZeros();
448  Aorig.copy(_A);
449  Borig.copy(_B);
450 
451  int JJ = param.iter < 0 ? 100000000 : param.iter;
452  bool even=true;
453  int last_written=-40;
454  int i;
455  for (i = 0; i<JJ; ++i) {
456  if (param.verbose && i%100==0) {
457  cout << "Iteration: " << i << endl;
458  flush(cout);
459  }
460  time.stop();
461  if (param.iter < 0 &&
462  time.getElapsed() > T(-param.iter)) break;
463  if (param.log) {
464  int seconds=static_cast<int>(floor(log(time.getElapsed())*5));
465  if (seconds > last_written) {
466  last_written++;
467  sprintf(buffer_string,"%s_%d.log",param.logName,
468  last_written+40);
469  writeLog(_D,T(time.getElapsed()),i,buffer_string);
470  fprintf(stderr,"\r%d",i);
471  }
472  }
473  time.start();
474 
475  Matrix<T> G;
476  _D.XtX(G);
477  if (param.clean)
478  this->cleanDict(X,G,param.posD,
479  param.modeD,param.gamma1,param.gamma2);
480  G.addDiag(MAX(param.lambda2,1e-10));
481  int j;
482  for (j = 0; j<_NUM_THREADS; ++j) {
483  AT[j].setZeros();
484  BT[j].setZeros();
485  }
486 
487 #pragma omp parallel for private(j)
488  for (j = 0; j<batchsize; ++j) {
489 #ifdef _OPENMP
490  int numT=omp_get_thread_num();
491 #else
492  int numT=0;
493 #endif
494  const int index=perm[(j+i*batchsize) % M];
495  Vector<T>& Xj = XT[numT];
496  SpVector<T>& spcoeffj = spcoeffT[numT];
497  Vector<T>& DtRj = DtRT[numT];
498  //X.refCol(index,Xj);
499  X.getData(Xj,index);
500  if (param.whiten) {
501  if (param.pattern) {
502  Vector<T> mean(4);
503  Xj.whiten(mean,param.pattern);
504  } else {
505  Xj.whiten(X.V());
506  }
507  }
508  _D.multTrans(Xj,DtRj);
509  Matrix<T>& Gs = GsT[numT];
510  Matrix<T>& Ga = GaT[numT];
511  Matrix<T>& invGs = invGsT[numT];
512  Matrix<T>& work= workT[numT];
513  Vector<T>& u = uT[numT];
514  Vector<int> ind;
515  Vector<T> coeffs_sparse;
516  spcoeffj.setL(L);
517  spcoeffj.refIndices(ind);
518  spcoeffj.refVal(coeffs_sparse);
519  T normX=Xj.nrm2sq();
520  coeffs_sparse.setZeros();
521  if (param.mode < SPARSITY) {
522  if (param.mode == PENALTY && param.lambda==0 && param.lambda2 > 0 && !param.posAlpha) {
523  Matrix<T>& GG = G;
524  u.set(0);
525  GG.conjugateGradient(DtRj,u,1e-4,2*K);
526  for (int k = 0; k<K; ++k) {
527  ind[k]=k;
528  coeffs_sparse[k]=u[k];
529  }
530  } else {
531  coreLARS2(DtRj,G,Gs,Ga,invGs,u,coeffs_sparse,ind,work,normX,param.mode,param.lambda,param.posAlpha);
532  }
533  } else {
534  if (param.mode == SPARSITY) {
535  coreORMPB(DtRj,G,ind,coeffs_sparse,normX,L,T(0.0),T(0.0));
536  } else if (param.mode==L2ERROR2) {
537  coreORMPB(DtRj,G,ind,coeffs_sparse,normX,L,param.lambda,T(0.0));
538  } else {
539  coreORMPB(DtRj,G,ind,coeffs_sparse,normX,L,T(0.0),param.lambda);
540  }
541  }
542  int count2=0;
543  for (int k = 0; k<L; ++k)
544  if (ind[k] == -1) {
545  break;
546  } else {
547  ++count2;
548  }
549  sort(ind.rawX(),coeffs_sparse.rawX(),0,count2-1);
550  spcoeffj.setL(count2);
551  AT[numT].rank1Update(spcoeffj);
552  BT[numT].rank1Update(Xj,spcoeffj);
553  }
554 
555  if (param.batch) {
556  _A.setZeros();
557  _B.setZeros();
558  for (j = 0; j<_NUM_THREADS; ++j) {
559  _A.add(AT[j]);
560  _B.add(BT[j]);
561  }
562  Vector<T> di, ai,bi;
563  Vector<T> newd(n);
564  for (j = 0; j<param.iter_updateD; ++j) {
565  for (int k = 0; k<K; ++k) {
566  if (_A[k*K+k] > 1e-6) {
567  _D.refCol(k,di);
568  _A.refCol(k,ai);
569  _B.refCol(k,bi);
570  _D.mult(ai,newd,T(-1.0));
571  newd.add(bi);
572  newd.scal(T(1.0)/_A[k*K+k]);
573  newd.add(di);
574  if (param.modeD != L2) {
575  newd.sparseProject(di,T(1.0),
576  sparseD,param.gamma1,
577  param.gamma2,T(2.0),param.posD);
578  } else {
579  if (param.posD) newd.thrsPos();
580  newd.normalize2();
581  di.copy(newd);
582  }
583  } else if (param.clean) {
584  _D.refCol(k,di);
585  di.setZeros();
586  }
587  }
588  }
589  } else if (param.stochastic) {
590  _A.setZeros();
591  _B.setZeros();
592  for (j = 0; j<_NUM_THREADS; ++j) {
593  _A.add(AT[j]);
594  _B.add(BT[j]);
595  }
596  _D.mult(_A,_B,false,false,T(-1.0),T(1.0));
597  T step_grad=rho/T(t0+batchsize*(i+1));
598  _D.add(_B,step_grad);
599  Vector<T> dj;
600  Vector<T> dnew(n);
601  if (param.modeD != L2) {
602  for (j = 0; j<K; ++j) {
603  _D.refCol(j,dj);
604  dnew.copy(dj);
605  dnew.sparseProject(dj,T(1.0),sparseD,param.gamma1,
606  param.gamma2,T(2.0),param.posD);
607  }
608  } else {
609  for (j = 0; j<K; ++j) {
610  _D.refCol(j,dj);
611  if (param.posD) dj.thrsPos();
612  dj.normalize2();
613  }
614  }
615  } else {
616 
619  int epoch = (((i+1) % M)*batchsize) / M;
620  if ((even && ((epoch % 2) == 1)) || (!even && ((epoch % 2) == 0))) {
621  Aodd.copy(Aeven);
622  Bodd.copy(Beven);
623  Aeven.setZeros();
624  Beven.setZeros();
625  count=countPrev;
626  countPrev=0;
627  even=!even;
628  }
629 
630  int ii=_itercount+i;
631  int num_elem=MIN(2*M, ii < batchsize ? ii*batchsize :
632  batchsize*batchsize+ii-batchsize);
633  T scal2=T(T(1.0)/batchsize);
634  T scal;
635  int totaliter=_itercount+count;
636  if (param.modeParam == PARAM2) {
637  scal=param.rho;
638  } else if (param.modeParam == PARAM1) {
639  scal=MAX(0.95,pow(T(totaliter)/T(totaliter+1),-rho));
640  } else {
641  scal = T(_itercount+num_elem+1-
642  batchsize)/T(_itercount+num_elem+1);
643  }
644  Aeven.scal(scal);
645  Beven.scal(scal);
646  Aodd.scal(scal);
647  Bodd.scal(scal);
648  if ((_itercount > 0 && i*batchsize < M)
649  || (_itercount == 0 && t0 != 0 &&
650  i*batchsize < 10000)) {
651  Aorig.scal(scal);
652  Borig.scal(scal);
653  _A.copy(Aorig);
654  _B.copy(Borig);
655  } else {
656  _A.setZeros();
657  _B.setZeros();
658  }
659  for (j = 0; j<_NUM_THREADS; ++j) {
660  Aeven.add(AT[j],scal2);
661  Beven.add(BT[j],scal2);
662  }
663  _A.add(Aodd);
664  _A.add(Aeven);
665  _B.add(Bodd);
666  _B.add(Beven);
667  ++count;
668  ++countPrev;
669 
670  Vector<T> di, ai,bi;
671  Vector<T> newd(n);
672  for (j = 0; j<param.iter_updateD; ++j) {
673  for (int k = 0; k<K; ++k) {
674  if (_A[k*K+k] > 1e-6) {
675  _D.refCol(k,di);
676  _A.refCol(k,ai);
677  _B.refCol(k,bi);
678  _D.mult(ai,newd,T(-1.0));
679  newd.add(bi);
680  newd.scal(T(1.0)/_A[k*K+k]);
681  newd.add(di);
682  if (param.modeD != L2) {
683  newd.sparseProject(di,T(1.0),sparseD,
684  param.gamma1,param.gamma2,T(2.0),param.posD);
685  } else {
686  if (param.posD) newd.thrsPos();
687  newd.normalize2();
688  di.copy(newd);
689  }
690  } else if (param.clean &&
691  ((_itercount+i)*batchsize) > 10000) {
692  _D.refCol(k,di);
693  di.setZeros();
694  }
695  }
696  }
697  }
698  }
699 
700  _itercount += i;
701  if (param.verbose)
702  time.printElapsed();
703  delete[](spcoeffT);
704  delete[](DtRT);
705  delete[](AT);
706  delete[](BT);
707  delete[](GsT);
708  delete[](invGsT);
709  delete[](GaT);
710  delete[](uT);
711  delete[](XT);
712  delete[](workT);
713 };
714 
715 
716 template <typename T>
717 void writeLog(const Matrix<T>& D, const T time, int iter,
718  char* name) {
719  std::ofstream f;
720  f.precision(12);
721  f.flags(std::ios_base::scientific);
722  f.open(name, ofstream::trunc);
723  f << time << " " << iter << std::endl;
724  for (int i = 0; i<D.n(); ++i) {
725  for (int j = 0; j<D.m(); ++j) {
726  f << D[i*D.m()+j] << " ";
727  }
728  f << std::endl;
729  }
730  f << std::endl;
731  f.close();
732 };
733 
734 
735 template <typename T>
737  const ParamDictLearn<T>& param) {
738 
739  int sparseD = param.modeD == L1L2 ? 2 : 6;
740  int J = param.iter;
741  int batch_size= _batchsize;
742  int batchsize= _batchsize;
743  int NUM_THREADS=init_omp(_NUM_THREADS);
744 
745  const int n = X.m();
746  const int K = _k;
747  const int M = X.n();
748  cout << "*****Offline Dictionary Learning*****" << endl;
749  fprintf(stderr,"num param iterD: %d\n",param.iter_updateD);
750  cout << "batch size: " << _batchsize << endl;
751  cout << "lambda: " << param.lambda << endl;
752  cout << "X: " << n << " x " << M << endl;
753  cout << "D: " << n << " x " << K << endl;
754  flush(cout);
755 
756  srandom(0);
757  Vector<T> col(n);
758  if (!_initialDict) {
759  _D.resize(n,K);
760  for (int i = 0; i<K; ++i) {
761  const int ind = random() % M;
762  Vector<T> d;
763  _D.refCol(i,d);
764  X.getData(col,ind);
765  d.copy(col);
766  }
767  _initialDict=true;
768  }
769 
770  Vector<T> tmp(n);
771  if (param.modeD != L2) {
772  for (int i = 0; i<K; ++i) {
773  Vector<T> d;
774  _D.refCol(i,d);
775  tmp.copy(d);
776  tmp.sparseProject(d,T(1.0),sparseD,param.gamma1,
777  param.gamma2,T(2.0),param.posD);
778  }
779  } else {
780  if (param.posD) _D.thrsPos();
781  _D.normalize();
782  }
783 
784  Matrix<T> G(K,K);
785  Matrix<T> coeffs(K,M);
786  coeffs.setZeros();
787 
788  Matrix<T> B(n,K);
789  Matrix<T> A(K,K);
790 
791  SpVector<T>* spcoeffT=new SpVector<T>[NUM_THREADS];
792  Vector<T>* DtRT=new Vector<T>[NUM_THREADS];
793  Vector<T>* coeffsoldT=new Vector<T>[NUM_THREADS];
794  Matrix<T>* BT=new Matrix<T>[NUM_THREADS];
795  Matrix<T>* AT=new Matrix<T>[NUM_THREADS];
796  for (int i = 0; i<NUM_THREADS; ++i) {
797  spcoeffT[i].resize(K);
798  DtRT[i].resize(K);
799  coeffsoldT[i].resize(K);
800  BT[i].resize(n,K);
801  BT[i].setZeros();
802  AT[i].resize(K,K);
803  AT[i].setZeros();
804  }
805 
806  Timer time;
807  time.start();
808  srandom(0);
809  Vector<int> perm;
810  perm.randperm(M);
811  int JJ = J < 0 ? 100000000 : J;
812  Vector<T> weights(M);
813  weights.setZeros();
814 
815  for (int i = 0; i<JJ; ++i) {
816  if (J < 0 && time.getElapsed() > T(-J)) break;
817  _D.XtX(G);
818  if (param.clean)
819  this->cleanDict(X,G,param.posD,
820  param.modeD,param.gamma1,param.gamma2);
821  int j;
822 #pragma omp parallel for private(j)
823  for (j = 0; j<batch_size; ++j) {
824 #ifdef _OPENMP
825  int numT=omp_get_thread_num();
826 #else
827  int numT=0;
828 #endif
829  const int ind=perm[(j+i*batch_size) % M];
830  Vector<T> Xj, coeffj;
831  SpVector<T>& spcoeffj = spcoeffT[numT];
832  Vector<T>& DtRj = DtRT[numT];
833  Vector<T>& oldcoeffj = coeffsoldT[numT];
834  X.getData(Xj,ind);
835  if (param.whiten) {
836  if (param.pattern) {
837  Vector<T> mean(4);
838  Xj.whiten(mean,param.pattern);
839  } else {
840  Xj.whiten(X.V());
841  }
842  }
843  coeffs.refCol(ind,coeffj);
844  oldcoeffj.copy(coeffj);
845  _D.multTrans(Xj,DtRj);
846  coeffj.toSparse(spcoeffj);
847  G.mult(spcoeffj,DtRj,T(-1.0),T(1.0));
848  if (param.mode == PENALTY) {
849  coreIST(G,DtRj,coeffj,param.lambda,200,T(1e-3));
850  } else {
851  T normX = Xj.nrm2sq();
852  coreISTconstrained(G,DtRj,coeffj,normX,param.lambda,200,T(1e-3));
853  }
854  oldcoeffj.toSparse(spcoeffj);
855  AT[numT].rank1Update(spcoeffj,-weights[ind]);
856  coeffj.toSparse(spcoeffj);
857  AT[numT].rank1Update(spcoeffj);
858  weights[ind]++;
859  oldcoeffj.scal(weights[ind]);
860  oldcoeffj.sub(coeffj);
861  oldcoeffj.toSparse(spcoeffj);
862  BT[numT].rank1Update(Xj,spcoeffj,T(-1.0));
863  }
864 
865  A.setZeros();
866  B.setZeros();
867  T scal;
868  int totaliter=i;
869  int ii = i;
870  int num_elem=MIN(2*M, ii < batchsize ? ii*batchsize :
871  batchsize*batchsize+ii-batchsize);
872  if (param.modeParam == PARAM2) {
873  scal=param.rho;
874  } else if (param.modeParam == PARAM1) {
875  scal=MAX(0.95,pow(T(totaliter)/T(totaliter+1),-param.rho));
876  } else {
877  scal = T(num_elem+1-
878  batchsize)/T(num_elem+1);
879  }
880  for (j = 0; j<NUM_THREADS; ++j) {
881  A.add(AT[j]);
882  B.add(BT[j]);
883  AT[j].scal(scal);
884  BT[j].scal(scal);
885  }
886  weights.scal(scal);
887  Vector<T> di, ai,bi;
888  Vector<T> newd(n);
889  for (j = 0; j<param.iter_updateD; ++j) {
890  for (int k = 0; k<K; ++k) {
891  if (A[k*K+k] > 1e-6) {
892  _D.refCol(k,di);
893  A.refCol(k,ai);
894  B.refCol(k,bi);
895  _D.mult(ai,newd,T(-1.0));
896  newd.add(bi);
897  newd.scal(T(1.0)/A[k*K+k]);
898  newd.add(di);
899  if (param.modeD != L2) {
900  newd.sparseProject(di,T(1.0),
901  sparseD,param.gamma1,
902  param.gamma2,T(2.0),param.posD);
903  } else {
904  if (param.posD) newd.thrsPos();
905  newd.normalize2();
906  di.copy(newd);
907  }
908  } else if (param.clean) {
909  _D.refCol(k,di);
910  di.setZeros();
911  }
912  }
913  }
914 
915  if (param.verbose)
916  {
917  Vector<T> l1norms;
918  coeffs.norm_l1_rows(l1norms);
919  double sumL1=0.0;
920  for ( int iii = 0; iii < l1norms.n(); iii += 1 )
921  sumL1 += l1norms[iii];
922  std::cout << "i = "<< i << ", coeffs sumL1 = " << sumL1 << std::endl << std::flush;
923  }
924 
925  }
926  _D.XtX(G);
927  if (param.clean)
928  this->cleanDict(X,G,param.posD,param.modeD,
929  param.gamma1,param.gamma2);
930  time.printElapsed();
931  delete[](spcoeffT);
932  delete[](DtRT);
933  delete[](AT);
934  delete[](BT);
935  delete[](coeffsoldT);
936 }
937 
938 
939 }
940 
941 #endif
942 
void norm_l1_rows(Vector< T > &norms) const
returns the linf norms of the columns
Definition: linalg.h:2193
void normalize2()
normalize the vector
Definition: linalg.h:3128
constraint_type_D modeD
Definition: dicts.h:90
Definition: dag.h:26
void rank1Update(const Vector< T > &vec1, const Vector< T > &vec2, const T alpha=1.0)
perform A <- A + alpha*vec1*vec2&#39;
Definition: linalg.h:2384
void setL(const int L)
Definition: linalg.h:960
virtual int m() const =0
void copy(const Vector< T > &x)
make a copy of x
Definition: linalg.h:2865
void normalize()
normalize the vector
Definition: linalg.h:3122
constraint_type
Definition: decomp.h:88
void sparseProject(Vector< T > &out, const T thrs, const int mode=1, const T lambda1=0, const T lambda2=0, const T lambda3=0, const bool pos=false)
Definition: linalg.h:3508
void scal(const T a)
scale the vector by a
Definition: linalg.h:3277
void coreISTconstrained(const AbstractMatrix< T > &G, Vector< T > &DtR, Vector< T > &coeffs, const T normX2, const T thrs, const int itermax=500, const T tol=0.5)
coreIST constrained
Definition: decomp.h:2113
void stop()
stop the time
Definition: utils.h:149
static char buffer_string[50]
Definition: dicts.h:40
virtual int V() const =0
void add(const Vector< T > &x, const T a=1.0)
A <- A + a*x.
Definition: linalg.h:3029
Matrix< T > _A
Definition: dicts.h:176
void addDiag(const Vector< T > &diag)
Definition: linalg.h:1506
T * rawX() const
returns a modifiable reference of the data, DANGEROUS
Definition: linalg.h:593
int _batchsize
Definition: dicts.h:182
std::vector< group > vector_groups
Definition: linalg.h:73
void cleanDict(const Data< T > &X, Matrix< T > &G, const bool posD=false, const constraint_type_D modeD=L2, const T gamma1=0, const T gamma2=0, const T maxCorrel=0.999999)
clean the dictionary
Definition: dicts.h:243
static void sort(int *irOut, T *prOut, int beg, int end)
Definition: misc.h:141
int getIter() const
Definition: dicts.h:157
void coreORMPB(Vector< T > &RtD, const AbstractMatrix< T > &G, Vector< int > &ind, Vector< T > &coeffs, T &normX, const int L, const T eps, const T lambda=0)
Auxiliary function of omp.
Definition: decomp.h:498
void coreIST(const AbstractMatrix< T > &G, Vector< T > &DtR, Vector< T > &coeffs, const T thrs, const int itermax=500, const T tol=0.5)
coreIST
Definition: decomp.h:2052
void conjugateGradient(const Vector< T > &b, Vector< T > &x, const T tol=1e-4, const int=4) const
compute x, such that b = Ax,
Definition: linalg.h:2482
Matrix< T > _D
Definition: dicts.h:178
int m() const
Definition: linalg.h:222
Class Timer.
Definition: utils.h:138
void train(const Data< T > &X, const ParamDictLearn< T > &param)
train or retrain using the matrix X
Definition: dicts.h:295
int n() const
returns the size of the vector
Definition: linalg.h:591
Data class, abstract class, useful in the class image.
Definition: linalg.h:130
constraint_type_D
Definition: dicts.h:41
#define MIN(a, b)
Definition: utils.h:47
void refCol(int i, Vector< T > &x) const
Reference the column i into the vector x.
Definition: linalg.h:1144
#define MAX(a, b)
Definition: utils.h:48
void setZeros()
Set all values to zero.
Definition: linalg.h:2871
constraint_type mode
Definition: dicts.h:88
virtual void getData(Vector< T > &data, const int i) const =0
T abs(const T x)
template version of the fabs function
void refIndices(Vector< int > &indices) const
create a reference on the vector r
Definition: linalg.h:4920
int _itercount
Definition: dicts.h:181
#define MAX_THREADS
Definition: utils.h:42
Dense Vector class.
Definition: linalg.h:65
void copy(const Matrix< T > &mat)
make a copy of the matrix mat in the current matrix
Definition: linalg.h:1339
void toSparse(SpVector< T > &vec) const
make a sparse copy
Definition: linalg.h:4025
void set(const T val)
set each value of the vector to val
Definition: linalg.h:2997
void thrsPos()
performs soft-thresholding of the vector
Definition: linalg.h:2973
Contains sparse decomposition algorithms It requires the toolbox linalg.
mode_compute modeParam
Definition: dicts.h:92
void getB(Matrix< T > &B) const
Definition: dicts.h:155
Trainer()
Empty constructor.
Definition: dicts.h:187
void mult(const Vector< T > &x, Vector< T > &b, const T alpha=1.0, const T beta=0.0) const
perform b = alpha*A*x+beta*b
Definition: linalg.h:1783
void start()
start the time
Definition: utils.h:146
bool _initialDict
Definition: dicts.h:180
int n() const
Number of columns.
Definition: linalg.h:224
Sparse Vector class.
Definition: linalg.h:67
void getD(Matrix< T > &D) const
Definition: dicts.h:156
void setZeros()
Set all the values to zero.
Definition: linalg.h:1240
void resize(const int n)
resize the vector
Definition: linalg.h:2876
void trainOffline(const Data< T > &X, const ParamDictLearn< T > &param)
Definition: dicts.h:736
void getA(Matrix< T > &A) const
Accessors.
Definition: dicts.h:154
void printElapsed()
print the elapsed time
Definition: utils.h:182
Matrix< T > _B
Definition: dicts.h:177
void refVal(Vector< T > &val) const
creates a reference on the vector val
Definition: linalg.h:4926
void resize(const int nzmax)
resizes the vector
Definition: linalg.h:4971
double getElapsed() const
print the elapsed time
Definition: utils.h:193
static int init_omp(const int numThreads)
Definition: misc.h:264
mode_compute
Definition: dicts.h:42
void coreLARS2(Vector< T > &DtR, const AbstractMatrix< T > &G, Matrix< T > &Gs, Matrix< T > &Ga, Matrix< T > &invGs, Vector< T > &u, Vector< T > &coeffs, Vector< int > &ind, Matrix< T > &work, T &normX, const constraint_type mode, const T constraint, const bool pos=false, T *pr_path=NULL, int length_path=-1)
Auxiliary function for lasso.
Definition: decomp.h:1556
void resize(int m, int n)
Resize the matrix.
Definition: linalg.h:1217
void sub(const Vector< T > &x)
A <- A - x.
Definition: linalg.h:3052
int _NUM_THREADS
Definition: dicts.h:183
Dense Matrix class.
Definition: linalg.h:61
void randperm(int n)
put a random permutation of size n (for integral vectors)
virtual int n() const =0
void add(const Matrix< T > &mat, const T alpha=1.0)
add alpha*mat to the current matrix
Definition: linalg.h:2062
void scal(const T a)
scale the matrix by the a
Definition: linalg.h:1334
T * rawX() const
reference a modifiable reference to the data, DANGEROUS
Definition: linalg.h:254
void writeLog(const Matrix< T > &D, const T time, int iter, char *name)
Definition: dicts.h:717