Skip to content

Commit

Permalink
Reduce memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
oraluben committed Jan 3, 2025
1 parent 60f0f87 commit 4bf8ecd
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap


AVAILABLE_EXAMPLES = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}


torch.set_default_device("cuda")
torch.manual_seed(0)

Expand Down Expand Up @@ -97,18 +112,17 @@ def test_mask(
causal_fav2_flops = 0.5 * B * H * D * S * S
flops = density * B * H * D * S * S

# Forward pass
# Group each attn impl to reduce memory usage
causal_fa2_time = do_bench(causal_fa2)
sdpa_mask_time = do_bench(sdpa_mask)
flex_ms = do_bench(flex_attention_call)

# Backward pass
causal_fa2_out = causal_fa2()
sdpa_mask_out = sdpa_mask()
flex_out = flex_attention_call()

causal_fa2_bw_time = do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True))

sdpa_mask_time = do_bench(sdpa_mask)
sdpa_mask_out = sdpa_mask()
sdpa_mask_bw_time = do_bench(lambda: sdpa_mask_out.backward(gradOut, retain_graph=True))

flex_ms = do_bench(flex_attention_call)
flex_out = flex_attention_call()
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))

print_header(
Expand Down Expand Up @@ -210,28 +224,16 @@ def main(examples: List[str] = ["all"]):
Args:
examples: List of examples to run. If "all" is specified, all examples will be run.
"""
available_examples = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
"softcap": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=False), skip_correctness=True
),
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
}

if "all" in examples:
ex_to_run = list(available_examples.keys())
ex_to_run = list(AVAILABLE_EXAMPLES.keys())
else:
ex_to_run = examples

for ex in ex_to_run:
if ex in available_examples:
available_examples[ex]()
if ex in AVAILABLE_EXAMPLES:
AVAILABLE_EXAMPLES[ex]()
torch.cuda.empty_cache()
else:
print(f"Warning: Unknown example key '{ex}'. Skipping.")

Expand All @@ -248,8 +250,8 @@ def main(examples: List[str] = ["all"]):
nargs="+",
default=["all"],
help="List of examples to run. Use space to separate multiple examples. "
"Available options: causal, alibi, sliding_window, prefix_lm, "
"document, softcap, softcap_approx, or 'all' to run all examples.",
"Available options: " + ", ".join(sorted(AVAILABLE_EXAMPLES.keys())) +
", or 'all' to run all examples.",
)

args = parser.parse_args()
Expand Down

0 comments on commit 4bf8ecd

Please sign in to comment.