Skip to content

Commit

Permalink
Fix attention result projection (#666)
Browse files Browse the repository at this point in the history
* Updated README to have Bryce as the maintainer

* Fix attention result projection

Current result projection for attention is incorrect. Type annotations would suggest that `result` isn't being summed over `head_index`, but in fact it is. I've edited the function so that it's no longer being summed over `head_index`.

Note, this bug caused the ARENA material to fail for the first transformers chapter, I've tested it and it now works.

* fix formatting with black

---------

Co-authored-by: Bryce Meyer <[email protected]>
Co-authored-by: Neel Nanda <[email protected]>
  • Loading branch information
3 people committed Jul 11, 2024
1 parent 9872334 commit 67ed0d6
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,15 @@ def forward(
else:
w = einops.rearrange(
self.W_O,
"head_index d_head d_model -> d_model (head_index d_head)",
"head_index d_head d_model -> d_model head_index d_head",
)
input = einops.rearrange(
z, "batch pos head_index d_head -> batch pos (head_index d_head)"
)
result = self.hook_result(F.linear(input, w)) # [batch, pos, head_index, d_model]
result = self.hook_result(
einops.einsum(
z,
w,
"... head_index d_head, d_model head_index d_head -> ... head_index d_model",
)
) # [batch, pos, head_index, d_model]
out = (
einops.reduce(result, "batch position index model->batch position model", "sum")
+ self.b_O
Expand Down

0 comments on commit 67ed0d6

Please sign in to comment.