diff --git a/tests/integration/test_grouped_query_attention.py b/tests/integration/test_grouped_query_attention.py index e5e603454..23d4535e2 100644 --- a/tests/integration/test_grouped_query_attention.py +++ b/tests/integration/test_grouped_query_attention.py @@ -1,6 +1,7 @@ import einops import torch +from transformer_lens import HookedTransformer from transformer_lens.components import Attention, GroupedQueryAttention from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -55,7 +56,7 @@ def test_grouped_query_attention_output_is_correct(): "mask": regular_attention.state_dict()["mask"], "IGNORE": regular_attention.state_dict()["IGNORE"], } - grouped_query_attemtion_state_dict = { + grouped_query_attention_state_dict = { "W_Q": W_Q, "b_Q": b_Q, "W_O": W_O, @@ -69,7 +70,7 @@ def test_grouped_query_attention_output_is_correct(): } regular_attention.load_state_dict(regular_attention_state_dict) - grouped_query_attention.load_state_dict(grouped_query_attemtion_state_dict) + grouped_query_attention.load_state_dict(grouped_query_attention_state_dict) query_input = torch.rand((1, 5, d_model)) key_input = torch.rand((1, 5, d_model)) @@ -92,3 +93,143 @@ def test_grouped_query_attention_output_is_correct(): ) assert torch.allclose(regular_attn_output, split_grouped_query_attn_output, rtol=1e-6) + + +def test_ungroup_grouped_query_attention_flag_produces_same_result(): + d_model = 512 + d_head = 32 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 4 + n_layers = 1 + + cfg_flag_off = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + act_fn="silu", + ungroup_grouped_query_attention=False, + ) + grouped_query_attention_flag_off = GroupedQueryAttention(cfg_flag_off) + + cfg_flag_on = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + act_fn="silu", + ungroup_grouped_query_attention=True, + ) + grouped_query_attention_flag_on = GroupedQueryAttention(cfg_flag_on) + + W_Q = torch.rand((n_heads, d_model, d_head)) + b_Q = torch.rand((n_heads, d_head)) + _W_K = torch.rand((n_key_value_heads, d_model, d_head)) + _b_K = torch.rand((n_key_value_heads, d_head)) + _W_V = torch.rand((n_key_value_heads, d_model, d_head)) + _b_V = torch.rand((n_key_value_heads, d_head)) + W_O = torch.rand((n_heads, d_head, d_model)) + b_O = torch.rand(d_model) + + grouped_query_attention_state_dict = { + "W_Q": W_Q, + "b_Q": b_Q, + "W_O": W_O, + "b_O": b_O, + "_W_K": _W_K, + "_b_K": _b_K, + "_W_V": _W_V, + "_b_V": _b_V, + "mask": grouped_query_attention_flag_off.state_dict()["mask"], + "IGNORE": grouped_query_attention_flag_off.state_dict()["IGNORE"], + } + + grouped_query_attention_flag_off.load_state_dict(grouped_query_attention_state_dict) + grouped_query_attention_flag_on.load_state_dict(grouped_query_attention_state_dict) + + query_input = torch.rand((1, 5, d_model)) + key_input = torch.rand((1, 5, d_model)) + value_input = torch.rand((1, 5, d_model)) + + grouped_query_attn_flag_off_output = grouped_query_attention_flag_off( + query_input, key_input, value_input + ) + grouped_query_attn_flag_on_output = grouped_query_attention_flag_on( + query_input, key_input, value_input + ) + + assert torch.equal(grouped_query_attn_flag_off_output, grouped_query_attn_flag_on_output) + + +def test_ungroup_grouped_query_attention_flag_changes_k_v_hooks_shape(): + d_model = 512 + d_head = 32 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 4 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + act_fn="silu", + d_vocab=d_vocab, + use_split_qkv_input=True, + ungroup_grouped_query_attention=False, + ) + + model = HookedTransformer(cfg) + assert model.cfg.ungroup_grouped_query_attention is False + + x = torch.arange(1, 9).unsqueeze(0) + flag_off_output, flag_off_cache = model.run_with_cache( + x, + names_filter=[ + "blocks.0.attn.hook_k", + "blocks.0.attn.hook_v", + "blocks.0.hook_k_input", + "blocks.0.hook_v_input", + ], + ) + + model.set_ungroup_grouped_query_attention(True) + assert model.cfg.ungroup_grouped_query_attention is True + + flag_on_output, flag_on_cache = model.run_with_cache( + x, + names_filter=[ + "blocks.0.attn.hook_k", + "blocks.0.attn.hook_v", + "blocks.0.hook_k_input", + "blocks.0.hook_v_input", + ], + ) + + assert ( + flag_on_cache["blocks.0.attn.hook_k"].shape[2] + == flag_off_cache["blocks.0.attn.hook_k"].shape[2] * n_key_value_heads + ) + assert ( + flag_on_cache["blocks.0.attn.hook_v"].shape[2] + == flag_off_cache["blocks.0.attn.hook_v"].shape[2] * n_key_value_heads + ) + assert ( + flag_on_cache["blocks.0.hook_k_input"].shape[2] + == flag_off_cache["blocks.0.hook_k_input"].shape[2] * n_key_value_heads + ) + assert ( + flag_on_cache["blocks.0.hook_v_input"].shape[2] + == flag_off_cache["blocks.0.hook_v_input"].shape[2] * n_key_value_heads + ) + + assert torch.equal(flag_off_output, flag_on_output) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 1c2a4a481..a5c53b222 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1940,6 +1940,12 @@ def set_use_attn_in(self, use_attn_in: bool): """ self.cfg.use_attn_in = use_attn_in + def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): + """ + Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). + """ + self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention + def process_weights_( self, fold_ln: bool = True, diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index cfca7fb72..e2fdc532e 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -54,6 +54,8 @@ class HookedTransformerConfig: attention head separately, with a hook. Defaults to false to save memory use_attn_scale (bool): whether to scale the attention weights by 1/sqrt(d_head) + ungroup_grouped_query_attention (bool): whether to ungroup key and value heads, for models that use + grouped query attention. attn_scale (float): The amount to divide attention scores by (if applicable). Defaults to sqrt(d_head) model_name (str): the name of the model, used to load @@ -199,6 +201,7 @@ class HookedTransformerConfig: use_hook_mlp_in: bool = False use_attn_in: bool = False use_local_attn: bool = False + ungroup_grouped_query_attention: bool = False original_architecture: Optional[str] = None from_checkpoint: bool = False checkpoint_index: Optional[int] = None diff --git a/transformer_lens/components/grouped_query_attention.py b/transformer_lens/components/grouped_query_attention.py index 0681518f7..d86740033 100644 --- a/transformer_lens/components/grouped_query_attention.py +++ b/transformer_lens/components/grouped_query_attention.py @@ -126,11 +126,16 @@ def calculate_qkv_matrices( q = self.hook_q( attn_fn(query_input, self.W_Q, self.b_Q) ) # [batch, pos, head_index, d_head] + k = self.hook_k( - attn_fn(key_input, self._W_K, self._b_K) + attn_fn(key_input, self.W_K, self.b_K) + if self.cfg.ungroup_grouped_query_attention + else attn_fn(key_input, self._W_K, self._b_K) ) # [batch, pos, head_index, d_head] v = self.hook_v( - attn_fn(value_input, self._W_V, self._b_V) + attn_fn(value_input, self.W_V, self.b_V) + if self.cfg.ungroup_grouped_query_attention + else attn_fn(value_input, self._W_V, self._b_V) ) # [batch, pos, head_index, d_head] return q, k, v @@ -149,7 +154,8 @@ def calculate_attention_scores( Returns: Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores. """ - k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads) + if not self.cfg.ungroup_grouped_query_attention: + k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads) return super().calculate_attention_scores(q, k) def calculate_z_scores( @@ -167,5 +173,6 @@ def calculate_z_scores( Returns: Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores. """ - v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) + if not self.cfg.ungroup_grouped_query_attention: + v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) return super().calculate_z_scores(v, pattern) diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 6db16a195..469fe66e1 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -136,6 +136,7 @@ def forward( n_kv_heads = ( self.cfg.n_key_value_heads if self.cfg.n_key_value_heads is not None + and not self.cfg.ungroup_grouped_query_attention else self.cfg.n_heads ) query_input = self.hook_q_input(