Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature/44 make flash attention configurable #47

Open
wants to merge 73 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
539e8a2
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
3317138
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
3186a8e
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
a86c9a8
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
105443f
feat: add softcap
theissenhelen Sep 20, 2024
e82a59e
test: add softcap
theissenhelen Sep 20, 2024
e648eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
6271cd8
feat: flash attention lazy import
theissenhelen Sep 23, 2024
d4940e7
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
9ff6cb9
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
bbd89dc
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
91533c6
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
0eb5c50
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
c04e641
feat: get alibi_slopes
theissenhelen Oct 2, 2024
6523b47
docs: update docstrings
theissenhelen Oct 3, 2024
22623cc
fix: bias shape
theissenhelen Oct 3, 2024
ed07e34
fix: softcap optional
theissenhelen Oct 3, 2024
c841324
fix: import annotations from future
theissenhelen Oct 3, 2024
6c12dda
fix: annotation error
theissenhelen Oct 3, 2024
b7b8f2e
docs: update changelog
theissenhelen Oct 3, 2024
df353d9
fix: type annotation
theissenhelen Oct 7, 2024
fc335c7
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
663fea0
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
a8b3f9d
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
6595ca1
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
0c55a9c
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
ea665be
feat: add softcap
theissenhelen Sep 20, 2024
ffa2d99
test: add softcap
theissenhelen Sep 20, 2024
7c2d634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
d2ed932
feat: flash attention lazy import
theissenhelen Sep 23, 2024
3295159
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
ebde686
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
5102d9a
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
3abc286
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
673a25d
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
f606058
feat: get alibi_slopes
theissenhelen Oct 2, 2024
ef34771
docs: update docstrings
theissenhelen Oct 3, 2024
5136fb3
fix: bias shape
theissenhelen Oct 3, 2024
892c269
fix: softcap optional
theissenhelen Oct 3, 2024
4c42171
fix: import annotations from future
theissenhelen Oct 3, 2024
4bdf464
fix: annotation error
theissenhelen Oct 3, 2024
5a670b2
docs: update changelog
theissenhelen Oct 3, 2024
34db6e4
fix: type annotation
theissenhelen Oct 7, 2024
d424c75
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
222b7d8
feat: attention wrapper
theissenhelen Oct 25, 2024
c2aca14
fix: remove duplicate version check
theissenhelen Oct 25, 2024
b75d225
merge conflict
cathalobrien Nov 1, 2024
147e772
added flex attn wrapper
cathalobrien Nov 1, 2024
f0c24e8
fix: alibi_slopes unassigned
theissenhelen Nov 6, 2024
3c4572b
adding causal wip
cathalobrien Nov 6, 2024
fb731f7
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Nov 8, 2024
f0308f2
added flex attn module
cathalobrien Nov 12, 2024
6dee265
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2024
7fb0b62
Bump min torch version to be able to use Flex Attn
cathalobrien Nov 12, 2024
739aa65
added input parameter checks
cathalobrien Nov 12, 2024
2a2ed11
precommit fix
cathalobrien Nov 12, 2024
fa1474c
merge
cathalobrien Nov 12, 2024
a703688
fix: typo
theissenhelen Nov 26, 2024
f1be563
test: adjust tests
theissenhelen Nov 27, 2024
0dda5d6
fix: no self.use_alibi_slopes
theissenhelen Nov 27, 2024
12facf0
fix: use_alibi_slope default to false
theissenhelen Nov 28, 2024
60e32f1
feat: Add sliding window support for TorchAttention via mask
japols Dec 9, 2024
07d9684
fix: set default flash_attention
japols Dec 10, 2024
9a1827a
fix: pytest
japols Dec 10, 2024
ca8c9fa
fix: tests
japols Dec 13, 2024
ac897ea
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Dec 16, 2024
7ec8142
docs: improve docstrings in MultiHeadSelfAttention
theissenhelen Dec 18, 2024
972d3c5
fix: error instead of SystemExit
theissenhelen Dec 18, 2024
e89fd2e
chore: refactor SDPAAttention update_mask method
theissenhelen Dec 18, 2024
2d122df
feat: add missing pytest.ini
theissenhelen Dec 18, 2024
d4510f6
chore: remove explicit float typing
theissenhelen Dec 19, 2024
6057004
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Dec 19, 2024
8656cae
support running without window size
cathalobrien Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,24 @@ Keep it human-readable, your future self will thank you!

