Skip to content

Commit

Permalink
Merge pull request #121 from abouteiller/bugfix/auxcuda-macros
Browse files Browse the repository at this point in the history
The macros to convert trans/notrans etc were not correct for use inline
  • Loading branch information
abouteiller authored Sep 6, 2024
2 parents 5fa144b + f7dac24 commit 17c6c95
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 74 deletions.
61 changes: 25 additions & 36 deletions src/dplasmaaux_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,32 @@
*/
#if !defined(CUBLAS_H_)
#include <cublas_v2.h>
#endif /* !defined(CUBLAS_V2_H_) */

#define dplasma_cublas_side(side) \
assert( (side == dplasmaRight) || (side == dplasmaLeft) ); \
side = (side == dplasmaRight) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;


#define dplasma_cublas_diag(diag) \
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) ); \
diag = (diag == dplasmaNonUnit) ? CUBLAS_DIAG_NON_UNIT : CUBLAS_DIAG_UNIT;

#define dplasma_cublas_fill(fill) \
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) ); \
fill = (fill == dplasmaLower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;

#if defined(PRECISION_z) || defined(PRECISION_c)
#define dplasma_cublas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) ); \
switch(trans){ \
case dplasmaNoTrans: \
trans = CUBLAS_OP_N; \
break; \
case dplasmaTrans: \
trans = CUBLAS_OP_T; \
break; \
case dplasmaConjTrans: \
trans = CUBLAS_OP_C; \
break; \
default: \
trans = CUBLAS_OP_N; \
break; \
}
#include "dplasma/constants.h"

static inline cublasSideMode_t dplasma_cublas_side(int side) {
assert( (side == dplasmaRight) || (side == dplasmaLeft) );
return (side == dplasmaRight) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
}

static inline cublasDiagType_t dplasma_cublas_diag(int diag) {
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) );
return (diag == dplasmaNonUnit) ? CUBLAS_DIAG_NON_UNIT : CUBLAS_DIAG_UNIT;
}

static inline cublasFillMode_t dplasma_cublas_fill(int fill) {
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) );
return (fill == dplasmaLower) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
}

static inline cublasOperation_t dplasma_cublas_op(int trans) {
#if defined(PRECISION_d) || defined(PRECISION_s)
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) );
#else
#define dplasma_cublas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) ); \
trans = (trans == dplasmaNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
#endif /* PRECISION_z || PRECISION_c */
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) );
#endif /* PRECISION_d || PRECISION_s */
return (trans == dplasmaConjTrans) ? CUBLAS_OP_C: ((trans == dplasmaTrans) ? CUBLAS_OP_T : CUBLAS_OP_N);
}
#endif /* !defined(CUBLAS_V2_H_) */

extern parsec_info_id_t dplasma_dtd_cuda_infoid;
extern parsec_info_id_t dplasma_dtd_cuda_workspace_infoid;
Expand Down
50 changes: 20 additions & 30 deletions src/dplasmaaux_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,31 @@
#include <hipsolver/hipsolver.h>
#include <rocsolver/rocsolver.h>

#define dplasma_hipblas_side(side) \
assert( (side == dplasmaRight) || (side == dplasmaLeft) ); \
side = (side == dplasmaRight) ? HIPBLAS_SIDE_RIGHT : HIPBLAS_SIDE_LEFT;
#include "dplasma/constants.h"

static inline hipblasSideMode_t dplasma_hipblas_side(int side) {
assert( (side == dplasmaRight) || (side == dplasmaLeft) );
return (side == dplasmaRight) ? HIPBLAS_SIDE_RIGHT : HIPBLAS_SIDE_LEFT;
}

#define dplasma_hipblas_diag(diag) \
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) ); \
diag = (diag == dplasmaNonUnit) ? HIPBLAS_DIAG_NON_UNIT : HIPBLAS_DIAG_UNIT;
static inline hipblasDiagType_t dplasma_hipblas_diag(int diag) {
assert( (diag == dplasmaNonUnit) || (diag == dplasmaUnit) );
return (diag == dplasmaNonUnit) ? HIPBLAS_DIAG_NON_UNIT : HIPBLAS_DIAG_UNIT;
}

#define dplasma_hipblas_fill(fill) \
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) ); \
fill = (fill == dplasmaLower) ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
static inline hipblasFillMode_t dplasma_hipblas_fill(int fill) {
assert( (fill == dplasmaLower) || (fill == dplasmaUpper) );
return (fill == dplasmaLower) ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
}

#if defined(PRECISION_z) || defined(PRECISION_c)
#define dplasma_hipblas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) ); \
switch(trans){ \
case dplasmaNoTrans: \
trans = HIPBLAS_OP_N; \
break; \
case dplasmaTrans: \
trans = HIPBLAS_OP_T; \
break; \
case dplasmaConjTrans: \
trans = HIPBLAS_OP_C; \
break; \
default: \
trans = HIPBLAS_OP_N; \
break; \
}
static inline hipblasOperation_t dplasma_hipblas_op(int trans) {
#if defined(PRECISION_d) || defined(PRECISION_s)
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) );
#else
#define dplasma_hipblas_op(trans) \
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) ); \
trans = (trans == dplasmaNoTrans) ? HIPBLAS_OP_N : HIPBLAS_OP_T;
#endif /* PRECISION_z || PRECISION_c */
assert( (trans == dplasmaNoTrans) || (trans == dplasmaTrans) || (trans == dplasmaConjTrans) );
#endif /* PRECISION_d || PRECISION_s */
return (trans == dplasmaConjTrans) ? HIPBLAS_OP_C: ((trans == dplasmaTrans) ? HIPBLAS_OP_T : HIPBLAS_OP_N);
}

extern parsec_info_id_t dplasma_dtd_hip_infoid;

Expand Down
5 changes: 1 addition & 4 deletions src/dtd_wrappers/zgemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ parsec_core_zgemm_cuda(parsec_device_gpu_module_t* gpu_device,
double betag = beta;
#endif

dplasma_cublas_op(transA);
dplasma_cublas_op(transB);

#if defined(PARSEC_DEBUG_NOISIER)
{
char tmp[MAX_TASK_STRLEN];
Expand All @@ -80,7 +77,7 @@ parsec_core_zgemm_cuda(parsec_device_gpu_module_t* gpu_device,

parsec_cuda_exec_stream_t* cuda_stream = (parsec_cuda_exec_stream_t*)gpu_stream;
cublasSetStream( handles->cublas_handle, cuda_stream->cuda_stream );
status = cublasZgemm(handles->cublas_handle, transA, transB,
status = cublasZgemm(handles->cublas_handle, dplasma_cublas_op(transA), dplasma_cublas_op(transB),
n, m, k,
&alphag, (cuDoubleComplex*)Ag, lda,
(cuDoubleComplex*)Bg, ldb,
Expand Down
5 changes: 1 addition & 4 deletions src/dtd_wrappers/zherk.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ parsec_core_zherk_cuda(parsec_device_gpu_module_t* gpu_device,
Ag = parsec_dtd_get_dev_ptr(this_task, 0);
Cg = parsec_dtd_get_dev_ptr(this_task, 1);

dplasma_cublas_op(trans);
dplasma_cublas_fill(uplo);

handles = parsec_info_get(&gpu_stream->infos, dplasma_dtd_cuda_infoid);

#if defined(PARSEC_DEBUG_NOISIER)
Expand All @@ -68,7 +65,7 @@ parsec_core_zherk_cuda(parsec_device_gpu_module_t* gpu_device,

parsec_cuda_exec_stream_t* cuda_stream = (parsec_cuda_exec_stream_t*)gpu_stream;
cublasSetStream( handles->cublas_handle, cuda_stream->cuda_stream );
status = cublasZherk(handles->cublas_handle, uplo, trans,
status = cublasZherk(handles->cublas_handle, dplasma_cublas_fill(uplo), dplasma_cublas_op(trans),
m, n,
&alpha, (cuDoubleComplex*)Ag, lda,
&beta, (cuDoubleComplex*)Cg, ldc );
Expand Down

0 comments on commit 17c6c95

Please sign in to comment.