Skip to content

Commit

Permalink
...EJB
Browse files Browse the repository at this point in the history
  • Loading branch information
ebylaska committed Dec 13, 2023
1 parent 979cea4 commit c48ff25
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions Nwpw/nwpwlib/device/gdevices_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,113 @@ class Gdevices {
inuse[i_st1] = false;
inuse[i_sa1] = false;
}


/**************************************
* *
* NN_zgemm *
* *
**************************************/
void NN_zgemm(int m, int n, int k,
double alpha,
double *host_a, int lda,
double *host_b, int ldb,
double beta,
double *host_c,int ldc)
{
int ia = fetch_dev_mem_indx(((size_t)lda) * ((size_t)k));
int ib = fetch_dev_mem_indx(((size_t)ldb) * ((size_t)n));
int ic = fetch_dev_mem_indx(((size_t)ldc) * ((size_t)n));

NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(lda, k, sizeof(double), host_a, lda, dev_mem[ia], lda, stream[0]));
NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(ldb, n, sizeof(double), host_b, ldb, dev_mem[ib], ldb, stream[0]));

NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0]));
NWPW_CUBLAS_ERROR(cublasDgemm(master_handle,matN,matN,m,n,k,
&alpha,
dev_mem[ia],lda,
dev_mem[ib],ldb,
&beta,
dev_mem[ic],ldc));
NWPW_CUBLAS_ERROR(cublasGetMatrixAsync(ldc,n,sizeof(double),dev_mem[ic],ldc,host_c,ldc,stream[0]));
NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0]));

inuse[ia] = false;
inuse[ib] = false;
inuse[ic] = false;
}


/**************************************
* *
* CN_zgemm *
* *
**************************************/
void CN_zgemm(int m, int n, int k,
double alpha,
double *host_a, int lda,
double *host_b, int ldb,
double beta,
double *host_c,int ldc)
{
int ia = fetch_dev_mem_indx(((size_t)lda) * ((size_t)k));
int ib = fetch_dev_mem_indx(((size_t)ldb) * ((size_t)n));
int ic = fetch_dev_mem_indx(((size_t)ldc) * ((size_t)n));

NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(lda, k, sizeof(double), host_a, lda, dev_mem[ia], lda, stream[0]));
NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(ldb, n, sizeof(double), host_b, ldb, dev_mem[ib], ldb, stream[0]));

NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0]));
NWPW_CUBLAS_ERROR(cublasDgemm(master_handle,matN,matN,m,n,k,
&alpha,
dev_mem[ia],lda,
dev_mem[ib],ldb,
&beta,
dev_mem[ic],ldc));
NWPW_CUBLAS_ERROR(cublasGetMatrixAsync(ldc,n,sizeof(double),dev_mem[ic],ldc,host_c,ldc,stream[0]));
NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0]));

inuse[ia] = false;
inuse[ib] = false;
inuse[ic] = false;
}

/**************************************
* *
* NC_zgemm *
* *
**************************************/
void NC_zgemm(int m, int n, int k,
double alpha,
double *host_a, int lda,
double *host_b, int ldb,
double beta,
double *host_c,int ldc)
{
int ia = fetch_dev_mem_indx(((size_t)lda) * ((size_t)k));
int ib = fetch_dev_mem_indx(((size_t)ldb) * ((size_t)n));
int ic = fetch_dev_mem_indx(((size_t)ldc) * ((size_t)n));

NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(lda, k, sizeof(double), host_a, lda, dev_mem[ia], lda, stream[0]));
NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(ldb, n, sizeof(double), host_b, ldb, dev_mem[ib], ldb, stream[0]));

NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0]));
NWPW_CUBLAS_ERROR(cublasDgemm(master_handle,matN,matN,m,n,k,
&alpha,
dev_mem[ia],lda,
dev_mem[ib],ldb,
&beta,
dev_mem[ic],ldc));
NWPW_CUBLAS_ERROR(cublasGetMatrixAsync(ldc,n,sizeof(double),dev_mem[ic],ldc,host_c,ldc,stream[0]));
NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0]));

inuse[ia] = false;
inuse[ib] = false;
inuse[ic] = false;
}




/********************/
/* psi_dev functions*/
Expand Down

0 comments on commit c48ff25

Please sign in to comment.