DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
test_Spams2.cxx
Go to the documentation of this file.
1 
2 /* Software SPAMS v2.1 - Copyright 2009-2011 Julien Mairal
3  *
4  * This file is part of SPAMS.
5  *
6  * SPAMS is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * SPAMS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with SPAMS. If not, see <http://www.gnu.org/licenses/>.
18  */
19 
33 #include <mexutils.h>
34 #include <decomp.h>
35 
36 template <typename T>
37  inline void callFunction(mxArray* plhs[], const mxArray*prhs[]) {
38  if (!utl::mexCheckType<T>(prhs[0]))
39  mexErrMsgTxt("type of argument 1 is not consistent");
40  if (mxIsSparse(prhs[0]))
41  mexErrMsgTxt("argument 1 should be full");
42  if (!utl::mexCheckType<T>(prhs[1]))
43  mexErrMsgTxt("type of argument 2 is not consistent");
44  if (mxIsSparse(prhs[1]))
45  mexErrMsgTxt("argument 2 should be full");
46  if (mxIsSparse(prhs[2]))
47  mexErrMsgTxt("argument 3 should be full");
48  if (!mxIsStruct(prhs[3]))
49  mexErrMsgTxt("argument 4 should be struct");
50 
51 
52  T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
53  const mwSize* dimsX=mxGetDimensions(prhs[0]);
54  int n=static_cast<int>(dimsX[0]);
55  int M=static_cast<int>(dimsX[1]);
56 
57  T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
58  const mwSize* dimsD=mxGetDimensions(prhs[1]);
59  int nD=static_cast<int>(dimsD[0]);
60  int K=static_cast<int>(dimsD[1]);
61  if (n != nD) mexErrMsgTxt("argument sizes are not consistent");
62 
63  T lambda = utl::GetScalarStruct<T>(prhs[3],"lambda");
64  int L = utl::GetScalarStructDef<int>(prhs[3],"L",K);
65  int numThreads = utl::GetScalarStructDef<int>(prhs[3],"numThreads",-1);
66  bool pos = utl::GetScalarStructDef<bool>(prhs[3],"pos",false);
67  spams::constraint_type mode = (spams::constraint_type)utl::GetScalarStructDef<int>(prhs[3],"mode",spams::PENALTY);
68  if (L > n) {
69  printf("L is changed to %d\n",n);
70  L=n;
71  }
72  if (L > K) {
73  printf("L is changed to %d\n",K);
74  L=K;
75  }
76  spams::Matrix<T> X(prX,n,M);
77  spams::Matrix<T> D(prD,n,K);
78 
79 
80  T* prWeight = reinterpret_cast<T*>(mxGetPr(prhs[2]));
81  const mwSize* dimsW=mxGetDimensions(prhs[2]);
82  int KK=static_cast<int>(dimsW[0]);
83  int MM=static_cast<int>(dimsW[1]);
84  if (K != KK || M != MM) mexErrMsgTxt("argument sizes are not consistent");
85 
86 
87  spams::Matrix<T> weight(prWeight,KK,MM);
88 
89  spams::SpMatrix<T> alpha;
90  spams::lassoWeight<T>(X,D,weight,alpha,L,lambda,mode,pos,numThreads);
91  utl::ConvertSpMatrix(plhs[0],alpha.m(),alpha.n(),alpha.n(),
92  alpha.nzmax(),alpha.v(),alpha.r(),alpha.pB());
93  }
94 
95  void mexFunction(int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[]) {
96  if (nrhs != 4)
97  mexErrMsgTxt("Bad number of inputs arguments");
98 
99  if (!(nlhs == 1 || nlhs == 1))
100  mexErrMsgTxt("Bad number of output arguments");
101 
102  if (mxGetClassID(prhs[0]) == mxDOUBLE_CLASS) {
103  callFunction<double>(plhs,prhs);
104  } else {
105  callFunction<float>(plhs,prhs);
106  }
107  }
108 
109 
110 
111 
Sparse Matrix class.
Definition: linalg.h:63
constraint_type
Definition: decomp.h:88
void callFunction(mxArray *plhs[], const mxArray *prhs[])
Definition: test_Spams2.cxx:37
void ConvertSpMatrix(mxArray *&matlab_mat, int K, int M, int n, int nzmax, const T *v, const int *r, const int *pB)
convert sparse matrix to Matlab sparse matrix
Definition: mexutils.h:236
int mwSize
Definition: utils.h:38
Contains miscellaneous functions for mex files. utl functions for mex code. Some codes are from spams...
Contains sparse decomposition algorithms It requires the toolbox linalg.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Definition: test_Spams2.cxx:95
Dense Matrix class.
Definition: linalg.h:61