Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
fix: error instead of SystemExit
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Dec 18, 2024
1 parent 7ec8142 commit 972d3c5
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def __init__(
window_size, by default None
dropout_p : float, optional
dropout probability, by default 0.0
softcap : float, optional
Anything > 0 activates softcapping attention, by default None
attention_implementation: str, optional
A predefined string which selects which underlying attention
implementation, by default "flash_attention"
Expand Down Expand Up @@ -101,7 +99,7 @@ def __init__(

if self.use_alibi_slopes:
self.alibi_slopes = get_alibi_slopes(num_heads)
assert self.alibi_slopes.shape[0] == num_heads
assert self.alibi_slopes.shape[0] == num_heads, "Error: Number of alibi_slopes must match number of heads"
else:
self.alibi_slopes = None

Expand All @@ -115,16 +113,14 @@ def set_attention_function(self):
"flex_attention": FlexAttentionWrapper,
"scaled_dot_product_attention": TorchAttentionWrapper,
}
if self.attention_implementation in attn_funcs:
LOGGER.info(f"attention.py: using {self.attention_implementation}")
# initalise the attn func here
self.attention = attn_funcs[self.attention_implementation]()
else:
# Requested attn implementation is not supported
raise SystemExit(
f"attention.py: Error! {self.attention_implementation} not supported. \
please change model.processor.attention_implementation in the config to one of: {attn_funcs.keys()}"
)
assert (
self.attention_implementation in attn_funcs
), f"{self.attention_implementation} not supported. \
Please change model.processor.attention_implementation in the config to one of: {attn_funcs.keys()}"
LOGGER.info(f"Using {self.attention_implementation}")

# initalise the attn func here
self.attention = attn_funcs[self.attention_implementation]()

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
Expand Down Expand Up @@ -173,7 +169,7 @@ def forward(


class TorchAttentionWrapper(nn.Module):
"""Wrapper for Pytorch dot product attention"""
"""Wrapper for Pytorch scaled dot product attention"""

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -211,12 +207,12 @@ def forward(
alibi_slopes=None,
):
if softcap is not None:
SystemError(
"Error. Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap."
NotImplementedError(
"Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap."
)
if alibi_slopes is not None:
SystemError(
"Error. Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
NotImplementedError(
"Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
)
if window_size is not None:
self.update_mask(query.shape[-2], window_size=window_size, device=query.device)
Expand All @@ -243,7 +239,7 @@ def __init__(self):
super().__init__()

if version.parse(torch.__version__) < version.parse("2.5.0"):
raise SystemExit("Error: torch version is too low. Update to 2.5.0 or higher to use Flex Attention.")
raise RuntimeError("Error: torch version is too low. Update to 2.5.0 or higher to use Flex Attention.")

# we compile flex attn once at the first iteration
# This is bc we need to know the seq len to compute the mask mod for sliding window
Expand All @@ -263,13 +259,13 @@ def forward(
):

if alibi_slopes is not None:
SystemExit("Error. Alibi_slopes not yet implemented in FlexAttn in Anemoi.")
NotImplementedError("Error. Alibi_slopes not yet implemented in FlexAttn in Anemoi.")
if softcap is not None:
SystemExit("Error. Softcap not yet implemented in FlexAttn in Anemoi.")
NotImplementedError("Error. Softcap not yet implemented in FlexAttn in Anemoi.")
if dropout_p != 0.0:
SystemExit("Error. Dropout not yet implemented in FlexAttn in Anemoi.")
NotImplementedError("Error. Dropout not yet implemented in FlexAttn in Anemoi.")
if causal:
SystemExit("Error. Causal not yet implemented in FlexAttn in Anemoi.")
NotImplementedError("Error. Causal not yet implemented in FlexAttn in Anemoi.")

# This assumes seq_len never changes
# across iterations and stages
Expand Down Expand Up @@ -312,7 +308,7 @@ def __init__(self):
import flash_attn

if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
raise SystemExit("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
raise RuntimeError("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
else:
self.attention = flash_attn.flash_attn_func

Expand Down

0 comments on commit 972d3c5

Please sign in to comment.