Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory usage #100

Merged
merged 1 commit into from
Jan 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap


AVAILABLE_EXAMPLES = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}


torch.set_default_device("cuda")
torch.manual_seed(0)

Expand Down Expand Up @@ -97,19 +112,15 @@ def test_mask(
causal_fav2_flops = 0.5 * B * H * D * S * S
flops = density * B * H * D * S * S

# Forward pass
causal_fa2_time = do_bench(causal_fa2)
sdpa_mask_time = do_bench(sdpa_mask)
flex_ms = do_bench(flex_attention_call)
times = []
for attn in (causal_fa2, sdpa_mask, flex_attention_call):
fwd_time = do_bench(attn)
fwd_out = attn()
bwd_time = do_bench(lambda: fwd_out.backward(gradOut, retain_graph=True)) # noqa: F821
times.append((fwd_time, bwd_time))

# Backward pass
causal_fa2_out = causal_fa2()
sdpa_mask_out = sdpa_mask()
flex_out = flex_attention_call()

causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True))
sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True))
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))
del fwd_out
torch.cuda.empty_cache()

print_header(
f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace(
Expand Down Expand Up @@ -140,6 +151,12 @@ def test_mask(
torch.testing.assert_close(flex, sdpa_mask, atol=1e-1, rtol=1e-2)

print("Correctness check passed ✅")

(
(causal_fa2_time, causal_fa2_bw_time),
(sdpa_mask_time, sdpa_mask_bw_time),
(flex_ms, flex_bw_ms),
) = times
# Usage in your results formatting:
results = [
[
Expand Down Expand Up @@ -210,28 +227,16 @@ def main(examples: List[str] = ["all"]):
Args:
examples: List of examples to run. If "all" is specified, all examples will be run.
"""
available_examples = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}

if "all" in examples:
ex_to_run = list(available_examples.keys())
ex_to_run = list(AVAILABLE_EXAMPLES.keys())
else:
ex_to_run = examples

for ex in ex_to_run:
if ex in available_examples:
available_examples[ex]()
if ex in AVAILABLE_EXAMPLES:
AVAILABLE_EXAMPLES[ex]()
torch.cuda.empty_cache()
else:
print(f"Warning: Unknown example key '{ex}'. Skipping.")

Expand All @@ -248,8 +253,9 @@ def main(examples: List[str] = ["all"]):
nargs="+",
default=["all"],
help="List of examples to run. Use space to separate multiple examples. "
"Available options: causal, alibi, sliding_window, prefix_lm, "
"document, softcap, softcap_approx, or 'all' to run all examples.",
"Available options: "
+ ", ".join(sorted(AVAILABLE_EXAMPLES.keys()))
+ ", or 'all' to run all examples.",
)

args = parser.parse_args()
Expand Down
Loading