From e89fd2e3ae3318c31a96b3ca93fcc4089f782948 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Wed, 18 Dec 2024 15:25:52 +0000 Subject: [PATCH] chore: refactor SDPAAttention update_mask method --- src/anemoi/models/layers/attention.py | 30 ++++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 80823849..9ee8e104 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -111,7 +111,7 @@ def set_attention_function(self): attn_funcs = { "flash_attention": FlashAttentionWrapper, "flex_attention": FlexAttentionWrapper, - "scaled_dot_product_attention": TorchAttentionWrapper, + "scaled_dot_product_attention": SDPAAttentionWrapper, } assert ( self.attention_implementation in attn_funcs @@ -168,7 +168,7 @@ def forward( return out -class TorchAttentionWrapper(nn.Module): +class SDPAAttentionWrapper(nn.Module): """Wrapper for Pytorch scaled dot product attention""" def __init__(self): @@ -181,18 +181,13 @@ def __init__(self): self.window_size = None def update_mask(self, seq_len, window_size: int, device: str): - update_mask = ( - self.mask is None or self.window_size != window_size or tuple(self.mask.shape) != (seq_len, seq_len) - ) - if update_mask: - self.window_size = window_size - self.mask = ( - torch.abs( - torch.arange(seq_len, device=device).unsqueeze(0) - - torch.arange(seq_len, device=device).unsqueeze(1) - ) - <= window_size + + self.mask = ( + torch.abs( + torch.arange(seq_len, device=device).unsqueeze(0) - torch.arange(seq_len, device=device).unsqueeze(1) ) + <= window_size + ) def forward( self, @@ -214,10 +209,11 @@ def forward( NotImplementedError( "Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes." ) - if window_size is not None: - self.update_mask(query.shape[-2], window_size=window_size, device=query.device) - else: - self.mask = None + + sequence_len = query.shape[-2] + + if window_size is not None and (self.mask is None or tuple(self.mask.shape) != (sequence_len, sequence_len)): + self.update_mask(sequence_len, window_size=window_size, device=query.device) with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): out = self.attention(