Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add utilities for padding and add to bench_padding.py
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Nov 16, 2023
1 parent 52aed83 commit e2e6c16
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 0 deletions.
152 changes: 152 additions & 0 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from dataclasses import dataclass
from typing import Optional

import fire

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}


def benchmark_fn_in_usec(f, *args, **kwargs):
# Manual warmup
for _ in range(4):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
measurement = t0.blocked_autorange()
return measurement.mean * 1e6


def get_tops_info(tops, time, peak_tops):
time_sec = time / 1e6
tops_sec = float(tops) / time_sec
pct_top_peak = tops_sec / peak_tops
return tops_sec, pct_top_peak


def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
A_fp8 = A.to(fp8_dtype)
B_fp8 = B.to(fp8_dtype).t() # view

A_pad = pad_tensor_for_matmul(A_fp8) # mem copy
B_pad = pad_tensor_for_matmul(B_fp8, both=True).contiguous().t() # mem copy

return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][
: A.shape[0], : B.shape[1]
]


def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
A_pad = pad_tensor_for_matmul(A) # mem copy
B_pad = pad_tensor_for_matmul(B, both=True) # mem copy

A_pad = A_pad.to(fp8_dtype) # mem copy
B_pad = B_pad.to(fp8_dtype) # mem copy

B_pad = B_pad.t().contiguous().t() # mem copy

return torch._scaled_mm(A_pad, B_pad, out_dtype=out_dtype)[0][
: A.shape[0], : B.shape[1]
]


def do_hp_matmul(A, B):
return torch.matmul(A, B)


@dataclass
class Experiment_config:
M: int
K: int
N: int
output_dtype: torch.dtype
fp8_dtype: torch.dtype

def __iter__(self):
return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype))


def gen_configs():
shapes = [(8192, 2500, 5000), (4096, 10, 4096)]
output_dtype = torch.float32
fp8_dtype = torch.float8_e4m3fn
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]


@torch.no_grad()
def run(compile: bool = False, n_limit: Optional[int] = None):
device = "cuda"
experiments = gen_configs()
results = []
tops_table = []
tops_headers = [
"Shape",
"Ref Dtype",
"Ref Tops",
"FP8 Tops",
"Ref % Peak",
"FP8 % Peak",
]
for experiment in experiments:
M, K, N, output_dtype, fp8_dtype = experiment
tops = 2 * M * N * K

A_base = torch.rand(M, K, device=device, dtype=output_dtype)
B_base = torch.rand(K, N, device=device, dtype=output_dtype)

hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul
fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul

ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base)
fp8_time = benchmark_fn_in_usec(
fp8_func, A_base, B_base, fp8_dtype, output_dtype
)

ref_tops_sec, ref_pct_top_peak = get_tops_info(
tops, ref_time, dtype_to_peak_tops[output_dtype]
)
fp8_tops_sec, fp8_pct_top_peak = get_tops_info(
tops, fp8_time, dtype_to_peak_tops[fp8_dtype]
)
tops_table.append(
[
f"({M}x{K}x{N})",
f"{output_dtype}",
f"{ref_tops_sec:.2E}",
f"{fp8_tops_sec:.2E}",
f"{ref_pct_top_peak:.3f}",
f"{fp8_pct_top_peak:.3f}",
]
)
results.append(
[(M, K, N), output_dtype, ref_time, fp8_time, ref_time / fp8_time]
)

print("TOPs".center(80, "*"))
print(tabulate(tops_table, headers=tops_headers))
print("Speed Results".center(80, "*"))
headers = ["Shape", "Ref Dtype", "Ref Time", "FP8 Time", "Speedup"]
print(tabulate(results, headers=headers, tablefmt="grid"))


if __name__ == "__main__":
fire.Fire(run)
45 changes: 45 additions & 0 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,48 @@ def compute_error(x, y):
def is_row_major(stride):
assert len(stride) == 2, "is_row_major only supports 2D tensors"
return stride[0] > stride[1] and stride[1] == 1


def get_min_alignment(size: int, alignment_value: int):
"""
Returns the minimum alignment value that is greater than or equal to the given size.
Args:
size: The size of the data to be aligned.
alignment_value: The alignment value to be used.
Returns:
int: The minimum alignment value that is greater than or equal to the given size.
"""
if size % alignment_value == 0:
return size
return (1 + (size // alignment_value)) * alignment_value


def pad_tensor_for_matmul(tensor: torch.Tensor, both: bool = False) -> torch.Tensor:
"""
Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required for H100s.
Args:
tensor: The tensor to pad.
both: Whether to pad both dimensions or just the second dimension.
Returns:
torch.Tensor: The padded tensor.
"""
assert tensor.dim() == 2
dim1, dim2 = tensor.shape

# Calculate aligned dimensions
dim2_aligned = get_min_alignment(dim2, 16)
dim1_aligned = get_min_alignment(dim1, 16) if both else dim1

# Check if padding is needed for either dimension
if dim1 == dim1_aligned and dim2 == dim2_aligned:
return tensor

# Calculate padding values for both dimensions
pad_dim1 = dim1_aligned - dim1
pad_dim2 = dim2_aligned - dim2

return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1))

0 comments on commit e2e6c16

Please sign in to comment.