From e0c9613c46072bf289c19ccf711ba67eef35d2d2 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 13 Jan 2025 14:28:36 +0800 Subject: [PATCH] Validate alibi (with local patch to return score) --- examples/benchmark.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 518dd3c..5aac7b4 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -29,7 +29,7 @@ AVAILABLE_EXAMPLES = { "causal": lambda: test_mask(mask_mod=causal_mask), "causal_score": lambda: test_mask(score_mod=lambda score, b, h, q_idx, kv_idx: torch.where(causal_mask(b, h, q_idx, kv_idx), score, torch.finfo(score.dtype))), - "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True), + "alibi": lambda: test_mask(score_mod=generate_alibi_bias(16), skip_correctness=False), "sliding_window": lambda: test_mask(mask_mod=generate_sliding_window(window_size=1024)), "prefix_lm": lambda: test_mask(mask_mod=generate_prefix_lm_mask(prefix_length=1024)), "document": lambda: run_document_masking(max_seq_len=32768, num_docs=12), @@ -88,16 +88,14 @@ def test_mask( block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device) else: block_mask = None - sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod - mask = create_mask(sdpa_mask_fn, 1, 1, S, S, device=device) - if score_mod: - mask = torch.where(mask, score_mod( - torch.zeros_like(mask, dtype=data_type), - torch.tensor([1], dtype=data_type), - torch.tensor([1], dtype=data_type), - torch.tensor([[s for i in range(S)] for s in range(S)], dtype=torch.int64), - torch.tensor([[i for i in range(S)] for s in range(S)], dtype=torch.int64), - ), torch.finfo(data_type).min) + mask = create_mask(mask_mod, 1, H, S, S, device=device) if mask_mod else None + bias = create_mask(score_mod, 1, H, S, S, device=device) if score_mod else None + if bias is not None: + bias = bias.to(dtype=data_type) + if mask: + mask = bias.where(mask, torch.finfo(data_type).min) + else: + assert mask is not None qkv = [ torch.randn(B, H, S, D, device=device, dtype=data_type, requires_grad=True)