### Added

- CI workflow to update the changelog on release
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
- add configurability of flash attention (#47)
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy

theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

### Changed

- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)
- Bugfixes for CI

### Removed

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies = [
"anemoi-utils>=0.1.9",
"einops>=0.6.1",
"hydra-core>=1.3",
"torch>=2.2",
"torch>=2.5",
"torch-geometric>=2.3,<2.5",
]
optional-dependencies.all = [ ]
Expand Down
296 changes: 271 additions & 25 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,27 @@
# nor does it submit to any jurisdiction.


from __future__ import annotations

import logging
import math
from typing import Optional

import einops
import torch
from packaging import version
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup

try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from torch.nn.functional import scaled_dot_product_attention as attn_func

_FLASH_ATTENTION_AVAILABLE = False
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

LOGGER = logging.getLogger(__name__)


class MultiHeadSelfAttention(nn.Module):
"""Multi Head Self Attention Pytorch Layer."""
"""Multi Head Self Attention Pytorch Layer using flash attention, see https://github.com/Dao-AILab/flash-attention"""
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -42,31 +38,85 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
attention_implementation: str = "flash_attention",
softcap: float = None,
use_alibi_slopes: bool = False,
):
"""Initialize MultiHeadSelfAttention.

Parameters
----------
num_heads : int
number of heads
embed_dim : int
embedding dimension
bias : bool, optional
bias, by default False
is_causal : bool, optional
apply causal attention mask, by default False
window_size : Optional[int], optional
window_size, by default None
dropout_p : float, optional
dropout probability, by default 0.0
softcap : float, optional
Anything > 0 activates softcapping attention, by default None
attention_implementation: str, optional
A predefined string which selects which underlying attention
implementation, by default "flash_attention"
use_alibi_slopes : bool, optional
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
to the attention score of query i and key j, where alibi_slope
is calculated using get_alibi_slopes, by default None
"""
super().__init__()

