Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature/44 make flash attention configurable #47

Open
wants to merge 73 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
539e8a2
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
3317138
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
3186a8e
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
a86c9a8
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
105443f
feat: add softcap
theissenhelen Sep 20, 2024
e82a59e
test: add softcap
theissenhelen Sep 20, 2024
e648eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
6271cd8
feat: flash attention lazy import
theissenhelen Sep 23, 2024
d4940e7
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
9ff6cb9
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
bbd89dc
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
91533c6
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
0eb5c50
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
c04e641
feat: get alibi_slopes
theissenhelen Oct 2, 2024
6523b47
docs: update docstrings
theissenhelen Oct 3, 2024
22623cc
fix: bias shape
theissenhelen Oct 3, 2024
ed07e34
fix: softcap optional
theissenhelen Oct 3, 2024
c841324
fix: import annotations from future
theissenhelen Oct 3, 2024
6c12dda
fix: annotation error
theissenhelen Oct 3, 2024
b7b8f2e
docs: update changelog
theissenhelen Oct 3, 2024
df353d9
fix: type annotation
theissenhelen Oct 7, 2024
fc335c7
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
663fea0
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
a8b3f9d
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
6595ca1
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
0c55a9c
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
ea665be
feat: add softcap
theissenhelen Sep 20, 2024
ffa2d99
test: add softcap
theissenhelen Sep 20, 2024
7c2d634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
d2ed932
feat: flash attention lazy import
theissenhelen Sep 23, 2024
3295159
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
ebde686
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
5102d9a
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
3abc286
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
673a25d
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
f606058
feat: get alibi_slopes
theissenhelen Oct 2, 2024
ef34771
docs: update docstrings
theissenhelen Oct 3, 2024
5136fb3
fix: bias shape
theissenhelen Oct 3, 2024
892c269
fix: softcap optional
theissenhelen Oct 3, 2024
4c42171
fix: import annotations from future
theissenhelen Oct 3, 2024
4bdf464
fix: annotation error
theissenhelen Oct 3, 2024
5a670b2
docs: update changelog
theissenhelen Oct 3, 2024
34db6e4
fix: type annotation
theissenhelen Oct 7, 2024
d424c75
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
222b7d8
feat: attention wrapper
theissenhelen Oct 25, 2024
c2aca14
fix: remove duplicate version check
theissenhelen Oct 25, 2024
b75d225
merge conflict
cathalobrien Nov 1, 2024
147e772
added flex attn wrapper
cathalobrien Nov 1, 2024
f0c24e8
fix: alibi_slopes unassigned
theissenhelen Nov 6, 2024
3c4572b
adding causal wip
cathalobrien Nov 6, 2024
fb731f7
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Nov 8, 2024
f0308f2
added flex attn module
cathalobrien Nov 12, 2024
6dee265
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2024
7fb0b62
Bump min torch version to be able to use Flex Attn
cathalobrien Nov 12, 2024
739aa65
added input parameter checks
cathalobrien Nov 12, 2024
2a2ed11
precommit fix
cathalobrien Nov 12, 2024
fa1474c
merge
cathalobrien Nov 12, 2024
a703688
fix: typo
theissenhelen Nov 26, 2024
f1be563
test: adjust tests
theissenhelen Nov 27, 2024
0dda5d6
fix: no self.use_alibi_slopes
theissenhelen Nov 27, 2024
12facf0
fix: use_alibi_slope default to false
theissenhelen Nov 28, 2024
60e32f1
feat: Add sliding window support for TorchAttention via mask
japols Dec 9, 2024
07d9684
fix: set default flash_attention
japols Dec 10, 2024
9a1827a
fix: pytest
japols Dec 10, 2024
ca8c9fa
fix: tests
japols Dec 13, 2024
ac897ea
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Dec 16, 2024
7ec8142
docs: improve docstrings in MultiHeadSelfAttention
theissenhelen Dec 18, 2024
972d3c5
fix: error instead of SystemExit
theissenhelen Dec 18, 2024
e89fd2e
chore: refactor SDPAAttention update_mask method
theissenhelen Dec 18, 2024
2d122df
feat: add missing pytest.ini
theissenhelen Dec 18, 2024
d4510f6
chore: remove explicit float typing
theissenhelen Dec 19, 2024
6057004
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Dec 19, 2024
8656cae
support running without window size
cathalobrien Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ Keep it human-readable, your future self will thank you!

### Added

- CI workflow to update the changelog on release
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
- add configurability of flash attention (#47)
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy

theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

### Changed

- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)
- Bugfixes for CI

### Removed

Expand Down
111 changes: 94 additions & 17 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,27 @@
# nor does it submit to any jurisdiction.
#

from __future__ import annotations

import logging
import math
from typing import Optional

import einops
import torch
from packaging import version
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup

try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from torch.nn.functional import scaled_dot_product_attention as attn_func

_FLASH_ATTENTION_AVAILABLE = False
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

LOGGER = logging.getLogger(__name__)


class MultiHeadSelfAttention(nn.Module):
"""Multi Head Self Attention Pytorch Layer."""
"""Multi Head Self Attention Pytorch Layer using flash attention, see https://github.com/Dao-AILab/flash-attention"""
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -41,31 +37,77 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float = None,
use_alibi_slopes: bool = None,
):
"""Initialize MultiHeadSelfAttention.

Parameters
----------
num_heads : int
number of heads
embed_dim : int
embedding dimension
bias : bool, optional
bias, by default False
is_causal : bool, optional
apply causal attention mask, by default False
window_size : Optional[int], optional
window_size, by default None
dropout_p : float, optional
dropout probability, by default 0.0
softcap : float, optional
Anything > 0 activates softcapping attention, by default None
use_alibi_slopes : bool, optional
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
to the attention score of query i and key j, where alibi_slope
is calculated using get_alibi_slopes, by default None
"""
super().__init__()

