From 9c1bbbb8bc4d36fc0e4bc45860538857721f3d83 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Fri, 3 Jan 2025 14:16:03 +0800 Subject: [PATCH] Reduce memory usage --- examples/benchmark.py | 66 +++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index a9dec37..50debe2 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -26,6 +26,21 @@ from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap +AVAILABLE_EXAMPLES = { + "causal": lambda: test_mask(mask_mod=causal_mask), + "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), + "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), + "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), + "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), + "softcap": lambda: test_mask( + score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True + ), + "softcap_approx": lambda: test_mask( + score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True + ), +} + + torch.set_default_device("cuda") torch.manual_seed(0) @@ -97,19 +112,15 @@ def test_mask( causal_fav2_flops = 0.5 * B * H * D * S * S flops = density * B * H * D * S * S - # Forward pass - causal_fa2_time = do_bench(causal_fa2) - sdpa_mask_time = do_bench(sdpa_mask) - flex_ms = do_bench(flex_attention_call) + times = [] + for attn in (causal_fa2, sdpa_mask, flex_attention_call): + fwd_time = do_bench(attn) + fwd_out = attn() + bwd_time = do_bench(lambda: fwd_out.backward(gradOut, retain_graph=True)) # noqa: F821 + times.append((fwd_time, bwd_time)) - # Backward pass - causal_fa2_out = causal_fa2() - sdpa_mask_out = sdpa_mask() - flex_out = flex_attention_call() - - causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True)) - sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True)) - flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True)) + del fwd_out + torch.cuda.empty_cache() print_header( f"{score_mod.__name__ if score_mod is not None else mask_mod.__name__}".replace( @@ -140,6 +151,12 @@ def test_mask( torch.testing.assert_close(flex, sdpa_mask, atol=1e-1, rtol=1e-2) print("Correctness check passed ✅") + + ( + (causal_fa2_time, causal_fa2_bw_time), + (sdpa_mask_time, sdpa_mask_bw_time), + (flex_ms, flex_bw_ms), + ) = times # Usage in your results formatting: results = [ [ @@ -210,28 +227,16 @@ def main(examples: List[str] = ["all"]): Args: examples: List of examples to run. If "all" is specified, all examples will be run. """ - available_examples = { - "causal": lambda: test_mask(mask_mod=causal_mask), - "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), - "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), - "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), - "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), - "softcap": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True - ), - "softcap_approx": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True - ), - } if "all" in examples: - ex_to_run = list(available_examples.keys()) + ex_to_run = list(AVAILABLE_EXAMPLES.keys()) else: ex_to_run = examples for ex in ex_to_run: - if ex in available_examples: - available_examples[ex]() + if ex in AVAILABLE_EXAMPLES: + AVAILABLE_EXAMPLES[ex]() + torch.cuda.empty_cache() else: print(f"Warning: Unknown example key '{ex}'. Skipping.") @@ -248,8 +253,9 @@ def main(examples: List[str] = ["all"]): nargs="+", default=["all"], help="List of examples to run. Use space to separate multiple examples. " - "Available options: causal, alibi, sliding_window, prefix_lm, " - "document, softcap, softcap_approx, or 'all' to run all examples.", + "Available options: " + + ", ".join(sorted(AVAILABLE_EXAMPLES.keys())) + + ", or 'all' to run all examples.", ) args = parser.parse_args()