diff --git a/transformer_lens/components.py b/transformer_lens/components.py index cd6708679..dae734baa 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -692,10 +692,7 @@ def rotate_every_two( GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. """ rot_x = x.clone() - if ( - self.cfg.original_architecture - in ["GPTNeoXForCausalLM", "LlamaForCausalLM"], - ): + if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: n = x.size(-1) // 2 rot_x[..., :n] = -x[..., n:] rot_x[..., n:] = x[..., :n]