assert (
embed_dim % num_heads == 0
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"

self.attention_implementation = attention_implementation
self.use_alibi_slopes = use_alibi_slopes
self.set_attention_function()

self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.window_size = window_size
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func
if self.use_alibi_slopes:
self.alibi_slopes = get_alibi_slopes(num_heads)
assert self.alibi_slopes.shape[0] == num_heads
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
else:
self.alibi_slopes = None

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")
self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def set_attention_function(self):
attn_funcs = {
"flash_attention": FlashAttentionWrapper,
"flex_attention": FlexAttentionWrapper,
"scaled_dot_product_attention": TorchAttentionWrapper,
}
if self.attention_implementation in attn_funcs:
LOGGER.info(f"attention.py: using {self.attention_implementation}")
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
# initalise the attn func here
self.attention = attn_funcs[self.attention_implementation]()
else:
# Requested attn implementation is not supported
raise SystemExit(
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
f"attention.py: Error! {self.attention_implementation} not supported. \
please change model.processor.attention_implementation in the config to one of: {attn_funcs.keys()}"
)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:

query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
Expand All @@ -89,24 +139,220 @@ def forward(
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
out = self.attention(
query,
key,
value,
batch_size,
causal=False,
window_size=self.window_size,
dropout_p=dropout_p,
softcap=self.softcap,
alibi_slopes=self.alibi_slopes,
)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)

return out


class TorchAttentionWrapper(nn.Module):
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper for Pytorch dot product attention"""
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
super().__init__()

from torch.nn.functional import scaled_dot_product_attention

self.attention = scaled_dot_product_attention
self.mask = None
self.window_size = None

def update_mask(self, seq_len, window_size: int, device: str):
update_mask = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is this saved as variable?

self.mask is None or self.window_size != window_size or tuple(self.mask.shape) != (seq_len, seq_len)
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
)
if update_mask:
self.window_size = window_size
self.mask = (
torch.abs(
torch.arange(seq_len, device=device).unsqueeze(0)
- torch.arange(seq_len, device=device).unsqueeze(1)
)
<= window_size
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")

def forward(
self,
query,
key,
value,
batch_size: int,
causal=False,
window_size=None,
dropout_p=0.0,
softcap=None,
alibi_slopes=None,
):
if softcap is not None:
SystemError(
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
"Error. Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap."
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
)
if alibi_slopes is not None:
SystemError(
"Error. Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
)
if window_size is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we're already checking if we should update the mask and then in the update_mask we do a bunch of other checks. I think this needs a small refactor to be more tidy.

self.update_mask(query.shape[-2], window_size=window_size, device=query.device)
else:
self.mask = None

with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning to limit the backend?

Copy link
Member

@japols japols Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JesperDramsch the idea is to use the MATH backend because it comes with pytorch and doesn't rely on external libraries like the other two options:

I would rather explicitly support the latter than switching via SDPA backend

out = self.attention(
query,
key,
value,
is_causal=False,
attn_mask=self.mask,
is_causal=causal,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format
)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
return out

out = self.projection(out)

class FlexAttentionWrapper(nn.Module):
"""Wrapper for Pytorch Flex attention."""

def __init__(self):
super().__init__()

if version.parse(torch.__version__) < version.parse("2.5.0"):
raise SystemExit("Error: torch version is too low. Update to 2.5.0 or higher to use Flex Attention.")
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved

# we compile flex attn once at the first iteration
# This is bc we need to know the seq len to compute the mask mod for sliding window
self.is_attn_compiled = False

def forward(
self,
query,
key,
value,
batch_size: int,
causal: bool = False,
window_size: int = None,
dropout_p: float = 0.0,
softcap: float = None,
alibi_slopes: torch.Tensor = None,
):

if alibi_slopes is not None:
SystemExit("Error. Alibi_slopes not yet implemented in FlexAttn in Anemoi.")
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
if softcap is not None:
SystemExit("Error. Softcap not yet implemented in FlexAttn in Anemoi.")
if dropout_p != 0.0:
SystemExit("Error. Dropout not yet implemented in FlexAttn in Anemoi.")
if causal:
SystemExit("Error. Causal not yet implemented in FlexAttn in Anemoi.")

# This assumes seq_len never changes
# across iterations and stages
# could add something like
# if query.shape[2] != prev_seq_len:
# self.is_attn_compiled = False
# To trigger a recompilation
if not self.is_attn_compiled:
import functools

from torch.nn.attention.flex_attention import create_block_mask # should this be after the version check?
from torch.nn.attention.flex_attention import flex_attention

def sliding_window_mask(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size

seq_len = query.shape[2]
self.block_mask = create_block_mask(
sliding_window_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, _compile=True
)
self.attention = functools.partial(
flex_attention, block_mask=self.block_mask
) # Cache the block mask (recomended in attn blog post)
self.attention = torch.compile(self.attention)
self.is_attn_compiled = True

# TODO test how this impacts scaling at large model counts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is this a TODO for?

torch._dynamo.config.optimize_ddp = False
out = self.attention(query, key, value)
torch._dynamo.config.optimize_ddp = True

return out


class FlashAttentionWrapper(nn.Module):
"""Wrapper for Flash attention."""

def __init__(self):
super().__init__()
import flash_attn

if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
raise SystemExit("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
else:
self.attention = flash_attn.flash_attn_func

def forward(
self,
query,
key,
value,
batch_size: int,
causal: bool = False,
window_size: int = None,
dropout_p: float = 0.0,
softcap: float = None,
alibi_slopes: torch.Tensor = None,
):
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)

alibi_slopes = alibi_slopes.repeat(batch_size, 1).to(query.device) if alibi_slopes is not None else None

out = self.attention(
query,
key,
value,
causal=False,
window_size=(window_size, window_size),
dropout_p=dropout_p,
softcap=softcap,
alibi_slopes=alibi_slopes,
)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
return out


def get_alibi_slopes(num_heads: int) -> Tensor:
"""Calculates linearly decreasing slopes for alibi attention.

Parameters
----------
num_heads : int
number of attention heads

Returns
-------
Tensor
aLiBi slopes
"""
n = 2 ** math.floor(math.log2(num_heads))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since num_heads is an integer, we could be using bit-shifting here:
n = 1 << (num_heads.bit_length() - 1)

Not sure how necessary speed is here though, as a trade-off against readability. It would definitely need a comment.

Copy link
Collaborator Author

@theissenhelen theissenhelen Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed is not an issue as it is only calculated once. So, I would go for readability.

slope_0 = 2.0 ** (-8.0 / n)
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
if n < num_heads:
slope_hat_0 = 2.0 ** (-4.0 / n)
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat])
return alibi_slopes
Loading
Loading