Skip to content

Commit

Permalink
Merge pull request #6 from fkodom/bug-fix/kv-groups
Browse files Browse the repository at this point in the history
Bug Fix: Discrepancy between grouped-SDPA and SDPA
  • Loading branch information
fkodom authored May 9, 2024
2 parents 0b5ab63 + a80f538 commit 7a69d44
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
34 changes: 11 additions & 23 deletions grouped_query_attention_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,13 @@ def scaled_dot_product_gqa(
query = query / scale

num_head_groups = hq // hk
if num_head_groups > 1 or force_grouped:
# Separate the query heads into 'num_head_groups' chunks, and fold the group
# dimension into the batch dimension. This allows us to compute the attention
# for each head in parallel, then sum over all of the groups at the end.
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
similarity = einsum(query, key, "b g h n d, b h s d -> b h n s")
else:
# If the number of query/key heads is equal, we can skip grouping the queries,
# and just use the standard sdot product attention.
similarity = einsum(query, key, "b h n d, b h s d -> b h n s")
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s")

if is_causal:
# Mask out the upper triangular portion of the attention matrix. This prevents
# the model from attending to tokens in the future.
mask = torch.ones(
(bq, nq, nk),
device=query.device,
dtype=torch.bool,
).tril_()
mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_()

if mask is not None:
# Expand mask to match the shape of the attention matrix.
Expand All @@ -115,28 +103,28 @@ def scaled_dot_product_gqa(
# sequence dimension for each attention head (though I don't have a particular
# use case in mind for that).
if mask.ndim == 2:
mask = rearrange(mask, "b s -> b () () s")
mask = rearrange(mask, "b s -> b () () () s")
elif mask.ndim == 3:
mask = rearrange(mask, "b n s -> b () n s")
mask = rearrange(mask, "b n s -> b () () n s")
# Mask similarity values by setting them to negative infinity. This guarantees
# that they will not contribute to the softmax computation below.
similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min)

attention = F.softmax(similarity / scale, dim=-1)
attention = F.softmax(similarity, dim=-1)
if dropout > 0.0:
attention = F.dropout(attention, p=dropout)

# Apply attention matrix to the value Tensor.
out = einsum(attention, value, "b h n s, b h s d -> b h n d")
out = einsum(attention, value, "b g h n s, b h s d -> b g h n d")
# Move head dimension back to axis 2
out = rearrange(out, "b h n d -> b n h d")
out = rearrange(out, "b g h n d -> b n (h g) d")

attn_weights: Optional[Tensor] = None
if need_weights:
# Move the sequence dimensions back to positions 1, 2. Move the head dimension
# to position 3. This more closely matches the return shape of the attention
# output: (b, n, h, d).
attn_weights = rearrange(attention, "b h n s -> b n s h")
attn_weights = rearrange(attention, "b g h n s -> b n s (h g)")
if average_attn_weights:
attn_weights = attn_weights.mean(dim=1)

Expand Down Expand Up @@ -222,13 +210,13 @@ def __init__(
self.norm: Optional[nn.LayerNorm] = None
if layer_norm:
self.norm = nn.LayerNorm(
kv_embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
)
# Grouped attention output will have the same embedding dimension as the
# key/value Tensors. So the output projection layer needs to accept the
# same dimension (kv_embed_dim).
self.out_proj = nn.Linear(
kv_embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)

self._reset_parameters()
Expand Down
18 changes: 14 additions & 4 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import pytest
import torch
import torch.nn.functional as F

from grouped_query_attention_pytorch.attention import (
MultiheadGQA,
scaled_dot_product_gqa,
)

torch.backends.cudnn.deterministic = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DTYPE = torch.float64
SEQ_LEN = 16


@pytest.mark.parametrize("embed_dim", [32])
@pytest.mark.parametrize("embed_dim", [64])
@pytest.mark.parametrize("num_heads", [4, 8])
@pytest.mark.parametrize("kv_heads", [4, 8])
@pytest.mark.parametrize("is_causal", [True, False])
Expand All @@ -34,12 +36,20 @@ def test_grouped_scaled_dot_product_attention(
)
assert out.size(0) == 1
assert out.size(1) == SEQ_LEN
assert out.size(2) == kv_heads
assert out.size(2) == num_heads
assert out.size(3) == embed_dim
assert attn_weights.size(0) == 1
assert attn_weights.size(1) == SEQ_LEN
assert attn_weights.size(2) == SEQ_LEN
assert attn_weights.size(3) == kv_heads
assert attn_weights.size(3) == num_heads

# Test that grouped SDPA is equivalent to SDPA if we duplicate the KV heads.
kv = kv.repeat_interleave(num_heads // kv_heads, dim=2)
kv = kv.permute(0, 2, 1, 3)
x = x.permute(0, 2, 1, 3)
out_vanilla = F.scaled_dot_product_attention(x, kv, kv, is_causal=is_causal)
out_vanilla = out_vanilla.permute(0, 2, 1, 3)
torch.testing.assert_close(out, out_vanilla)


@torch.no_grad()
Expand Down

0 comments on commit 7a69d44

Please sign in to comment.