Skip to content

Commit

Permalink
Merge pull request #723 from TransformerLensOrg/dev
Browse files Browse the repository at this point in the history
2.6
  • Loading branch information
bryce13950 authored Sep 13, 2024
2 parents be334fb + 87edf1d commit e64888d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 6 deletions.
145 changes: 143 additions & 2 deletions tests/integration/test_grouped_query_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import einops
import torch

from transformer_lens import HookedTransformer
from transformer_lens.components import Attention, GroupedQueryAttention
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig

Expand Down Expand Up @@ -55,7 +56,7 @@ def test_grouped_query_attention_output_is_correct():
"mask": regular_attention.state_dict()["mask"],
"IGNORE": regular_attention.state_dict()["IGNORE"],
}
grouped_query_attemtion_state_dict = {
grouped_query_attention_state_dict = {
"W_Q": W_Q,
"b_Q": b_Q,
"W_O": W_O,
Expand All @@ -69,7 +70,7 @@ def test_grouped_query_attention_output_is_correct():
}

regular_attention.load_state_dict(regular_attention_state_dict)
grouped_query_attention.load_state_dict(grouped_query_attemtion_state_dict)
grouped_query_attention.load_state_dict(grouped_query_attention_state_dict)

query_input = torch.rand((1, 5, d_model))
key_input = torch.rand((1, 5, d_model))
Expand All @@ -92,3 +93,143 @@ def test_grouped_query_attention_output_is_correct():
)

assert torch.allclose(regular_attn_output, split_grouped_query_attn_output, rtol=1e-6)


def test_ungroup_grouped_query_attention_flag_produces_same_result():
d_model = 512
d_head = 32
n_heads = 16
n_ctx = 128
n_key_value_heads = 4
n_layers = 1

cfg_flag_off = HookedTransformerConfig(
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
n_ctx=n_ctx,
n_key_value_heads=n_key_value_heads,
n_layers=n_layers,
act_fn="silu",
ungroup_grouped_query_attention=False,
)
grouped_query_attention_flag_off = GroupedQueryAttention(cfg_flag_off)

cfg_flag_on = HookedTransformerConfig(
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
n_ctx=n_ctx,
n_key_value_heads=n_key_value_heads,
n_layers=n_layers,
act_fn="silu",
ungroup_grouped_query_attention=True,
)
grouped_query_attention_flag_on = GroupedQueryAttention(cfg_flag_on)

W_Q = torch.rand((n_heads, d_model, d_head))
b_Q = torch.rand((n_heads, d_head))
_W_K = torch.rand((n_key_value_heads, d_model, d_head))
_b_K = torch.rand((n_key_value_heads, d_head))
_W_V = torch.rand((n_key_value_heads, d_model, d_head))
_b_V = torch.rand((n_key_value_heads, d_head))
W_O = torch.rand((n_heads, d_head, d_model))
b_O = torch.rand(d_model)

grouped_query_attention_state_dict = {
"W_Q": W_Q,
"b_Q": b_Q,
"W_O": W_O,
"b_O": b_O,
"_W_K": _W_K,
"_b_K": _b_K,
"_W_V": _W_V,
"_b_V": _b_V,
"mask": grouped_query_attention_flag_off.state_dict()["mask"],
"IGNORE": grouped_query_attention_flag_off.state_dict()["IGNORE"],
}

grouped_query_attention_flag_off.load_state_dict(grouped_query_attention_state_dict)
grouped_query_attention_flag_on.load_state_dict(grouped_query_attention_state_dict)

query_input = torch.rand((1, 5, d_model))
key_input = torch.rand((1, 5, d_model))
value_input = torch.rand((1, 5, d_model))

grouped_query_attn_flag_off_output = grouped_query_attention_flag_off(
query_input, key_input, value_input
)
grouped_query_attn_flag_on_output = grouped_query_attention_flag_on(
query_input, key_input, value_input
)

assert torch.equal(grouped_query_attn_flag_off_output, grouped_query_attn_flag_on_output)


def test_ungroup_grouped_query_attention_flag_changes_k_v_hooks_shape():
d_model = 512
d_head = 32
n_heads = 16
n_ctx = 128
n_key_value_heads = 4
n_layers = 1
d_vocab = 10

cfg = HookedTransformerConfig(
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
n_ctx=n_ctx,
n_key_value_heads=n_key_value_heads,
n_layers=n_layers,
act_fn="silu",
d_vocab=d_vocab,
use_split_qkv_input=True,
ungroup_grouped_query_attention=False,
)

