From c48ff25bfc5cbe3f7a4f30d5dd93420a4e73366b Mon Sep 17 00:00:00 2001 From: eric bylaska Date: Wed, 13 Dec 2023 12:25:20 -0800 Subject: [PATCH] ...EJB --- Nwpw/nwpwlib/device/gdevices_cuda.hpp | 107 ++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/Nwpw/nwpwlib/device/gdevices_cuda.hpp b/Nwpw/nwpwlib/device/gdevices_cuda.hpp index 7f5ea8a6..7b7c804b 100644 --- a/Nwpw/nwpwlib/device/gdevices_cuda.hpp +++ b/Nwpw/nwpwlib/device/gdevices_cuda.hpp @@ -885,6 +885,113 @@ class Gdevices { inuse[i_st1] = false; inuse[i_sa1] = false; } + + + /************************************** + * * + * NN_zgemm * + * * + **************************************/ + void NN_zgemm(int m, int n, int k, + double alpha, + double *host_a, int lda, + double *host_b, int ldb, + double beta, + double *host_c,int ldc) + { + int ia = fetch_dev_mem_indx(((size_t)lda) * ((size_t)k)); + int ib = fetch_dev_mem_indx(((size_t)ldb) * ((size_t)n)); + int ic = fetch_dev_mem_indx(((size_t)ldc) * ((size_t)n)); + + NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(lda, k, sizeof(double), host_a, lda, dev_mem[ia], lda, stream[0])); + NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(ldb, n, sizeof(double), host_b, ldb, dev_mem[ib], ldb, stream[0])); + + NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0])); + NWPW_CUBLAS_ERROR(cublasDgemm(master_handle,matN,matN,m,n,k, + &alpha, + dev_mem[ia],lda, + dev_mem[ib],ldb, + &beta, + dev_mem[ic],ldc)); + NWPW_CUBLAS_ERROR(cublasGetMatrixAsync(ldc,n,sizeof(double),dev_mem[ic],ldc,host_c,ldc,stream[0])); + NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0])); + + inuse[ia] = false; + inuse[ib] = false; + inuse[ic] = false; + } + + + /************************************** + * * + * CN_zgemm * + * * + **************************************/ + void CN_zgemm(int m, int n, int k, + double alpha, + double *host_a, int lda, + double *host_b, int ldb, + double beta, + double *host_c,int ldc) + { + int ia = fetch_dev_mem_indx(((size_t)lda) * ((size_t)k)); + int ib = fetch_dev_mem_indx(((size_t)ldb) * ((size_t)n)); + int ic = fetch_dev_mem_indx(((size_t)ldc) * ((size_t)n)); + + NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(lda, k, sizeof(double), host_a, lda, dev_mem[ia], lda, stream[0])); + NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(ldb, n, sizeof(double), host_b, ldb, dev_mem[ib], ldb, stream[0])); + + NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0])); + NWPW_CUBLAS_ERROR(cublasDgemm(master_handle,matN,matN,m,n,k, + &alpha, + dev_mem[ia],lda, + dev_mem[ib],ldb, + &beta, + dev_mem[ic],ldc)); + NWPW_CUBLAS_ERROR(cublasGetMatrixAsync(ldc,n,sizeof(double),dev_mem[ic],ldc,host_c,ldc,stream[0])); + NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0])); + + inuse[ia] = false; + inuse[ib] = false; + inuse[ic] = false; + } + + /************************************** + * * + * NC_zgemm * + * * + **************************************/ + void NC_zgemm(int m, int n, int k, + double alpha, + double *host_a, int lda, + double *host_b, int ldb, + double beta, + double *host_c,int ldc) + { + int ia = fetch_dev_mem_indx(((size_t)lda) * ((size_t)k)); + int ib = fetch_dev_mem_indx(((size_t)ldb) * ((size_t)n)); + int ic = fetch_dev_mem_indx(((size_t)ldc) * ((size_t)n)); + + NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(lda, k, sizeof(double), host_a, lda, dev_mem[ia], lda, stream[0])); + NWPW_CUBLAS_ERROR(cublasSetMatrixAsync(ldb, n, sizeof(double), host_b, ldb, dev_mem[ib], ldb, stream[0])); + + NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0])); + NWPW_CUBLAS_ERROR(cublasDgemm(master_handle,matN,matN,m,n,k, + &alpha, + dev_mem[ia],lda, + dev_mem[ib],ldb, + &beta, + dev_mem[ic],ldc)); + NWPW_CUBLAS_ERROR(cublasGetMatrixAsync(ldc,n,sizeof(double),dev_mem[ic],ldc,host_c,ldc,stream[0])); + NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[0])); + + inuse[ia] = false; + inuse[ib] = false; + inuse[ic] = false; + } + + + /********************/ /* psi_dev functions*/