From a80f53854e15cf461d8b9ab4b5fc6fad64b6d7db Mon Sep 17 00:00:00 2001 From: Frank Odom Date: Thu, 9 May 2024 13:55:38 -0500 Subject: [PATCH] Fix discrepancy between grouped-SDPA and SDPA --- grouped_query_attention_pytorch/attention.py | 34 +++++++------------- tests/test_attention.py | 18 ++++++++--- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/grouped_query_attention_pytorch/attention.py b/grouped_query_attention_pytorch/attention.py index 4d319d0..ecd15cc 100644 --- a/grouped_query_attention_pytorch/attention.py +++ b/grouped_query_attention_pytorch/attention.py @@ -85,25 +85,13 @@ def scaled_dot_product_gqa( query = query / scale num_head_groups = hq // hk - if num_head_groups > 1 or force_grouped: - # Separate the query heads into 'num_head_groups' chunks, and fold the group - # dimension into the batch dimension. This allows us to compute the attention - # for each head in parallel, then sum over all of the groups at the end. - query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) - similarity = einsum(query, key, "b g h n d, b h s d -> b h n s") - else: - # If the number of query/key heads is equal, we can skip grouping the queries, - # and just use the standard sdot product attention. - similarity = einsum(query, key, "b h n d, b h s d -> b h n s") + query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) + similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s") if is_causal: # Mask out the upper triangular portion of the attention matrix. This prevents # the model from attending to tokens in the future. - mask = torch.ones( - (bq, nq, nk), - device=query.device, - dtype=torch.bool, - ).tril_() + mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_() if mask is not None: # Expand mask to match the shape of the attention matrix. @@ -115,28 +103,28 @@ def scaled_dot_product_gqa( # sequence dimension for each attention head (though I don't have a particular # use case in mind for that). if mask.ndim == 2: - mask = rearrange(mask, "b s -> b () () s") + mask = rearrange(mask, "b s -> b () () () s") elif mask.ndim == 3: - mask = rearrange(mask, "b n s -> b () n s") + mask = rearrange(mask, "b n s -> b () () n s") # Mask similarity values by setting them to negative infinity. This guarantees # that they will not contribute to the softmax computation below. similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min) - attention = F.softmax(similarity / scale, dim=-1) + attention = F.softmax(similarity, dim=-1) if dropout > 0.0: attention = F.dropout(attention, p=dropout) # Apply attention matrix to the value Tensor. - out = einsum(attention, value, "b h n s, b h s d -> b h n d") + out = einsum(attention, value, "b g h n s, b h s d -> b g h n d") # Move head dimension back to axis 2 - out = rearrange(out, "b h n d -> b n h d") + out = rearrange(out, "b g h n d -> b n (h g) d") attn_weights: Optional[Tensor] = None if need_weights: # Move the sequence dimensions back to positions 1, 2. Move the head dimension # to position 3. This more closely matches the return shape of the attention # output: (b, n, h, d). - attn_weights = rearrange(attention, "b h n s -> b n s h") + attn_weights = rearrange(attention, "b g h n s -> b n s (h g)") if average_attn_weights: attn_weights = attn_weights.mean(dim=1) @@ -222,13 +210,13 @@ def __init__( self.norm: Optional[nn.LayerNorm] = None if layer_norm: self.norm = nn.LayerNorm( - kv_embed_dim, eps=layer_norm_eps, device=device, dtype=dtype + embed_dim, eps=layer_norm_eps, device=device, dtype=dtype ) # Grouped attention output will have the same embedding dimension as the # key/value Tensors. So the output projection layer needs to accept the # same dimension (kv_embed_dim). self.out_proj = nn.Linear( - kv_embed_dim, embed_dim, bias=bias, device=device, dtype=dtype + embed_dim, embed_dim, bias=bias, device=device, dtype=dtype ) self._reset_parameters() diff --git a/tests/test_attention.py b/tests/test_attention.py index e6ca0cb..45144cf 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,17 +1,19 @@ import pytest import torch +import torch.nn.functional as F from grouped_query_attention_pytorch.attention import ( MultiheadGQA, scaled_dot_product_gqa, ) +torch.backends.cudnn.deterministic = True DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 +DTYPE = torch.float64 SEQ_LEN = 16 -@pytest.mark.parametrize("embed_dim", [32]) +@pytest.mark.parametrize("embed_dim", [64]) @pytest.mark.parametrize("num_heads", [4, 8]) @pytest.mark.parametrize("kv_heads", [4, 8]) @pytest.mark.parametrize("is_causal", [True, False]) @@ -34,12 +36,20 @@ def test_grouped_scaled_dot_product_attention( ) assert out.size(0) == 1 assert out.size(1) == SEQ_LEN - assert out.size(2) == kv_heads + assert out.size(2) == num_heads assert out.size(3) == embed_dim assert attn_weights.size(0) == 1 assert attn_weights.size(1) == SEQ_LEN assert attn_weights.size(2) == SEQ_LEN - assert attn_weights.size(3) == kv_heads + assert attn_weights.size(3) == num_heads + + # Test that grouped SDPA is equivalent to SDPA if we duplicate the KV heads. + kv = kv.repeat_interleave(num_heads // kv_heads, dim=2) + kv = kv.permute(0, 2, 1, 3) + x = x.permute(0, 2, 1, 3) + out_vanilla = F.scaled_dot_product_attention(x, kv, kv, is_causal=is_causal) + out_vanilla = out_vanilla.permute(0, 2, 1, 3) + torch.testing.assert_close(out, out_vanilla) @torch.no_grad()