Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Flexible normalization layers #95

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ classifiers = [

dynamic = [ "version" ]
dependencies = [
# Fix certain dependencies during development
"anemoi-utils>=0.1.9",
"einops>=0.6.1",
"hydra-core>=1.3",
Expand Down
8 changes: 6 additions & 2 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.utils.config import DotDict

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

Expand All @@ -38,6 +40,7 @@ def __init__(
self,
num_heads: int,
embed_dim: int,
layer_kernels: DotDict,
bias: bool = False,
is_causal: bool = False,
window_size: Optional[int] = None,
Expand All @@ -56,13 +59,14 @@ def __init__(
self.dropout_p = dropout_p
self.is_causal = is_causal

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
linear = layer_kernels["Linear"]
self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)
self.projection = linear(embed_dim, embed_dim, bias=True)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
Expand Down
82 changes: 57 additions & 25 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import einops
import torch
from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
num_heads: int,
activation: str,
window_size: int,
layer_kernels: DotDict,
dropout_p: float = 0.0,
):
super().__init__()
Expand All @@ -78,7 +80,8 @@ def __init__(
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae

self.layer_norm1 = nn.LayerNorm(num_channels)
self.layer_norm_attention = layer_kernels["LayerNorm"](normalized_shape=num_channels)
self.layer_norm_mlp = layer_kernels["LayerNorm"](normalized_shape=num_channels)

self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
Expand All @@ -87,21 +90,21 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
layer_kernels=layer_kernels,
)

self.mlp = nn.Sequential(
nn.Linear(num_channels, hidden_dim),
layer_kernels["Linear"](num_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, num_channels),
layer_kernels["Linear"](hidden_dim, num_channels),
)
self.layer_norm2 = nn.LayerNorm(num_channels)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm2(x))
x = x + self.attention(self.layer_norm_attention(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm_mlp(x))
return x


Expand All @@ -112,6 +115,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -126,6 +130,9 @@ def __init__(
Number of input channels.
out_channels : int
Number of output channels.
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
mlp_extra_layers : int, optional
Extra layers in MLP, by default 0
activation : str, optional
Expand All @@ -144,13 +151,15 @@ def __init__(
2 * in_channels,
out_channels,
out_channels,
layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)

self.conv = GraphConv(
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand All @@ -173,6 +182,7 @@ def __ini__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -183,6 +193,7 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
Expand Down Expand Up @@ -229,6 +240,7 @@ def __ini__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -239,6 +251,7 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
Expand Down Expand Up @@ -295,6 +308,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -312,6 +326,9 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -330,37 +347,41 @@ def __init__(

self.num_chunks = num_chunks

self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)
linear = layer_kernels["Linear"]
layerNorm = layer_kernels["LayerNorm"]
self.lin_key = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)

self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)

self.projection = nn.Linear(out_channels, out_channels)
# Why does the GraphTransformer not have a layer_norm_mlp like the Transformer?
self.projection = linear(out_channels, out_channels)

try:
act_func = getattr(nn, activation)
except AttributeError as ae:
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae

self.layer_norm_attention = layerNorm(normalized_shape=in_channels)
self.layer_norm_mlp = layerNorm(normalized_shape=out_channels)

self.node_dst_mlp = nn.Sequential(
nn.LayerNorm(out_channels),
nn.Linear(out_channels, hidden_dim),
self.layer_norm_mlp,
linear(out_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, out_channels),
linear(hidden_dim, out_channels),
)

self.layer_norm1 = nn.LayerNorm(in_channels)

if self.update_src_nodes:
self.node_src_mlp = nn.Sequential(
nn.LayerNorm(out_channels),
nn.Linear(out_channels, hidden_dim),
self.layer_norm_mlp,
linear(out_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, out_channels),
linear(hidden_dim, out_channels),
)

def shard_qkve_heads(
Expand Down Expand Up @@ -435,6 +456,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -452,6 +474,9 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -466,6 +491,7 @@ def __init__(
hidden_dim=hidden_dim,
out_channels=out_channels,
edge_dim=edge_dim,
layer_kernels=layer_kernels,
num_heads=num_heads,
bias=bias,
activation=activation,
Expand All @@ -474,7 +500,7 @@ def __init__(
**kwargs,
)

self.layer_norm2 = nn.LayerNorm(in_channels)
self.layer_norm_attention_2 = layer_kernels["LayerNorm"](normalized_shape=in_channels)

def forward(
self,
Expand All @@ -489,9 +515,10 @@ def forward(
x_skip = x

x = (
self.layer_norm1(x[0]),
self.layer_norm2(x[1]),
) # Why does this use layer_norm2? And only is a mapper thing?
self.layer_norm_attention(x[0]),
self.layer_norm_attention_2(x[1]),
) # Why does this use layer_norm_attention_2? And only is a mapper thing?

x_r = self.lin_self(x[1])
query = self.lin_query(x[1])
key = self.lin_key(x[0])
Expand Down Expand Up @@ -559,6 +586,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -576,6 +604,9 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -591,6 +622,7 @@ def __init__(
hidden_dim=hidden_dim,
out_channels=out_channels,
edge_dim=edge_dim,
layer_kernels=layer_kernels,
num_heads=num_heads,
bias=bias,
activation=activation,
Expand All @@ -611,7 +643,7 @@ def forward(
):
x_skip = x

x = self.layer_norm1(x)
x = self.layer_norm_attention(x)
x_r = self.lin_self(x)
query = self.lin_query(x)
key = self.lin_key(x)
Expand Down
Loading
Loading