-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ungrouping GQA #713
Ungrouping GQA #713
Conversation
Two notes on this. Can you spell out Secondly, for this we need to have a new test written to make sure this does not break at any point. The difference here is very subtle, and those sorts of things can break very easily due to typos and overlooking. We need to make sure there is a test to ensure that the routing to the right tensors is always maintained. The actual implementation looks great though! We just want to make sure it always stays functioning properly. |
Great! I've just done the renaming of |
Hi @bryce13950! I spoke with @hannamw and I'll be adding the tests for this. I'll let you know once it's done. |
@bryce13950 Pushed tests for the new |
I love it! I am going to merge this, and get this in a release by tomorrow morning. |
* adding option to ungroup gqa * adding option to ungroup gqa * updating gqa / rebasing * Update HookedTransformerConfig.py * formatting fix * renaming ungroup_gqa option * Add tests for ungroup_grouped_query_attention flag --------- Co-authored-by: Michael Hanna <[email protected]> Co-authored-by: Ivan Arcuschin <[email protected]>
* Redo of #713 (#722) * adding option to ungroup gqa * adding option to ungroup gqa * updating gqa / rebasing * Update HookedTransformerConfig.py * formatting fix * renaming ungroup_gqa option * Add tests for ungroup_grouped_query_attention flag --------- Co-authored-by: Michael Hanna <[email protected]> Co-authored-by: Ivan Arcuschin <[email protected]> * fixed typo --------- Co-authored-by: Michael Hanna <[email protected]> Co-authored-by: Ivan Arcuschin <[email protected]>
Description
Background: Normal attention layers have$N$ attention heads, each with unique query/key/value weights. In contrast, GQA has $N$ unique sets of query weights but only $M<N$ unique sets of key/value weights; thus, some of the key/value weights are shared across heads, improving efficiency. Currently, this means that while the dimensionality of
hook_k_input
andhook_v_input
is(batch, pos, N, d_model)
for most models, it is(batch, pos, M, d_model)
for GQA models. This is a bit annoying if we want to edit the k/v inputs to each attention head separately, even though their weights are shared.Proposal: I have added an option,
cfg.ungroup_gqa
(default:False
), which causes the input to be fed intohook_k_input
andhook_v_input
as(batch, pos, N, d_model)
, instead of only expanding to this size later, atcalculate_attention_scores
orcalculate_z_scores
. This loses the benefits of GQA, but it nice for interp.Type of change
Checklist: