forked from DefTruth/CUDA-Learn-Notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhgemm_cublas.cu
82 lines (68 loc) · 2.31 KB
/
hgemm_cublas.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <mma.h>
#include <torch/types.h>
#include <torch/extension.h>
#include "cublas_v2.h"
void cublas_tensor_op_row_major(half *A, half *B, half *C, size_t M,
size_t N, size_t K) {
static cublasHandle_t handle = nullptr;
cublasCreate(&handle);
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
static half alpha = 1.0;
static half beta = 0.0;
cublasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
N, M, K,
&alpha,
B, CUDA_R_16F, N,
A, CUDA_R_16F, K,
&beta,
C, CUDA_R_16F, N,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// why this line will make cublas slow down?
// cublasDestroy(handle);
}
// TODO: add cublas_tensor_op_col_major
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if(((T).options().dtype() != (th_type))) { \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be "#th_type); \
}
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
throw std::runtime_error("Tensor size mismatch!"); \
}
// cublas tensor op with row major B matrix
void hgemm_cublas_tensor_op_row_major(
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
cublas_tensor_op_row_major(
reinterpret_cast<half*>(a.data_ptr()),
reinterpret_cast<half*>(b.data_ptr()),
reinterpret_cast<half*>(c.data_ptr()),
M, N, K
);
}
// TODO: add cublas_tensor_op_col_major