Skip to content

Commit

Permalink
Validate alibi (with local patch to return score)
Browse files Browse the repository at this point in the history
  • Loading branch information
oraluben committed Jan 13, 2025
1 parent be9c92f commit e0c9613
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
AVAILABLE_EXAMPLES = {
"causal": lambda: test_mask(mask_mod=causal_mask),
"causal_score": lambda: test_mask(score_mod=lambda score, b, h, q_idx, kv_idx: torch.where(causal_mask(b, h, q_idx, kv_idx), score, torch.finfo(score.dtype))),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True),
"alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=False),
"sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)),
"prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)),
"document": lambda: run_document_masking(max_seq_len=32768, num_docs=12),
@@ -88,16 +88,14 @@ def test_mask(
block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device)
else:
block_mask = None
sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod
mask = create_mask(sdpa_mask_fn, 1, 1, S, S, device=device)
if score_mod:
mask = torch.where(mask, score_mod(
torch.zeros_like(mask, dtype=data_type),
torch.tensor([1], dtype=data_type),
torch.tensor([1], dtype=data_type),
torch.tensor([[s for i in range(S)] for s in range(S)], dtype=torch.int64),
torch.tensor([[i for i in range(S)] for s in range(S)], dtype=torch.int64),
), torch.finfo(data_type).min)
mask = create_mask(mask_mod, 1, H, S, S, device=device) if mask_mod else None
bias = create_mask(score_mod, 1, H, S, S, device=device) if score_mod else None
if bias is not None:
bias = bias.to(dtype=data_type)
if mask:
mask = bias.where(mask, torch.finfo(data_type).min)
else:
assert mask is not None

qkv = [
torch.randn(B, H, S, D, device=device, dtype=data_type, requires_grad=True)

0 comments on commit e0c9613

Please sign in to comment.