Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ungrouping GQA #713

Merged
merged 8 commits into from
Sep 12, 2024
Merged

Ungrouping GQA #713

merged 8 commits into from
Sep 12, 2024

Conversation

hannamw
Copy link
Contributor

@hannamw hannamw commented Sep 6, 2024

Description

Background: Normal attention layers have $N$ attention heads, each with unique query/key/value weights. In contrast, GQA has $N$ unique sets of query weights but only $M<N$ unique sets of key/value weights; thus, some of the key/value weights are shared across heads, improving efficiency. Currently, this means that while the dimensionality of hook_k_input and hook_v_input is (batch, pos, N, d_model) for most models, it is (batch, pos, M, d_model) for GQA models. This is a bit annoying if we want to edit the k/v inputs to each attention head separately, even though their weights are shared.

Proposal: I have added an option, cfg.ungroup_gqa (default: False), which causes the input to be fed into hook_k_input and hook_v_input as (batch, pos, N, d_model), instead of only expanding to this size later, at calculate_attention_scores or calculate_z_scores. This loses the benefits of GQA, but it nice for interp.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@bryce13950
Copy link
Collaborator

Two notes on this. Can you spell out gqa in your config variable? I think most people will know what it means, however the configuration is done separate from the component itself, so having it spelt out as a param will make the functionality obvious immediately without someone having to reference documentation.

Secondly, for this we need to have a new test written to make sure this does not break at any point. The difference here is very subtle, and those sorts of things can break very easily due to typos and overlooking. We need to make sure there is a test to ensure that the routing to the right tensors is always maintained.

The actual implementation looks great though! We just want to make sure it always stays functioning properly.

@hannamw
Copy link
Contributor Author

hannamw commented Sep 12, 2024

Great! I've just done the renaming of ungroup_gqa to ungroup_grouped_query_attention. As for the test: is this something I should be writing? Unfortunately, I'm not very familiar with the transformer lens testing infrastructure. I guess the important things to verify are that 1) model logits the same when the setting is on/off and 2) when the setting is on, the shape of the k/v inputs are indeed as expected.

@FlyingPumba
Copy link
Contributor

Hi @bryce13950! I spoke with @hannamw and I'll be adding the tests for this. I'll let you know once it's done.

@FlyingPumba
Copy link
Contributor

@bryce13950 Pushed tests for the new ungroup_grouped_query_attention flag. We also now have 100% coverage for the GroupedQueryAttention class.

@bryce13950
Copy link
Collaborator

I love it! I am going to merge this, and get this in a release by tomorrow morning.

@bryce13950 bryce13950 merged commit a127e74 into TransformerLensOrg:main Sep 12, 2024
12 checks passed
@bryce13950 bryce13950 mentioned this pull request Sep 12, 2024
10 tasks
bryce13950 added a commit that referenced this pull request Sep 12, 2024
* adding option to ungroup gqa

* adding option to ungroup gqa

* updating gqa / rebasing

* Update HookedTransformerConfig.py

* formatting fix

* renaming ungroup_gqa option

* Add tests for ungroup_grouped_query_attention flag

---------

Co-authored-by: Michael Hanna <[email protected]>
Co-authored-by: Ivan Arcuschin <[email protected]>
bryce13950 added a commit that referenced this pull request Sep 26, 2024
* Redo of #713 (#722)

* adding option to ungroup gqa

* adding option to ungroup gqa

* updating gqa / rebasing

* Update HookedTransformerConfig.py

* formatting fix

* renaming ungroup_gqa option

* Add tests for ungroup_grouped_query_attention flag

---------

Co-authored-by: Michael Hanna <[email protected]>
Co-authored-by: Ivan Arcuschin <[email protected]>

* fixed typo

---------

Co-authored-by: Michael Hanna <[email protected]>
Co-authored-by: Ivan Arcuschin <[email protected]>
@bryce13950 bryce13950 mentioned this pull request Sep 26, 2024
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants