Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xformers support #20

Open
xuzhao9 opened this issue Oct 28, 2024 · 4 comments
Open

Add xformers support #20

xuzhao9 opened this issue Oct 28, 2024 · 4 comments

Comments

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 28, 2024

Add xformers built on source code, similar to fbgemm: https://github.com/facebookresearch/xformers

Make sure fa3 is available.

@antferdom
Copy link

Simple op availability assertion:

op = xformers.ops.fmha.flash3.FwOp
if op.is_available():
    print(f"xformers_ops_fmha_flash3 supported: {HAS_FLASH}")

References

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Oct 29, 2024

#23 should fix this

@antferdom
Copy link

antferdom commented Oct 29, 2024

Looks good to me, but xformers build from sources with FA3 support might trigger recompilation in the existing environment and overlap with previous Flash Attention v3 installation.

Me and a colleague @ohwi, found a point of conflict between xformers FA3 Torch custom op wrapper logic and flashattn_hopper_cuda, which led to CUDA errors:

TypeError: fwd(): incompatible function arguments. The following argument types are supported:                                                                                                     1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: float, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool, arg9: int, arg10: int) -> list[torch.Tensor]

Our understanding of the conflict:

  • The current version of fwd function in flashattn_hopper_cuda requires non-optional arguments window_size_left and window_size_right, but xformer registered custom `mha_fwd does not include this update.

And there is a code block in xformers that import flashattn_hopper_cuda as a fallback. This makes only one of xformers or flash-attn available.
See: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_api.cpp#L463-L475
and
https://github.com/facebookresearch/xformers/blob/68b7fd14df5eb1d2558c52842b4206a14d2d20e9/xformers/ops/fmha/flash3.py#L48-L82

Therefore, although xformers prints FLASH3 as available operator, we need to further assert its execution. I made it work with
flashattn-hopper==3.0.0b1
torch==2.4.1+cu124
xformers==0.0.29

This consideration might be worth creating a proper issue in xformers repo, what do you think @xuzhao9?

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Oct 29, 2024

Yes I think it is a valid issue to post to the xformers repo @antferdom

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants