Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check correctness for score_mod implementations #103

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

oraluben
Copy link
Contributor

No description provided.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 13, 2025
@oraluben
Copy link
Contributor Author

Hi @drisspg , I'd like to propose a change in the main torch repo in the torch.nn.attention.flex_attention.create_mask[1] function, let it to returns the real bias, instead of the mask.

This allows to finish this PR with minimal efforts and would help other people to validate their flex attn implementations IMO.
If the change has been adopted, the current behaviour can be simply restored with torch.where(torch.isneginf(out), False, True), while it's not easy to get the bias from the mask, if the user do not want to just ctrl-c-v create_mask.

The change might look like this with backward compatibility:

diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
index 7b0a3354642..ebd17dde5c4 100644
--- a/torch/nn/attention/flex_attention.py
+++ b/torch/nn/attention/flex_attention.py
@@ -711,6 +711,7 @@ def create_mask(
     Q_LEN: int,
     KV_LEN: int,
     device: str = "cuda",
+    _return_score: bool = False,
     _compile: bool = False,
 ) -> Tensor:
     r"""This function creates a mask tensor from a mod_fn function.
@@ -747,7 +748,8 @@ def create_mask(
             score_mod = mod_fn
             score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,))  # first input is score
             out = score_mod(torch.zeros(B, H, Q_LEN, KV_LEN, device=device), b, h, m, n)
-            mask = torch.where(torch.isneginf(out), False, True)
+            if not _return_score:
+                mask = torch.where(torch.isneginf(out), False, True)
             return mask
         elif mod_type == _ModificationType.MASK_MOD:
             mask_mod = mod_fn

Looking forward to know if it make sense to you, thanks!

[1] https://github.com/pytorch/pytorch/blob/87843ee9ab50778a98eda62fd7498d44c69488bd/torch/nn/attention/flex_attention.py#L810-L815

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants