forked from DefTruth/CUDA-Learn-Notes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhgemm.py
221 lines (204 loc) · 11 KB
/
hgemm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch
import time
from torch.utils.cpp_extension import load
from functools import partial
from typing import Optional
import argparse
torch.set_grad_enabled(False)
def get_args():
parser = argparse.ArgumentParser(description="hgemm benchmark")
parser.add_argument("--M", type=int, default=None, help="Matrix M size")
parser.add_argument("--N", type=int, default=None, help="Matrix N size")
parser.add_argument("--K", type=int, default=None, help="Matrix K size")
parser.add_argument("--warmup", "--w", type=int, default=5, help="Warmup iters")
parser.add_argument("--iters", "--i", type=int, default=20, help="Benchmark iters")
parser.add_argument("--show-all", "--show", action="store_true", help="Show all matrix values ")
parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests")
parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests")
parser.add_argument("--enable-cuda", "--cuda", action="store_true", help="Enable CUDA kernel tests")
parser.add_argument("--enable-mma-all", "--mma-all", action="store_true", help="Enable all MMA kernel tests")
parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests")
parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests")
parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul")
parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm")
return parser.parse_args()
args = get_args()
print(args)
# Load the CUDA kernel as a python module
print("Loading hgemm lib ...")
lib = load(name='hgemm_lib',
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
'hgemm_mma.cu', 'hgemm_mma_stage.cu'],
extra_cuda_cflags=[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math"
],
extra_cflags=['-std=c++17'],
verbose=False)
MAX_TFLOPS = -1
def run_benchmark(perf_func: callable,
a: torch.Tensor, b: torch.Tensor,
tag: str, out: Optional[torch.Tensor] = None,
stages: int = -1, swizzle: bool = False,
swizzle_stride: int = 1,
warmup: int = args.warmup,
iters: int = args.iters,
show_all: bool = args.show_all):
global MAX_TFLOPS
M = a.size(0)
K = a.size(1)
N = b.size(1)
if (a.size(0) > 1024 or a.size(1) >= 1024
or b.size(1) > 1024):
iters = 10
if swizzle:
# make swizzle stride as N/4 and multiples of 256
swizzle_stride = int((int(N / 4) // 256) * 256)
swizzle_stride = swizzle_stride if swizzle_stride >= 256 else 1
swizzle = swizzle if swizzle_stride >= 256 else False
else:
swizzle_stride = 1 # means no thread block swizzle
if stages:
assert swizzle_stride is not None
if out is not None:
out.fill_(0)
if out is not None:
for i in range(warmup):
if stages > 1:
perf_func(a, b, out, stages, swizzle, swizzle_stride)
else:
perf_func(a, b, out)
else:
for i in range(warmup):
_ = perf_func(a, b)
torch.cuda.synchronize()
start = time.time()
# iters
if out is not None:
for i in range(iters):
if stages > 1:
perf_func(a, b, out, stages, swizzle, swizzle_stride)
else:
perf_func(a, b, out)
else:
for i in range(iters):
out = perf_func(a, b)
torch.cuda.synchronize()
end = time.time()
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
out_info = f"{tag}"
out_val = out.flatten()[:2].detach().cpu().numpy().tolist()
out_val = [round(v, 8) for v in out_val]
out_val = [f"{v:<12}"[:10] for v in out_val]
TFLOPS = (2 * M * N * K) * 1e-9 / (mean_time)
mean_time = str(f"{mean_time:<12}")[:8]
swizzle_stride = 'NOOP' if swizzle_stride == 1 else swizzle_stride
# caculate TFLOPS improved.
if TFLOPS > MAX_TFLOPS:
if MAX_TFLOPS > 0:
improve = ((TFLOPS - MAX_TFLOPS) / MAX_TFLOPS) * 100
improve = round(improve, 2)
else:
improve = 0
MAX_TFLOPS = TFLOPS
print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)")
else:
print(f"{out_info:>40}: {out_val}, time:{mean_time}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
if show_all: print(out)
time.sleep(0.1)
return out, mean_time
Ms = [4096, 8192, 16384]
Ns = [4096, 8192, 16384]
Ks = [2048, 4096, 8192]
if args.M and args.N and args.K:
Ms = [args.M]
Ns = [args.N]
Ks = [args.K]
MAX_M, MAX_N, MAX_K = max(Ms), max(Ns), max(Ks)
# pre allocate for fast profiling.
torch.cuda.synchronize()
start = time.time()
print(f"pre allocate for fast profiling start, MAX_M={MAX_M}, MAX_N={MAX_N}, MAX_K={MAX_K}")
A = torch.randn((MAX_M, MAX_K), dtype=torch.half).cuda()
B = torch.randn((MAX_K, MAX_N), dtype=torch.half).cuda()
C = torch.randn((MAX_M, MAX_N), dtype=torch.half).cuda()
torch.cuda.synchronize()
end = time.time()
print(f"pre allocate for fast profiling done, time: {(end - start) * 1000} ms")
MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks]
PERF_COUNT = 0
for (M, N, K) in MNKs:
MAX_TFLOPS = -1
PERF_COUNT += 1
print("-" * 130)
print(" " * 40 + f"M={M}, N={N}, K={K}, Warmup={args.warmup}, Iters={args.iters}, {PERF_COUNT}/{len(MNKs)}")
print("-" * 130)
a = A[:M, :K].contiguous()
b = B[:K, :N].contiguous()
c = C[:M, :N].contiguous()
torch.cuda.synchronize()
if args.enable_cuda_all: # more cuda cores kernel tests.
# CUDA Cores FP16
run_benchmark(lib.hgemm_naive_f16, a, b, "(naive)", c)
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "(f16x8pack+t8x8+bcf)", c)
if args.enable_cuda or args.enable_cuda_all:
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "(f16x8pack+t8x8+dbuf)", c)
run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "(f16x8pack+t8x8+k16+dbuf)", c)
if args.enable_wmma or args.enable_wmma_all:
print("-" * 68 + "WMMA" + "-" * 58)
# wmma api, stages, dsmem, swizzle
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4, a, b, "(mma4x2+warp2x4)", c)
# prefer on NVIDIA L20 device.
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage3)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage2)", c, stages=2)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage2+dsmem)", c, stages=2)
# thread block swizzle
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage3+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages, a, b, "(mma4x2+warp2x4+stage2+swizzle)", c, stages=2, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem, a, b, "(mma4x2+warp2x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
# TODO: add MMA PTX kernel tests.
if args.enable_wmma_all: # more wmma kernel tests.
# TODO: add more stages tests for mma2x4/mma4x4, 4,5 etc.
# prefer on NVIDIA TRX 3080 Laptop 16GB GDDR6 device.
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage2+dsmem)", c, stages=2)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage2+dsmem)", c, stages=2)
# thread block swizzle
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
if args.enable_mma or args.enable_mma_all:
print("-" * 68 + "MMA" + "-" * 59)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4, a, b, "(mma2x4+warp4x4)", c)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3)", c, stages=3)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2)", c, stages=2)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem)", c, stages=3)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem)", c, stages=2)
# thread block swizzle
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
if args.enable_mma_all: # more mma kernel tests.
pass
if not args.disable_cublas:
run_benchmark(lib.hgemm_cublas_tensor_op_row_major, a, b, "(cublas)", c)
if args.enable_torch:
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
torch.cuda.synchronize()
print("-" * 130)