diff --git a/CHANGELOG.md b/CHANGELOG.md index acf1aaed..2bcfd5d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,8 +39,16 @@ Keep it human-readable, your future self will thank you! ### Added +- CI workflow to update the changelog on release +- add configurability of flash attention (#47) +- configurabilty of the dropout probability in the the MultiHeadSelfAttention module - CI workflow to update the changelog on release - Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords. +- Codeowners file +- Pygrep precommit hooks +- Docsig precommit hooks +- Changelog merge strategy + ### Changed @@ -48,6 +56,7 @@ Keep it human-readable, your future self will thank you! - run downstream-ci only when src and tests folders have changed - New error messages for wrongs graphs. - Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45) +- Bugfixes for CI ### Removed diff --git a/pyproject.toml b/pyproject.toml index 214f82c1..ecdfd9a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "anemoi-utils>=0.1.9", "einops>=0.6.1", "hydra-core>=1.3", - "torch>=2.2", + "torch>=2.5", "torch-geometric>=2.3,<2.5", ] optional-dependencies.all = [ ] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..f77701dd --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +markers = + data_dependent: marks tests depending on data (deselect with '-m "not data_dependent"') + auth: marks tests that require authentication (deselect with '-m "not auth"') + gpu: marks tests that require a GPU (deselect with '-m "not gpu"') + +tmp_path_retention_policy = none diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f54920..1ec5a5a9 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -8,23 +8,19 @@ # nor does it submit to any jurisdiction. +from __future__ import annotations + import logging +import math from typing import Optional import einops +import torch +from packaging import version from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup -try: - from flash_attn import flash_attn_func as attn_func -except ImportError: - from torch.nn.functional import scaled_dot_product_attention as attn_func - - _FLASH_ATTENTION_AVAILABLE = False -else: - _FLASH_ATTENTION_AVAILABLE = True - from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence @@ -32,7 +28,13 @@ class MultiHeadSelfAttention(nn.Module): - """Multi Head Self Attention Pytorch Layer.""" + """Multi Head Self Attention Pytorch Layer + + allows for three different attention implementations: + - scaled dot product attention, see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + - flash attention, see https://github.com/Dao-AILab/flash-attention + - flex attention, see https://pytorch.org/blog/flexattention/ + """ def __init__( self, @@ -42,31 +44,88 @@ def __init__( is_causal: bool = False, window_size: Optional[int] = None, dropout_p: float = 0.0, + attention_implementation: str = "flash_attention", + softcap: float = None, + use_alibi_slopes: bool = False, ): + """Initialize MultiHeadSelfAttention. + + For the flash attention implementation, two additional parameters are available: softcap, use_alibi_slopes + + softcap: Softcapping prevents the logits from grwoing excessively large + + use_alibi_slopes: Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) to the attention score of + query i and key j, where alibi_slope is calculated using get_alibi_slopes + + Parameters + ---------- + num_heads : int + number of heads + embed_dim : int + embedding dimension + bias : bool, optional + bias, by default False + is_causal : bool, optional + apply causal attention mask, by default False + window_size : Optional[int], optional + window_size, by default None + dropout_p : float, optional + dropout probability, by default 0.0 + attention_implementation: str, optional + A predefined string which selects which underlying attention + implementation, by default "flash_attention" + softcap : float, optional + Anything > 0 activates softcapping attention, by default None + use_alibi_slopes : bool, optional + Adds bias + """ super().__init__() assert ( embed_dim % num_heads == 0 ), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})" + self.attention_implementation = attention_implementation + self.use_alibi_slopes = use_alibi_slopes + self.set_attention_function() + self.num_heads = num_heads self.embed_dim = embed_dim self.head_dim = embed_dim // num_heads # q k v - self.window_size = (window_size, window_size) # flash attention + self.window_size = window_size self.dropout_p = dropout_p self.is_causal = is_causal + self.softcap = softcap - self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) - self.attention = attn_func + if self.use_alibi_slopes: + self.alibi_slopes = get_alibi_slopes(num_heads) + assert self.alibi_slopes.shape[0] == num_heads, "Error: Number of alibi_slopes must match number of heads" + else: + self.alibi_slopes = None - if not _FLASH_ATTENTION_AVAILABLE: - LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") + self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) self.projection = nn.Linear(embed_dim, embed_dim, bias=True) + def set_attention_function(self): + attn_funcs = { + "flash_attention": FlashAttentionWrapper, + "flex_attention": FlexAttentionWrapper, + "scaled_dot_product_attention": SDPAAttentionWrapper, + } + assert ( + self.attention_implementation in attn_funcs + ), f"{self.attention_implementation} not supported. \ + Please change model.processor.attention_implementation in the config to one of: {attn_funcs.keys()}" + LOGGER.info(f"Using {self.attention_implementation}") + + # initalise the attn func here + self.attention = attn_funcs[self.attention_implementation]() + def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: + query, key, value = self.lin_qkv(x).chunk(3, -1) if model_comm_group: @@ -89,24 +148,218 @@ def forward( value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) dropout_p = self.dropout_p if self.training else 0.0 - if _FLASH_ATTENTION_AVAILABLE: - query, key, value = ( - einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) + out = self.attention( + query, + key, + value, + batch_size, + causal=False, + window_size=self.window_size, + dropout_p=dropout_p, + softcap=self.softcap, + alibi_slopes=self.alibi_slopes, + ) + + out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) + out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") + + out = self.projection(out) + + return out + + +class SDPAAttentionWrapper(nn.Module): + """Wrapper for Pytorch scaled dot product attention""" + + def __init__(self): + super().__init__() + + from torch.nn.functional import scaled_dot_product_attention + + self.attention = scaled_dot_product_attention + self.mask = None + self.window_size = None + + def update_mask(self, seq_len, window_size: int, device: str): + + self.mask = ( + torch.abs( + torch.arange(seq_len, device=device).unsqueeze(0) - torch.arange(seq_len, device=device).unsqueeze(1) ) - out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p) - out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") - else: + <= window_size + ) + + def forward( + self, + query, + key, + value, + batch_size: int, + causal=False, + window_size=None, + dropout_p=0.0, + softcap=None, + alibi_slopes=None, + ): + if softcap is not None: + NotImplementedError( + "Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap." + ) + if alibi_slopes is not None: + NotImplementedError( + "Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes." + ) + + sequence_len = query.shape[-2] + + if window_size is not None and (self.mask is None or tuple(self.mask.shape) != (sequence_len, sequence_len)): + self.update_mask(sequence_len, window_size=window_size, device=query.device) + + with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): out = self.attention( query, key, value, - is_causal=False, + attn_mask=self.mask, + is_causal=causal, dropout_p=dropout_p, - ) # expects (batch heads grid variable) format + ) - out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) - out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") + return out - out = self.projection(out) + +class FlexAttentionWrapper(nn.Module): + """Wrapper for Pytorch Flex attention.""" + + def __init__(self): + super().__init__() + + if version.parse(torch.__version__) < version.parse("2.5.0"): + raise RuntimeError("Error: torch version is too low. Update to 2.5.0 or higher to use Flex Attention.") + + # we compile flex attn once at the first iteration + # This is bc we need to know the seq len to compute the mask mod for sliding window + self.is_attn_compiled = False + + def forward( + self, + query, + key, + value, + batch_size: int, + causal: bool = False, + window_size: int = None, + dropout_p: float = 0.0, + softcap: float = None, + alibi_slopes: torch.Tensor = None, + ): + + if alibi_slopes is not None: + NotImplementedError("Error. Alibi_slopes not yet implemented in FlexAttn in Anemoi.") + if softcap is not None: + NotImplementedError("Error. Softcap not yet implemented in FlexAttn in Anemoi.") + if dropout_p != 0.0: + NotImplementedError("Error. Dropout not yet implemented in FlexAttn in Anemoi.") + if causal: + NotImplementedError("Error. Causal not yet implemented in FlexAttn in Anemoi.") + + # This assumes seq_len never changes + # across iterations and stages + # could add something like + # if query.shape[2] != prev_seq_len: + # self.is_attn_compiled = False + # To trigger a recompilation + if not self.is_attn_compiled: + import functools + + from torch.nn.attention.flex_attention import create_block_mask # should this be after the version check? + from torch.nn.attention.flex_attention import flex_attention + + if window_size is not None: + def sliding_window_mask(b, h, q_idx, kv_idx): + return abs(q_idx - kv_idx) <= window_size + + seq_len = query.shape[2] + self.block_mask = create_block_mask( + sliding_window_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, _compile=True + ) + self.attention = functools.partial( + flex_attention, block_mask=self.block_mask + ) # Cache the block mask (recomended in attn blog post) + else: + self.attention = flex_attention + self.attention = torch.compile(self.attention) + self.is_attn_compiled = True + + torch._dynamo.config.optimize_ddp = False + out = self.attention(query, key, value) + torch._dynamo.config.optimize_ddp = True return out + + +class FlashAttentionWrapper(nn.Module): + """Wrapper for Flash attention.""" + + def __init__(self): + super().__init__() + import flash_attn + + if version.parse(flash_attn.__version__) < version.parse("2.6.0"): + raise RuntimeError("Error: Flash-attn version is too low. Update to 2.6.0 or higher.") + else: + self.attention = flash_attn.flash_attn_func + + def forward( + self, + query, + key, + value, + batch_size: int, + causal: bool = False, + window_size: int = None, + dropout_p: float = 0.0, + softcap: float = None, + alibi_slopes: torch.Tensor = None, + ): + query, key, value = ( + einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) + ) + + alibi_slopes = alibi_slopes.repeat(batch_size, 1).to(query.device) if alibi_slopes is not None else None + + out = self.attention( + query, + key, + value, + causal=False, + window_size=(window_size, window_size), + dropout_p=dropout_p, + softcap=softcap, + alibi_slopes=alibi_slopes, + ) + out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") + return out + + +def get_alibi_slopes(num_heads: int) -> Tensor: + """Calculates linearly decreasing slopes for alibi attention. + + Parameters + ---------- + num_heads : int + number of attention heads + + Returns + ------- + Tensor + aLiBi slopes + """ + n = 2 ** math.floor(math.log2(num_heads)) + slope_0 = 2 ** (-8 / n) + alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n)) + if n < num_heads: + slope_hat_0 = 2 ** (-4 / n) + alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) + alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat]) + return alibi_slopes diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6c..8a88b1fe 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -69,6 +69,9 @@ def __init__( activation: str, window_size: int, dropout_p: float = 0.0, + attention_implementation: str = "flash_attention", + softcap: float = None, + use_alibi_slopes: bool = None, ): super().__init__() @@ -87,6 +90,9 @@ def __init__( bias=False, is_causal=False, dropout_p=dropout_p, + attention_implementation=attention_implementation, + softcap=softcap, + use_alibi_slopes=use_alibi_slopes, ) self.mlp = nn.Sequential( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 5c4fae38..4f2a9b20 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -75,6 +75,9 @@ def __init__( mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, + attention_implementation: str = "flash_attention", + softcap: float = None, + use_alibi_slopes: bool = None, ) -> None: """Initialize TransformerProcessor. @@ -92,6 +95,13 @@ def __init__( Activation function, by default "GELU" dropout_p: float Dropout probability used for multi-head self attention, default 0.0 + attention_implementation: str, optional + A predefined string which selects which underlying attention + implementation, by default "flash_attention" + softcap : float, optional + Anything > 0 activates softcapping flash attention, by default None + use_alibi_slopes : bool, optional + Use aLiBI option, only used for flash attention, by default None """ super().__init__(num_channels=num_channels, num_layers=num_layers) @@ -103,6 +113,9 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, + attention_implementation=attention_implementation, + softcap=softcap, + use_alibi_slopes=use_alibi_slopes, ) def forward( diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 4fd32311..bb9547e8 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -97,6 +97,9 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, + attention_implementation: str = "flash_attention", + softcap: float = 0.0, + use_alibi_slopes: bool = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -117,6 +120,13 @@ def __init__( Activation function, by default "GELU" dropout_p: float, optional Dropout probability used for multi-head self attention, default 0.0 + attention_implementation: str, optional + A predefined string which selects which underlying attention + implementation, by default "flash_attention" + softcap : float, optional + Anything > 0 activates softcapping flash attention, by default None + use_alibi_slopes : bool, optional + Use aLiBI option, only used for flash attention, by default None """ super().__init__( num_channels=num_channels, @@ -138,6 +148,9 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, + attention_implementation=attention_implementation, + softcap=softcap, + use_alibi_slopes=use_alibi_slopes, ) self.offload_layers(cpu_offload) diff --git a/tests/layers/block/test_block_transformer.py b/tests/layers/block/test_block_transformer.py index 46541e08..087a351c 100644 --- a/tests/layers/block/test_block_transformer.py +++ b/tests/layers/block/test_block_transformer.py @@ -33,12 +33,20 @@ class TestTransformerProcessorBlock: activation=st.sampled_from(["ReLU", "GELU", "Tanh"]), window_size=st.integers(min_value=1, max_value=512), dropout_p=st.floats(min_value=0.0, max_value=1.0), + softcap=st.floats(min_value=0.0, max_value=1.0), ) @settings(max_examples=10) - def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p): + def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p, softcap): num_channels = num_heads * factor_attention_heads block = TransformerProcessorBlock( - num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p + num_channels, + hidden_dim, + num_heads, + activation, + window_size, + dropout_p=dropout_p, + attention_implementation="scaled_dot_product_attention", + softcap=softcap, ) assert isinstance(block, TransformerProcessorBlock) @@ -56,6 +64,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3), batch_size=st.integers(min_value=1, max_value=40), dropout_p=st.floats(min_value=0.0, max_value=1.0), + softcap=st.floats(min_value=0.0, max_value=1.0), ) @settings(max_examples=10) def test_forward_output( @@ -68,14 +77,21 @@ def test_forward_output( shapes, batch_size, dropout_p, + softcap, ): num_channels = num_heads * factor_attention_heads block = TransformerProcessorBlock( - num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p + num_channels, + hidden_dim, + num_heads, + activation, + window_size, + dropout_p=dropout_p, + attention_implementation="scaled_dot_product_attention", + softcap=softcap, ) - x = torch.randn((batch_size, num_channels)) - + x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True) output = block.forward(x, shapes, batch_size) assert isinstance(output, torch.Tensor) assert output.shape == (batch_size, num_channels) diff --git a/tests/layers/chunk/test_chunk_transformer.py b/tests/layers/chunk/test_chunk_transformer.py index 86989486..553500b6 100644 --- a/tests/layers/chunk/test_chunk_transformer.py +++ b/tests/layers/chunk/test_chunk_transformer.py @@ -24,6 +24,7 @@ def init(self): activation: str = "GELU" window_size: int = 13 dropout_p: float = 0.1 + attention_implementation = "scaled_dot_product_attention" # num_heads must be evenly divisible by num_channels for MHSA return ( @@ -34,6 +35,7 @@ def init(self): activation, window_size, dropout_p, + attention_implementation, ) @pytest.fixture @@ -46,6 +48,7 @@ def processor_chunk(self, init): activation, window_size, dropout_p, + attention_implementation, ) = init return TransformerProcessorChunk( num_channels=num_channels, @@ -55,6 +58,7 @@ def processor_chunk(self, init): activation=activation, window_size=window_size, dropout_p=dropout_p, + attention_implementation=attention_implementation, ) def test_all_blocks(self, processor_chunk): diff --git a/tests/layers/processor/test_transformer_processor.py b/tests/layers/processor/test_transformer_processor.py index b94ff63f..0968465e 100644 --- a/tests/layers/processor/test_transformer_processor.py +++ b/tests/layers/processor/test_transformer_processor.py @@ -25,6 +25,8 @@ def transformer_processor_init(): num_heads = 16 mlp_hidden_ratio = 4 dropout_p = 0.1 + softcap = 0.5 + attention_implementation = "scaled_dot_product_attention" return ( num_layers, window_size, @@ -35,6 +37,8 @@ def transformer_processor_init(): num_heads, mlp_hidden_ratio, dropout_p, + softcap, + attention_implementation, ) @@ -50,6 +54,8 @@ def transformer_processor(transformer_processor_init): num_heads, mlp_hidden_ratio, dropout_p, + softcap, + attention_implementation, ) = transformer_processor_init return TransformerProcessor( num_layers=num_layers, @@ -61,6 +67,8 @@ def transformer_processor(transformer_processor_init): num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, dropout_p=dropout_p, + attention_implementation=attention_implementation, + softcap=softcap, ) @@ -75,6 +83,8 @@ def test_transformer_processor_init(transformer_processor, transformer_processor _num_heads, _mlp_hidden_ratio, _dropout_p, + _attention_implementation, + _softcap, ) = transformer_processor_init assert isinstance(transformer_processor, TransformerProcessor) assert transformer_processor.num_chunks == num_chunks @@ -93,6 +103,8 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces _num_heads, _mlp_hidden_ratio, _dropout_p, + _attention_implementation, + _softcap, ) = transformer_processor_init gridsize = 100 batch_size = 1 diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index a1b40540..9ef23485 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -22,12 +22,16 @@ num_heads=st.integers(min_value=1, max_value=50), embed_dim_multiplier=st.integers(min_value=1, max_value=10), dropout_p=st.floats(min_value=0.0, max_value=1.0), + softcap=st.floats(min_value=0.0, max_value=1.0), + attention_implementation=st.sampled_from(["scaled_dot_product_attention", "flex_attention"]), ) -def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p): +def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p, softcap, attention_implementation): embed_dim = ( num_heads * embed_dim_multiplier ) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads - mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) + mhsa = MultiHeadSelfAttention( + num_heads, embed_dim, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap + ) assert isinstance(mhsa, nn.Module) assert mhsa.num_heads == num_heads @@ -42,11 +46,17 @@ def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout num_heads=st.integers(min_value=1, max_value=20), embed_dim_multiplier=st.integers(min_value=1, max_value=10), dropout_p=st.floats(min_value=0.0, max_value=1.0), + softcap=st.floats(min_value=0.0, max_value=1.0), + attention_implementation=st.sampled_from(["scaled_dot_product_attention"]), ) @settings(deadline=None) -def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p): +def test_multi_head_self_attention_forward( + batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap, attention_implementation +): embed_dim = num_heads * embed_dim_multiplier - mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) + mhsa = MultiHeadSelfAttention( + num_heads, embed_dim, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap + ) x = torch.randn(batch_size * 2, embed_dim) shapes = [list(x.shape)] @@ -61,10 +71,16 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult num_heads=st.integers(min_value=1, max_value=20), embed_dim_multiplier=st.integers(min_value=1, max_value=10), dropout_p=st.floats(min_value=0.0, max_value=1.0), + softcap=st.floats(min_value=0.0, max_value=1.0), + attention_implementation=st.sampled_from(["scaled_dot_product_attention"]), ) -def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p): +def test_multi_head_self_attention_backward( + batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap, attention_implementation +): embed_dim = num_heads * embed_dim_multiplier - mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p) + mhsa = MultiHeadSelfAttention( + num_heads, embed_dim, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap + ) x = torch.randn(batch_size * 2, embed_dim, requires_grad=True) shapes = [list(x.shape)]