diff --git a/examples/benchmark.py b/examples/benchmark.py index a9dec37..885f7a2 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -26,6 +26,21 @@ from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap +AVAILABLE_EXAMPLES = { + "causal": lambda: test_mask(mask_mod=causal_mask), + "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), + "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), + "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), + "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), + "softcap": lambda: test_mask( + score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True + ), + "softcap_approx": lambda: test_mask( + score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True + ), +} + + torch.set_default_device("cuda") torch.manual_seed(0) @@ -97,18 +112,17 @@ def test_mask( causal_fav2_flops = 0.5 * B * H * D * S * S flops = density * B * H * D * S * S - # Forward pass + # Group each attn impl to reduce memory usage causal_fa2_time = do_bench(causal_fa2) - sdpa_mask_time = do_bench(sdpa_mask) - flex_ms = do_bench(flex_attention_call) - - # Backward pass causal_fa2_out = causal_fa2() - sdpa_mask_out = sdpa_mask() - flex_out = flex_attention_call() - causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True)) + + sdpa_mask_time = do_bench(sdpa_mask) + sdpa_mask_out = sdpa_mask() sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True)) + + flex_ms = do_bench(flex_attention_call) + flex_out = flex_attention_call() flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True)) print_header( @@ -210,28 +224,16 @@ def main(examples: List[str] = ["all"]): Args: examples: List of examples to run. If "all" is specified, all examples will be run. """ - available_examples = { - "causal": lambda: test_mask(mask_mod=causal_mask), - "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), - "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), - "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), - "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), - "softcap": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True - ), - "softcap_approx": lambda: test_mask( - score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True - ), - } if "all" in examples: - ex_to_run = list(available_examples.keys()) + ex_to_run = list(AVAILABLE_EXAMPLES.keys()) else: ex_to_run = examples for ex in ex_to_run: - if ex in available_examples: - available_examples[ex]() + if ex in AVAILABLE_EXAMPLES: + AVAILABLE_EXAMPLES[ex]() + torch.cuda.empty_cache() else: print(f"Warning: Unknown example key '{ex}'. Skipping.") @@ -248,8 +250,8 @@ def main(examples: List[str] = ["all"]): nargs="+", default=["all"], help="List of examples to run. Use space to separate multiple examples. " - "Available options: causal, alibi, sliding_window, prefix_lm, " - "document, softcap, softcap_approx, or 'all' to run all examples.", + "Available options: " + ", ".join(sorted(AVAILABLE_EXAMPLES.keys())) + + ", or 'all' to run all examples.", ) args = parser.parse_args()