-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proposal] Optionally use flash attention. #378
Comments
Seems reasonable to me, I'd be happy for someone to add this
…On Fri, 8 Sept 2023 at 18:44, Ben Thompson ***@***.***> wrote:
It would be nice to have a flag to enable flash attention in models where
that would make sense. This is helpful for performance and memory usage in
larger models. In my case working with Pythia 12B, I get ~50% better
performance and ~4x larger batch sizes when using flash attention. I also
find numerical stability in float16 to be better using flash attention,
probably because the model was trained using flash attention.
The downside of using flash attention in TransformerLens is that the we
would not have access to intermediate quantities in the attention
calculation like the attention matrix itself. This is why I would suggest
having a default-off flag so that users can choose whether they need those
intermediate values to be available. In addition, when only a small subset
of attention intermediates are needed, it's much faster to just cache the
input to the attention layer and then recompute those intermediates when
needed.
Thanks!
—
Reply to this email directly, view it on GitHub
<#378>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ASRPNKOCL65GY7N52PH2PQDXZNKRXANCNFSM6AAAAAA4QWO7HE>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Seems v. useful for sparse autoencoder training. Docs here - https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#conclusion - in case anyone wants to take this (I'll pick it up at some point if no-one does). |
I'd be quite keen to make a start on this soon, @alan-cooney have you made a start already? |
I haven't yet so please feel free to! |
It would be nice to have a flag to enable flash attention in models where that would make sense. This is helpful for performance and memory usage in larger models. In my case working with Pythia 12B, I get ~50% better performance and ~4x larger batch sizes when using flash attention. I also find numerical stability in float16 to be better using flash attention, probably because the model was trained using flash attention.
The downside of using flash attention in TransformerLens is that we would not have access to intermediate quantities in the attention calculation like the attention matrix itself. This is why I would suggest having a default-off flag so that users can choose whether they need those intermediate values to be available. In addition, when only a small subset of attention intermediates are needed, it's much faster to just cache the input to the attention layer (or the residual stream) and then recompute those intermediates when needed.
Thanks!
The text was updated successfully, but these errors were encountered: