Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
chore: refactor SDPAAttention update_mask method
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Dec 18, 2024
1 parent 972d3c5 commit e89fd2e
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def set_attention_function(self):
attn_funcs = {
"flash_attention": FlashAttentionWrapper,
"flex_attention": FlexAttentionWrapper,
"scaled_dot_product_attention": TorchAttentionWrapper,
"scaled_dot_product_attention": SDPAAttentionWrapper,
}
assert (
self.attention_implementation in attn_funcs
Expand Down Expand Up @@ -168,7 +168,7 @@ def forward(
return out


class TorchAttentionWrapper(nn.Module):
class SDPAAttentionWrapper(nn.Module):
"""Wrapper for Pytorch scaled dot product attention"""

def __init__(self):
Expand All @@ -181,18 +181,13 @@ def __init__(self):
self.window_size = None

def update_mask(self, seq_len, window_size: int, device: str):
update_mask = (
self.mask is None or self.window_size != window_size or tuple(self.mask.shape) != (seq_len, seq_len)
)
if update_mask:
self.window_size = window_size
self.mask = (
torch.abs(
torch.arange(seq_len, device=device).unsqueeze(0)
- torch.arange(seq_len, device=device).unsqueeze(1)
)
<= window_size

self.mask = (
torch.abs(
torch.arange(seq_len, device=device).unsqueeze(0) - torch.arange(seq_len, device=device).unsqueeze(1)
)
<= window_size
)

def forward(
self,
Expand All @@ -214,10 +209,11 @@ def forward(
NotImplementedError(
"Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
)
if window_size is not None:
self.update_mask(query.shape[-2], window_size=window_size, device=query.device)
else:
self.mask = None

sequence_len = query.shape[-2]

if window_size is not None and (self.mask is None or tuple(self.mask.shape) != (sequence_len, sequence_len)):
self.update_mask(sequence_len, window_size=window_size, device=query.device)

with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
out = self.attention(
Expand Down

0 comments on commit e89fd2e

Please sign in to comment.