18 #ifndef __utlVNLBlas_h 19 #define __utlVNLBlas_h 44 template <
class T>
inline void 45 MatrixCopy(
const vnl_matrix<T>& mat, vnl_matrix<T>& matOut,
const T alpha,
const char trans=
'N')
48 if (trans==
'N' || trans==
'n' || trans==
'R' || trans==
'r')
49 matOut.set_size(mat.rows(), mat.cols());
50 else if (trans==
'T' || trans==
't' || trans==
'C' || trans==
'c')
51 matOut.set_size(mat.cols(), mat.rows());
54 if ((trans==
'N' || trans==
'n') && std::fabs(alpha-1.0)<1e-10 )
55 utl::cblas_copy(mat.rows()*mat.cols(), mat.data_block(), 1, matOut.data_block(), 1);
57 utl::mkl_omatcopy<T>(
'R', trans, mat.rows(), mat.cols(), alpha, mat.data_block(), mat.cols(), matOut.data_block(), matOut.cols());
59 if (trans==
'N' || trans==
'n')
61 else if (trans==
'T' || trans==
't')
62 matOut = mat.transpose();
63 else if (trans==
'C' || trans==
'c')
64 matOut = mat.conjugate_transpose();
65 else if (trans==
'R' || trans==
'r')
68 vnl_c_vector<T>::conjugate(matOut.begin(),matOut.begin(),matOut.size());
71 if (std::fabs(alpha-1)>1e-10)
140 __utl_gemv_MatrixTimesVector(T,
gemv_VnlMatrixTimesVector, Vnl, vnl_matrix<T>, rows, cols, data_block, vnl_vector<T>, size, data_block, set_size);
161 __utl_gevm_MatrixTimesVector(T,
gemm_VnlVectorTimesMatrix, Vnl, vnl_matrix<T>, rows, cols, data_block, vnl_vector<T>, size, data_block, set_size);
191 utlSAException(v1.size() != v2.size())(v1.size())(v2.size()).msg(
"vector sizes mismatch");
192 return utl::cblas_dot<T>(v1.size(), v1.data_block(), 1, v2.data_block(), 1);
198 OuterProduct(
const vnl_vector<T>& v1,
const vnl_vector<T>& v2, vnl_matrix<T>& mat,
const double alpha=1.0)
200 int M = v1.size(), N = v2.size();
201 if (M!=mat.rows() || N!=mat.cols())
206 utl::cblas_ger<T>(
CblasRowMajor, M, N, alpha, v1.data_block(), 1, v2.data_block(), 1, mat.data_block(), mat.cols());
212 OuterProduct(
const vnl_vector<T>& v1, vnl_matrix<T>& mat,
const double alpha=1.0)
215 if (M!=mat.rows() || M!=mat.cols())
221 T* data = mat.data_block();
222 for (
int i = 0; i < M; ++i )
223 for (
int j = 0; j < i; ++j )
224 data[i*M+j] = data[j*M+i];
229 GetRow(
const vnl_matrix<T>& mat,
const int index, vnl_vector<T>& v1)
231 if (v1.size()!=mat.cols())
232 v1.set_size(mat.cols());
233 utl::cblas_copy<T>(v1.size(), mat.data_block()+mat.cols()*index, 1, v1.data_block(), 1);
238 GetColumn(
const vnl_matrix<T>& mat,
const int index, vnl_vector<T>& v1)
240 if (v1.size()!=mat.rows())
241 v1.set_size(mat.rows());
242 utl::cblas_copy<T>(v1.size(), mat.data_block()+index, mat.cols(), v1.data_block(), 1);
void cblas_copy(const INTT N, const T *X, const INTT incX, T *Y, const INTT incY)
void GetColumn(const vnl_matrix< T > &mat, const int index, vnl_vector< T > &v1)
bool gemm_VnlMatrixTimesMatrix(const bool bATrans, const bool bBTrans, const T alpha, const vnl_matrix< T > &A, const vnl_matrix< T > &B, const T beta, vnl_matrix< T > &C)
#define utlException(cond, expout)
void syrk_VnlMatrix(const bool trans, const T alpha, const vnl_matrix< T > &A, const T beta, vnl_matrix< T > &C)
syrk_VnlMatrix
void MatrixCopy(const vnl_matrix< T > &mat, vnl_matrix< T > &matOut, const T alpha, const char trans='N')
MatrixCopy. A := alpha * op(A)
#define __utl_gemv_MatrixTimesVector(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, MatrixGetDataFuncName, VectorName, GetSizeFuncName, VectorGetDataFuncName, ReSizeFuncName)
#define __utl_syrk_Matrix(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, GetDataFuncName, ReSizeFuncName)
macro for syrk and row-major matrix
#define utlSAException(expr)
void OuterProduct(const TVector1 &v1, const int N1, const TVector2 &v2, const int N2, TMatrix &mat)
#define __utl_gemm_MatrixTimesMatrix(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, GetDataFuncName, ReSizeFuncName)
bool gemv_VnlMatrixTimesVector(const bool bATrans, const T alpha, const vnl_matrix< T > &A, const vnl_vector< T > &X, const T beta, vnl_vector< T > &Y)
double InnerProduct(const TVector1 &v1, const TVector2 &v2, const int N1)
#define __utl_gevm_MatrixTimesVector(T, FuncName, FuncHelperName, RowMajorMatrixName, GetRowsFuncName, GetColsFuncName, MatrixGetDataFuncName, VectorName, GetSizeFuncName, VectorGetDataFuncName, ReSizeFuncName)
void GetRow(const vnl_matrix< T > &mat, const int index, vnl_vector< T > &v1)
bool gemm_VnlVectorTimesMatrix(const bool bATrans, const T alpha, const vnl_vector< T > &X, const vnl_matrix< T > &A, const T beta, vnl_vector< T > &Y)