DMRITool  v0.1.1-139-g860d86b4
Diffusion MRI Tool
utlVNLBlas.h
Go to the documentation of this file.
1 
18 #ifndef __utlVNLBlas_h
19 #define __utlVNLBlas_h
20 
21 #include "utlCore.h"
22 #include "utlVNL.h"
23 #include "utlBlas.h"
24 
25 #ifdef UTL_USE_MKL
26 #include "utlMKL.h"
27 #endif
28 
29 namespace utl
30 {
31 
44 template <class T> inline void
45 MatrixCopy(const vnl_matrix<T>& mat, vnl_matrix<T>& matOut, const T alpha, const char trans='N')
46 {
47 #ifdef UTL_USE_MKL
48  if (trans=='N' || trans=='n' || trans=='R' || trans=='r')
49  matOut.set_size(mat.rows(), mat.cols());
50  else if (trans=='T' || trans=='t' || trans=='C' || trans=='c')
51  matOut.set_size(mat.cols(), mat.rows());
52  else
53  utlException(true, "wrong trans");
54  if ((trans=='N' || trans=='n') && std::fabs(alpha-1.0)<1e-10 )
55  utl::cblas_copy(mat.rows()*mat.cols(), mat.data_block(), 1, matOut.data_block(), 1);
56  else
57  utl::mkl_omatcopy<T>('R', trans, mat.rows(), mat.cols(), alpha, mat.data_block(), mat.cols(), matOut.data_block(), matOut.cols());
58 #else
59  if (trans=='N' || trans=='n')
60  matOut = mat;
61  else if (trans=='T' || trans=='t')
62  matOut = mat.transpose();
63  else if (trans=='C' || trans=='c')
64  matOut = mat.conjugate_transpose();
65  else if (trans=='R' || trans=='r')
66  {
67  matOut = mat;
68  vnl_c_vector<T>::conjugate(matOut.begin(),matOut.begin(),matOut.size());
69  }
70 
71  if (std::fabs(alpha-1)>1e-10)
72  matOut *= alpha;
73 #endif
74 }
75 
76 
117 __utl_gemm_MatrixTimesMatrix(T, gemm_VnlMatrixTimesMatrix, Vnl, vnl_matrix<T>, rows, cols, data_block, set_size);
118 
119 
140 __utl_gemv_MatrixTimesVector(T, gemv_VnlMatrixTimesVector, Vnl, vnl_matrix<T>, rows, cols, data_block, vnl_vector<T>, size, data_block, set_size);
141 
142 
161 __utl_gevm_MatrixTimesVector(T, gemm_VnlVectorTimesMatrix, Vnl, vnl_matrix<T>, rows, cols, data_block, vnl_vector<T>, size, data_block, set_size);
162 
185 __utl_syrk_Matrix(T, syrk_VnlMatrix, Vnl, vnl_matrix<T>, rows, cols, data_block, set_size);
186 
187 template <class T>
188 inline T
189 InnerProduct(const vnl_vector<T>& v1, const vnl_vector<T>& v2)
190 {
191  utlSAException(v1.size() != v2.size())(v1.size())(v2.size()).msg("vector sizes mismatch");
192  return utl::cblas_dot<T>(v1.size(), v1.data_block(), 1, v2.data_block(), 1);
193 }
194 
196 template <class T>
197 inline void
198 OuterProduct(const vnl_vector<T>& v1, const vnl_vector<T>& v2, vnl_matrix<T>& mat, const double alpha=1.0)
199 {
200  int M = v1.size(), N = v2.size();
201  if (M!=mat.rows() || N!=mat.cols())
202  {
203  mat.set_size(M, N);
204  mat.fill(0.0);
205  }
206  utl::cblas_ger<T>(CblasRowMajor, M, N, alpha, v1.data_block(), 1, v2.data_block(), 1, mat.data_block(), mat.cols());
207 }
208 
210 template <class T>
211 inline void
212 OuterProduct(const vnl_vector<T>& v1, vnl_matrix<T>& mat, const double alpha=1.0)
213 {
214  int M = v1.size();
215  if (M!=mat.rows() || M!=mat.cols())
216  {
217  mat.set_size(M, M);
218  mat.fill(0.0);
219  }
220  utl::cblas_syr<T>(CblasRowMajor, CblasUpper, M, alpha, v1.data_block(), 1, mat.data_block(), mat.cols());
221  T* data = mat.data_block();
222  for ( int i = 0; i < M; ++i )
223  for ( int j = 0; j < i; ++j )
224  data[i*M+j] = data[j*M+i];
225 }
226 
227 template <class T>
228 inline void
229 GetRow(const vnl_matrix<T>& mat, const int index, vnl_vector<T>& v1)
230 {
231  if (v1.size()!=mat.cols())
232  v1.set_size(mat.cols());
233  utl::cblas_copy<T>(v1.size(), mat.data_block()+mat.cols()*index, 1, v1.data_block(), 1);
234 }
235 
236 template <class T>
237 inline void
238 GetColumn(const vnl_matrix<T>& mat, const int index, vnl_vector<T>& v1)
239 {
240  if (v1.size()!=mat.rows())
241  v1.set_size(mat.rows());
242  utl::cblas_copy<T>(v1.size(), mat.data_block()+index, mat.cols(), v1.data_block(), 1);
243 }
244 
245 
248 }
249 
250 
251 #endif
252 
void cblas_copy(const INTT N, const T *X, const INTT incX, T *Y, const INTT incY)
void GetColumn(const vnl_matrix< T > &mat, const int index, vnl_vector< T > &v1)
Definition: utlVNLBlas.h:238
bool gemm_VnlMatrixTimesMatrix(const bool bATrans, const bool bBTrans, const T alpha, const vnl_matrix< T > &A, const vnl_matrix< T > &B, const T beta, vnl_matrix< T > &C)
Definition: utlVNLBlas.h:117
#define utlException(cond, expout)
Definition: utlCoreMacro.h:548
void syrk_VnlMatrix(const bool trans, const T alpha, const vnl_matrix< T > &A, const T beta, vnl_matrix< T > &C)
syrk_VnlMatrix
Definition: utlVNLBlas.h:185
void MatrixCopy(const vnl_matrix< T > &mat, vnl_matrix< T > &matOut, const T alpha, const char trans='N')
MatrixCopy. A := alpha * op(A)
Definition: utlVNLBlas.h:45
#define __utl_gemv_MatrixTimesVector(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, MatrixGetDataFuncName, VectorName, GetSizeFuncName, VectorGetDataFuncName, ReSizeFuncName)
Definition: utlBlas.h:551
Definition: utl.h:90
#define __utl_syrk_Matrix(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, GetDataFuncName, ReSizeFuncName)
macro for syrk and row-major matrix
Definition: utlBlas.h:646
#define utlSAException(expr)
Definition: utlCoreMacro.h:543
void OuterProduct(const TVector1 &v1, const int N1, const TVector2 &v2, const int N2, TMatrix &mat)
Definition: utlMath.h:1299
#define __utl_gemm_MatrixTimesMatrix(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, GetDataFuncName, ReSizeFuncName)
Definition: utlBlas.h:433
bool gemv_VnlMatrixTimesVector(const bool bATrans, const T alpha, const vnl_matrix< T > &A, const vnl_vector< T > &X, const T beta, vnl_vector< T > &Y)
Definition: utlVNLBlas.h:140
double InnerProduct(const TVector1 &v1, const TVector2 &v2, const int N1)
Definition: utlMath.h:1322
help functions for VNL
#define __utl_gevm_MatrixTimesVector(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, MatrixGetDataFuncName, VectorName, GetSizeFuncName, VectorGetDataFuncName, ReSizeFuncName)
Definition: utlBlas.h:593
void GetRow(const vnl_matrix< T > &mat, const int index, vnl_vector< T > &v1)
Definition: utlVNLBlas.h:229
bool gemm_VnlVectorTimesMatrix(const bool bATrans, const T alpha, const vnl_vector< T > &X, const vnl_matrix< T > &A, const T beta, vnl_vector< T > &Y)
Definition: utlVNLBlas.h:161