model = HookedTransformer(cfg)
assert model.cfg.ungroup_grouped_query_attention is False

x = torch.arange(1, 9).unsqueeze(0)
flag_off_output, flag_off_cache = model.run_with_cache(
x,
names_filter=[
"blocks.0.attn.hook_k",
"blocks.0.attn.hook_v",
"blocks.0.hook_k_input",
"blocks.0.hook_v_input",
],
)

model.set_ungroup_grouped_query_attention(True)
assert model.cfg.ungroup_grouped_query_attention is True

flag_on_output, flag_on_cache = model.run_with_cache(
x,
names_filter=[
"blocks.0.attn.hook_k",
"blocks.0.attn.hook_v",
"blocks.0.hook_k_input",
"blocks.0.hook_v_input",
],
)

assert (
flag_on_cache["blocks.0.attn.hook_k"].shape[2]
== flag_off_cache["blocks.0.attn.hook_k"].shape[2] * n_key_value_heads
)
assert (
flag_on_cache["blocks.0.attn.hook_v"].shape[2]
== flag_off_cache["blocks.0.attn.hook_v"].shape[2] * n_key_value_heads
)
assert (
flag_on_cache["blocks.0.hook_k_input"].shape[2]
== flag_off_cache["blocks.0.hook_k_input"].shape[2] * n_key_value_heads
)
assert (
flag_on_cache["blocks.0.hook_v_input"].shape[2]
== flag_off_cache["blocks.0.hook_v_input"].shape[2] * n_key_value_heads
)

assert torch.equal(flag_off_output, flag_on_output)
6 changes: 6 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,12 @@ def set_use_attn_in(self, use_attn_in: bool):
"""
self.cfg.use_attn_in = use_attn_in

def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool):
"""
Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA).
"""
self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention

def process_weights_(
self,
fold_ln: bool = True,
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class HookedTransformerConfig:
attention head separately, with a hook. Defaults to false to save memory
use_attn_scale (bool): whether to scale the attention weights by
1/sqrt(d_head)
ungroup_grouped_query_attention (bool): whether to ungroup key and value heads, for models that use
grouped query attention.
attn_scale (float): The amount to divide attention scores by (if applicable). Defaults to
sqrt(d_head)
model_name (str): the name of the model, used to load
Expand Down Expand Up @@ -199,6 +201,7 @@ class HookedTransformerConfig:
use_hook_mlp_in: bool = False
use_attn_in: bool = False
use_local_attn: bool = False
ungroup_grouped_query_attention: bool = False
original_architecture: Optional[str] = None
from_checkpoint: bool = False
checkpoint_index: Optional[int] = None
Expand Down
15 changes: 11 additions & 4 deletions transformer_lens/components/grouped_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ def calculate_qkv_matrices(
q = self.hook_q(
attn_fn(query_input, self.W_Q, self.b_Q)
) # [batch, pos, head_index, d_head]

k = self.hook_k(
attn_fn(key_input, self._W_K, self._b_K)
attn_fn(key_input, self.W_K, self.b_K)
if self.cfg.ungroup_grouped_query_attention
else attn_fn(key_input, self._W_K, self._b_K)
) # [batch, pos, head_index, d_head]
v = self.hook_v(
attn_fn(value_input, self._W_V, self._b_V)
attn_fn(value_input, self.W_V, self.b_V)
if self.cfg.ungroup_grouped_query_attention
else attn_fn(value_input, self._W_V, self._b_V)
) # [batch, pos, head_index, d_head]
return q, k, v

Expand All @@ -149,7 +154,8 @@ def calculate_attention_scores(
Returns:
Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores.
"""
k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads)
if not self.cfg.ungroup_grouped_query_attention:
k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads)
return super().calculate_attention_scores(q, k)

def calculate_z_scores(
Expand All @@ -167,5 +173,6 @@ def calculate_z_scores(
Returns:
Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores.
"""
v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads)
if not self.cfg.ungroup_grouped_query_attention:
v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads)
return super().calculate_z_scores(v, pattern)
1 change: 1 addition & 0 deletions transformer_lens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def forward(
n_kv_heads = (
self.cfg.n_key_value_heads
if self.cfg.n_key_value_heads is not None
and not self.cfg.ungroup_grouped_query_attention
else self.cfg.n_heads
)
query_input = self.hook_q_input(
Expand Down

0 comments on commit e64888d

Please sign in to comment.