Skip to content

Commit

Permalink
clear cache after each run
Browse files Browse the repository at this point in the history
  • Loading branch information
oraluben committed Jan 3, 2025
1 parent 84b718e commit 1930929
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,15 @@ def test_mask(
causal_fav2_flops = 0.5 * B * H * D * S * S
flops = density * B * H * D * S * S

# Group each attn impl to reduce memory usage
causal_fa2_time = do_bench(causal_fa2)
causal_fa2_out = causal_fa2()
causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True))
times = []
for attn in (causal_fa2, sdpa_mask, flex_attention_call):
fwd_time = do_bench(attn)
fwd_out = attn()
bwd_time = do_bench(lambda: fwd_out.backward(gradOut, retain_graph=True))
times.append((fwd_time, bwd_time))

sdpa_mask_time = do_bench(sdpa_mask)
sdpa_mask_out = sdpa_mask()
sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True))

flex_ms = do_bench(flex_attention_call)
flex_out = flex_attention_call()
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))
del fwd_out
torch.cuda.empty_cache()

print_header(
f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace(
Expand Down Expand Up @@ -154,6 +151,12 @@ def test_mask(
torch.testing.assert_close(flex, sdpa_mask, atol=1e-1, rtol=1e-2)

print("Correctness check passed ✅")

(
(causal_fa2_time, causal_fa2_bw_time),
(sdpa_mask_time, sdpa_mask_bw_time),
(flex_ms, flex_bw_ms),
) = times
# Usage in your results formatting:
results = [
[
Expand Down

0 comments on commit 1930929

Please sign in to comment.