Skip to content

Commit

Permalink
[HGEMM] Add PyTorch HGEMM profile (#59)
Browse files Browse the repository at this point in the history
* Create prof.py

* Update .gitignore

* Update .gitignore

* Update README.md
  • Loading branch information
DefTruth authored Sep 30, 2024
1 parent cb869e2 commit 3f5ace3
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ __pycache__
*.pt
*.pth
*.nsys*
*.sqlite
*.ncu*
*.sqlite*
*.engine
18 changes: 18 additions & 0 deletions hgemm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
*.so
*.a
*.dylib
*.dll
*.lib
.DS_Store
build
*.whl
tmp
__pycache__
*.onnx
*.engine
*.pt
*.pth
*.nsys*
*.ncu*
*.sqlite*
*.engine
62 changes: 62 additions & 0 deletions hgemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,68 @@ cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
```

## PyTorch HGEMM Profile

在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn kernel,内部实际使用HMMA(Tensor Cores)进行计算。

```bash
ncu -o hgemm.prof -f python3 prof.py
nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true python3 prof.py
```
- 日志

```bash
==PROF== Connected to process 367502 (/usr/bin/python3.10)
==PROF== Profiling "unrolled_elementwise_kernel" - 0: 0%....50%....100% - 8 passes
==PROF== Profiling "unrolled_elementwise_kernel" - 1: 0%....50%....100% - 8 passes
==PROF== Profiling "unrolled_elementwise_kernel" - 2: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 3: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 4: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 5: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 6: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 7: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 8: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 9: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 10: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 11: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 12: 0%....50%....100% - 8 passes
==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 13: 0%....50%....100% - 8 passes
```

- SASS

```C
310 00007f41 37d5b850 LDSM.16.M88.4 R192, [R169+UR8+0x2000]
311 00007f41 37d5b860 LDSM.16.M88.4 R196, [R169+UR8+0x2800]
312 00007f41 37d5b870 @!P0 BRA.U 0x7f4137d5c3f0
313 00007f41 37d5b880 HMMA.1688.F32 R0, R176, R192, R0
314 00007f41 37d5b890 LDSM.16.MT88.4 R184, [R167+UR8+0x400]
315 00007f41 37d5b8a0 HMMA.1688.F32 R32, R178, R192, R32
316 00007f41 37d5b8b0 LDSM.16.M88.4 R200, [R170+UR8+0x2000]
317 00007f41 37d5b8c0 HMMA.1688.F32 R64, R180, R192, R64
318 00007f41 37d5b8d0 LDSM.16.MT88.4 R188, [R168+UR8+0x400]
319 00007f41 37d5b8e0 HMMA.1688.F32 R96, R182, R192, R96
320 00007f41 37d5b8f0 LDSM.16.M88.4 R204, [R170+UR8+0x2800]
321 00007f41 37d5b900 HMMA.1688.F32 R100, R182, R193, R100
322 00007f41 37d5b910 HMMA.1688.F32 R68, R180, R193, R68
323 00007f41 37d5b920 HMMA.1688.F32 R36, R178, R193, R36
324 00007f41 37d5b930 HMMA.1688.F32 R4, R176, R193, R4
325 00007f41 37d5b940 HMMA.1688.F32 R8, R176, R194, R8
326 00007f41 37d5b950 HMMA.1688.F32 R40, R178, R194, R40
327 00007f41 37d5b960 HMMA.1688.F32 R72, R180, R194, R72
328 00007f41 37d5b970 HMMA.1688.F32 R104, R182, R194, R104
329 00007f41 37d5b980 HMMA.1688.F32 R108, R182, R195, R108
330 00007f41 37d5b990 HMMA.1688.F32 R76, R180, R195, R76
331 00007f41 37d5b9a0 HMMA.1688.F32 R44, R178, R195, R44
332 00007f41 37d5b9b0 HMMA.1688.F32 R12, R176, R195, R12
333 00007f41 37d5b9c0 HMMA.1688.F32 R16, R176, R196, R16
334 00007f41 37d5b9d0 HMMA.1688.F32 R48, R178, R196, R48
335 00007f41 37d5b9e0 HMMA.1688.F32 R80, R180, R196, R80
336 00007f41 37d5b9f0 HMMA.1688.F32 R112, R182, R196, R112
337 00007f41 37d5ba00 HMMA.1688.F32 R116, R182, R197, R116
```



## 参考文献

Expand Down
88 changes: 88 additions & 0 deletions hgemm/prof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
import time
from torch.utils.cpp_extension import load
from functools import partial
from typing import Optional

torch.set_grad_enabled(False)

# # Load the CUDA kernel as a python module
# lib = load(name='hgemm_lib',
# sources=['hgemm.cu'],
# extra_cuda_cflags=[
# "-O3",
# "-U__CUDA_NO_HALF_OPERATORS__",
# "-U__CUDA_NO_HALF_CONVERSIONS__",
# "-U__CUDA_NO_HALF2_OPERATORS__",
# "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
# "--expt-relaxed-constexpr",
# "--expt-extended-lambda",
# "--use_fast_math"
# ],
# extra_cflags=['-std=c++17'])


def run_benchmark(perf_func: callable,
a: torch.Tensor, b: torch.Tensor,
tag: str, out: Optional[torch.Tensor] = None,
warmup: int = 1, iters: int = 10,
show_all: bool = False):
if out is not None:
out.fill_(0)
if out is not None:
for i in range(warmup):
perf_func(a, b, out)
else:
for i in range(warmup):
_ = perf_func(a, b)

torch.cuda.synchronize()
start = time.time()
# iters
if out is not None:
for i in range(iters):
perf_func(a, b, out)
else:
for i in range(iters):
out = perf_func(a, b)
torch.cuda.synchronize()
end = time.time()
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
out_info = f"out_{tag}"
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
out_val = [round(v, 8) for v in out_val]
out_val = [f"{v:<12}" for v in out_val]
print(f"{out_info:>32}: {out_val}, time:{mean_time:.6f}ms")
if show_all: print(out)
return out.clone(), mean_time


# Ms = [1024, 2048, 4096]
# Ns = [1024, 2048, 4096]
# Ks = [256, 512, 1024]
Ms = [1024]
Ns = [1024]
Ks = [256]
MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks]
for (M, N, K) in MNKs:
print("-" * 110)
print(" " * 45 + f"M={M}, N={N}, K={K}")
a = torch.randn((M, K)).cuda().half().contiguous()
b = torch.randn((K, N)).cuda().half().contiguous()
c = torch.randn((M, N)).cuda().half().contiguous()
# run_benchmark(lib.hgemm_naive_f16, a, b, "f16", c)
# run_benchmark(lib.hgemm_sliced_k_f16, a, b, "f16(sk)", c)
# run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(t4x4bcf)", c)
# run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(t4x4offset)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4, a, b, "f16x4(t8x8sk)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_bcf, a, b, "f16x4(t8x8bcf)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack, a, b, "f16x4pack(t8x8sk)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(bcf)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(bcf+offset)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(bcf)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_offset, a, b, "f16x8pack(bcf+offset)", c)
# run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(dbuf)", c)
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
print("-" * 110)

0 comments on commit 3f5ace3

Please sign in to comment.