assert (
embed_dim % num_heads == 0
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"

self.use_flash_attention = use_flash_attention
self.set_attention_function()

self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap
self.use_alibi_slopes = use_alibi_slopes

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func
if self.use_alibi_slopes is not None:
self.alibi_slopes = get_alibi_slopes(num_heads)
assert self.alibi_slopes.shape[0] == num_heads
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")
self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def set_attention_function(self):

if self.use_flash_attention:
import flash_attn

if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
raise SystemExit("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
else:
self.attention = flash_attn.flash_attn_func
else:
from torch.nn.functional import scaled_dot_product_attention

self.attention = scaled_dot_product_attention

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:

query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
Expand All @@ -88,11 +130,23 @@ def forward(
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
if self.use_flash_attention:
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)

alibi_slopes = self.alibi_slopes.repeat(batch_size, 1).to(query.device) if self.use_alibi_slopes else None

out = self.attention(
query,
key,
value,
causal=False,
window_size=self.window_size,
dropout_p=dropout_p,
softcap=self.softcap,
alibi_slopes=alibi_slopes,
)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
out = self.attention(
Expand All @@ -101,11 +155,34 @@ def forward(
value,
is_causal=False,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format
)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)

return out


def get_alibi_slopes(num_heads: int) -> Tensor:
"""Calculates linearly decreasing slopes for alibi attention.

Parameters
----------
num_heads : int
number of attention heads

Returns
-------
Tensor
aLiBi slopes
"""
n = 2 ** math.floor(math.log2(num_heads))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since num_heads is an integer, we could be using bit-shifting here:
n = 1 << (num_heads.bit_length() - 1)

Not sure how necessary speed is here though, as a trade-off against readability. It would definitely need a comment.

Copy link
Collaborator Author

@theissenhelen theissenhelen Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed is not an issue as it is only calculated once. So, I would go for readability.

slope_0 = 2.0 ** (-8.0 / n)
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
if n < num_heads:
slope_hat_0 = 2.0 ** (-4.0 / n)
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat])
return alibi_slopes
6 changes: 6 additions & 0 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
activation: str,
window_size: int,
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float = None,
use_alibi_slopes: bool = None,
):
super().__init__()

Expand All @@ -81,6 +84,9 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

self.mlp = nn.Sequential(
Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float = None,
use_alibi_slopes: bool = None,
) -> None:
"""Initialize TransformerProcessor.

Expand All @@ -91,6 +94,10 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float
Dropout probability used for multi-head self attention, default 0.0
softcap : float, optional
Anything > 0 activates softcapping flash attention, by default None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "Anything > 0" mean here? Please adjust this explanation across docstrings to be more informative to someone that hasn't worked with the attention implementation yet.

use_alibi_slopes : bool, optional
Use aLiBI option, only used for flash attention, by default None
"""
super().__init__(num_channels=num_channels, num_layers=num_layers)

Expand All @@ -102,6 +109,9 @@ def __init__(
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

def forward(
Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
use_flash_attention: bool = False,
softcap: float = 0.0,
use_alibi_slopes: bool = None,
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -116,6 +119,10 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float, optional
Dropout probability used for multi-head self attention, default 0.0
softcap : float, optional
Anything > 0 activates softcapping flash attention, by default None
use_alibi_slopes : bool, optional
Use aLiBI option, only used for flash attention, by default None
"""
super().__init__(
num_channels=num_channels,
Expand All @@ -137,6 +144,9 @@ def __init__(
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

self.offload_layers(cpu_offload)
Expand Down
12 changes: 7 additions & 5 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ class TestTransformerProcessorBlock:
activation=st.sampled_from(["ReLU", "GELU", "Tanh"]),
window_size=st.integers(min_value=1, max_value=512),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p):
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p, softcap):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
)
assert isinstance(block, TransformerProcessorBlock)

Expand All @@ -53,6 +54,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w
shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3),
batch_size=st.integers(min_value=1, max_value=40),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_forward_output(
Expand All @@ -65,14 +67,14 @@ def test_forward_output(
shapes,
batch_size,
dropout_p,
softcap,
):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
)

x = torch.randn((batch_size, num_channels))

x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: What is the comment for?

output = block.forward(x, shapes, batch_size)
assert isinstance(output, torch.Tensor)
assert output.shape == (batch_size, num_channels)
Expand Down
6 changes: 6 additions & 0 deletions tests/layers/processor/test_transformer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def transformer_processor_init():
num_heads = 16
mlp_hidden_ratio = 4
dropout_p = 0.1
softcap = 0.5
return (
num_layers,
window_size,
Expand All @@ -32,6 +33,7 @@ def transformer_processor_init():
num_heads,
mlp_hidden_ratio,
dropout_p,
softcap,
)


Expand All @@ -47,6 +49,7 @@ def transformer_processor(transformer_processor_init):
num_heads,
mlp_hidden_ratio,
dropout_p,
softcap,
) = transformer_processor_init
return TransformerProcessor(
num_layers=num_layers,
Expand All @@ -58,6 +61,7 @@ def transformer_processor(transformer_processor_init):
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
dropout_p=dropout_p,
softcap=softcap,
)


Expand All @@ -72,6 +76,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_softcap,
) = transformer_processor_init
assert isinstance(transformer_processor, TransformerProcessor)
assert transformer_processor.num_chunks == num_chunks
Expand All @@ -90,6 +95,7 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_softcap,
) = transformer_processor_init
gridsize = 100
batch_size = 1
Expand Down
Loading
Loading