Skip to content

Commit

Permalink
fix(components): fix bug for models with rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
soheeyang committed Jul 30, 2023
1 parent bf5ed27 commit 32d55c1
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions transformer_lens/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,7 @@ def rotate_every_two(
GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
"""
rot_x = x.clone()
if (
self.cfg.original_architecture
in ["GPTNeoXForCausalLM", "LlamaForCausalLM"],
):
if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]:
n = x.size(-1) // 2
rot_x[..., :n] = -x[..., n:]
rot_x[..., n:] = x[..., :n]
Expand Down

0 comments on commit 32d55c1

Please sign in to comment.