diff --git a/.gitignore b/.gitignore index 63fab501..4b4ec220 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,6 @@ __pycache__ *.pt *.pth *.nsys* -*.sqlite +*.ncu* +*.sqlite* *.engine diff --git a/hgemm/.gitignore b/hgemm/.gitignore index e69de29b..02f3909b 100755 --- a/hgemm/.gitignore +++ b/hgemm/.gitignore @@ -0,0 +1,18 @@ +*.so +*.a +*.dylib +*.dll +*.lib +.DS_Store +build +*.whl +tmp +__pycache__ +*.onnx +*.engine +*.pt +*.pth +*.nsys* +*.ncu* +*.sqlite* +*.engine diff --git a/hgemm/README.md b/hgemm/README.md index 8f9202fd..c7fb10a4 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -115,6 +115,68 @@ cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); ``` +## PyTorch HGEMM Profile + +在Ada架构下,PyTorch 2.4对FP16使用matmul时,会调用ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_32x1_nn kernel,内部实际使用HMMA(Tensor Cores)进行计算。 + +```bash +ncu -o hgemm.prof -f python3 prof.py +nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true python3 prof.py +``` +- 日志 + +```bash +==PROF== Connected to process 367502 (/usr/bin/python3.10) +==PROF== Profiling "unrolled_elementwise_kernel" - 0: 0%....50%....100% - 8 passes +==PROF== Profiling "unrolled_elementwise_kernel" - 1: 0%....50%....100% - 8 passes +==PROF== Profiling "unrolled_elementwise_kernel" - 2: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 3: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 4: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 5: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 6: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 7: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 8: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 9: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 10: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 11: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 12: 0%....50%....100% - 8 passes +==PROF== Profiling "ampere_fp16_s1688gemm_fp16_12..." - 13: 0%....50%....100% - 8 passes +``` + +- SASS + +```C +310 00007f41 37d5b850 LDSM.16.M88.4 R192, [R169+UR8+0x2000] +311 00007f41 37d5b860 LDSM.16.M88.4 R196, [R169+UR8+0x2800] +312 00007f41 37d5b870 @!P0 BRA.U 0x7f4137d5c3f0 +313 00007f41 37d5b880 HMMA.1688.F32 R0, R176, R192, R0 +314 00007f41 37d5b890 LDSM.16.MT88.4 R184, [R167+UR8+0x400] +315 00007f41 37d5b8a0 HMMA.1688.F32 R32, R178, R192, R32 +316 00007f41 37d5b8b0 LDSM.16.M88.4 R200, [R170+UR8+0x2000] +317 00007f41 37d5b8c0 HMMA.1688.F32 R64, R180, R192, R64 +318 00007f41 37d5b8d0 LDSM.16.MT88.4 R188, [R168+UR8+0x400] +319 00007f41 37d5b8e0 HMMA.1688.F32 R96, R182, R192, R96 +320 00007f41 37d5b8f0 LDSM.16.M88.4 R204, [R170+UR8+0x2800] +321 00007f41 37d5b900 HMMA.1688.F32 R100, R182, R193, R100 +322 00007f41 37d5b910 HMMA.1688.F32 R68, R180, R193, R68 +323 00007f41 37d5b920 HMMA.1688.F32 R36, R178, R193, R36 +324 00007f41 37d5b930 HMMA.1688.F32 R4, R176, R193, R4 +325 00007f41 37d5b940 HMMA.1688.F32 R8, R176, R194, R8 +326 00007f41 37d5b950 HMMA.1688.F32 R40, R178, R194, R40 +327 00007f41 37d5b960 HMMA.1688.F32 R72, R180, R194, R72 +328 00007f41 37d5b970 HMMA.1688.F32 R104, R182, R194, R104 +329 00007f41 37d5b980 HMMA.1688.F32 R108, R182, R195, R108 +330 00007f41 37d5b990 HMMA.1688.F32 R76, R180, R195, R76 +331 00007f41 37d5b9a0 HMMA.1688.F32 R44, R178, R195, R44 +332 00007f41 37d5b9b0 HMMA.1688.F32 R12, R176, R195, R12 +333 00007f41 37d5b9c0 HMMA.1688.F32 R16, R176, R196, R16 +334 00007f41 37d5b9d0 HMMA.1688.F32 R48, R178, R196, R48 +335 00007f41 37d5b9e0 HMMA.1688.F32 R80, R180, R196, R80 +336 00007f41 37d5b9f0 HMMA.1688.F32 R112, R182, R196, R112 +337 00007f41 37d5ba00 HMMA.1688.F32 R116, R182, R197, R116 +``` + + ## 参考文献 diff --git a/hgemm/prof.py b/hgemm/prof.py new file mode 100644 index 00000000..14012338 --- /dev/null +++ b/hgemm/prof.py @@ -0,0 +1,88 @@ +import torch +import time +from torch.utils.cpp_extension import load +from functools import partial +from typing import Optional + +torch.set_grad_enabled(False) + +# # Load the CUDA kernel as a python module +# lib = load(name='hgemm_lib', +# sources=['hgemm.cu'], +# extra_cuda_cflags=[ +# "-O3", +# "-U__CUDA_NO_HALF_OPERATORS__", +# "-U__CUDA_NO_HALF_CONVERSIONS__", +# "-U__CUDA_NO_HALF2_OPERATORS__", +# "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", +# "--expt-relaxed-constexpr", +# "--expt-extended-lambda", +# "--use_fast_math" +# ], +# extra_cflags=['-std=c++17']) + + +def run_benchmark(perf_func: callable, + a: torch.Tensor, b: torch.Tensor, + tag: str, out: Optional[torch.Tensor] = None, + warmup: int = 1, iters: int = 10, + show_all: bool = False): + if out is not None: + out.fill_(0) + if out is not None: + for i in range(warmup): + perf_func(a, b, out) + else: + for i in range(warmup): + _ = perf_func(a, b) + + torch.cuda.synchronize() + start = time.time() + # iters + if out is not None: + for i in range(iters): + perf_func(a, b, out) + else: + for i in range(iters): + out = perf_func(a, b) + torch.cuda.synchronize() + end = time.time() + total_time = (end - start) * 1000 # ms + mean_time = total_time / iters + out_info = f"out_{tag}" + out_val = out.flatten().detach().cpu().numpy().tolist()[:3] + out_val = [round(v, 8) for v in out_val] + out_val = [f"{v:<12}" for v in out_val] + print(f"{out_info:>32}: {out_val}, time:{mean_time:.6f}ms") + if show_all: print(out) + return out.clone(), mean_time + + +# Ms = [1024, 2048, 4096] +# Ns = [1024, 2048, 4096] +# Ks = [256, 512, 1024] +Ms = [1024] +Ns = [1024] +Ks = [256] +MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks] +for (M, N, K) in MNKs: + print("-" * 110) + print(" " * 45 + f"M={M}, N={N}, K={K}") + a = torch.randn((M, K)).cuda().half().contiguous() + b = torch.randn((K, N)).cuda().half().contiguous() + c = torch.randn((M, N)).cuda().half().contiguous() + # run_benchmark(lib.hgemm_naive_f16, a, b, "f16", c) + # run_benchmark(lib.hgemm_sliced_k_f16, a, b, "f16(sk)", c) + # run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(t4x4bcf)", c) + # run_benchmark(lib.hgemm_t_4x4_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(t4x4offset)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4, a, b, "f16x4(t8x8sk)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_bcf, a, b, "f16x4(t8x8bcf)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack, a, b, "f16x4pack(t8x8sk)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf, a, b, "f16x4pack(bcf)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4_pack_bcf_offset, a, b, "f16x4pack(bcf+offset)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "f16x8pack(bcf)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_offset, a, b, "f16x8pack(bcf+offset)", c) + # run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "f16x8pack(dbuf)", c) + run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th") + print("-" * 110) +