Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Switch layer kernel implementation in the config #35

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.utils.config import DotDict

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

Expand All @@ -41,6 +43,7 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout: float = 0.0,
layer_kernels: DotDict = None,
):
super().__init__()

Expand All @@ -54,14 +57,16 @@ def __init__(
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.is_causal = is_causal
Linear = layer_kernels.get("Linear", nn.Linear)
LayerNorm = layer_kernels.get("LayerNorm", nn.LayerNorm)

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.lin_qkv = Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)
self.projection = Linear(embed_dim, embed_dim, bias=True)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
Expand Down
70 changes: 51 additions & 19 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import einops
import torch
from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
Expand Down Expand Up @@ -55,7 +56,7 @@ def forward(
class TransformerProcessorBlock(BaseBlock):
"""Transformer block with MultiHeadSelfAttention and MLPs."""

def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size: int):
def __init__(self, num_channels, layer_kernels, hidden_dim, num_heads, activation, window_size: int):
super().__init__()

try:
Expand All @@ -64,7 +65,12 @@ def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size:
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae

self.layer_norm1 = nn.LayerNorm(num_channels)
# Uses the implementation defined in config.model.layer_kernels.<kernel>
# (unless it is not availible, in which case it will fall back to torch.nn.<kernel>)
Linear = layer_kernels.get("Linear", torch.nn.Linear)
LayerNorm = layer_kernels.get("LayerNorm", torch.nn.LayerNorm)

self.layer_norm1 = LayerNorm(num_channels)

self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
Expand All @@ -73,14 +79,15 @@ def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size:
bias=False,
is_causal=False,
dropout=0.0,
layer_kernels=layer_kernels,
)

self.mlp = nn.Sequential(
nn.Linear(num_channels, hidden_dim),
Linear(num_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, num_channels),
Linear(hidden_dim, num_channels),
)
self.layer_norm2 = nn.LayerNorm(num_channels)
self.layer_norm2 = LayerNorm(num_channels)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
Expand All @@ -98,6 +105,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
layer_kernels: any,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand Down Expand Up @@ -132,13 +140,15 @@ def __init__(
out_channels,
n_extra_layers=mlp_extra_layers,
activation=activation,
layer_kernels=layer_kernels,
)

self.conv = GraphConv(
in_channels=in_channels,
out_channels=out_channels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
layer_kernels=layer_kernels,
)

@abstractmethod
Expand Down Expand Up @@ -215,6 +225,7 @@ def __ini__(
self,
in_channels: int,
out_channels: int,
layer_kernels: any,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
update_src_nodes: bool = True,
Expand All @@ -225,6 +236,7 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
Expand Down Expand Up @@ -281,6 +293,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: DotDict,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -298,6 +311,8 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : DotDict,
A dict of layer implementations e.g. layer_kernels['Linear'] = "Module.submodule.Linear". Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -316,15 +331,20 @@ def __init__(

self.num_chunks = num_chunks

self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)
# Uses the implementation defined in config.model.layer_kernels.<kernel>
# (unless it is not availible, in which case it will fall back to torch.nn.<kernel>)
Linear = layer_kernels.get("Linear", torch.nn.Linear)
LayerNorm = layer_kernels.get("LayerNorm", torch.nn.LayerNorm)

self.lin_key = Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)

self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)

self.projection = nn.Linear(out_channels, out_channels)
self.projection = Linear(out_channels, out_channels)

try:
act_func = getattr(nn, activation)
Expand All @@ -333,20 +353,20 @@ def __init__(
raise RuntimeError from ae

self.node_dst_mlp = nn.Sequential(
nn.LayerNorm(out_channels),
nn.Linear(out_channels, hidden_dim),
LayerNorm(out_channels),
Linear(out_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, out_channels),
Linear(hidden_dim, out_channels),
)

self.layer_norm1 = nn.LayerNorm(in_channels)
self.layer_norm1 = LayerNorm(in_channels)

if self.update_src_nodes:
self.node_src_mlp = nn.Sequential(
nn.LayerNorm(out_channels),
nn.Linear(out_channels, hidden_dim),
LayerNorm(out_channels),
Linear(out_channels, hidden_dim),
act_func(),
nn.Linear(hidden_dim, out_channels),
Linear(hidden_dim, out_channels),
)

def shard_qkve_heads(
Expand Down Expand Up @@ -421,6 +441,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: any,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -438,6 +459,8 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : any,
A dict of layer implementations e.g. layer_kernels['Linear'] = "Module.submodule.Linear". Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -457,10 +480,15 @@ def __init__(
activation=activation,
num_chunks=num_chunks,
update_src_nodes=update_src_nodes,
layer_kernels=layer_kernels,
**kwargs,
)

self.layer_norm2 = nn.LayerNorm(in_channels)
# Uses the implementation defined in config.model.layer_kernels.<kernel>
# (unless it is not availible, in which case it will fall back to torch.nn.<kernel>)
LayerNorm = layer_kernels.get("LayerNorm", torch.nn.LayerNorm)

self.layer_norm2 = LayerNorm(in_channels)

def forward(
self,
Expand Down Expand Up @@ -513,6 +541,7 @@ def __init__(
hidden_dim: int,
out_channels: int,
edge_dim: int,
layer_kernels: any,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU",
Expand All @@ -530,6 +559,8 @@ def __init__(
Number of output channels.
edge_dim : int,
Edge dimension
layer_kernels : any,
A dict of layer implementations e.g. layer_kernels['Linear'] = "Module.submodule.Linear". Defined in config/models/<model>.yaml
num_heads : int,
Number of heads
bias : bool, by default True,
Expand All @@ -545,6 +576,7 @@ def __init__(
hidden_dim=hidden_dim,
out_channels=out_channels,
edge_dim=edge_dim,
layer_kernels=layer_kernels,
num_heads=num_heads,
bias=bias,
activation=activation,
Expand Down
5 changes: 5 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from abc import abstractmethod
from typing import Optional

from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
num_channels: int,
num_layers: int,
window_size: int,
layer_kernels: DotDict,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
Expand All @@ -82,6 +84,8 @@ def __init__(
Number of channels
num_layers : int
Number of layers
layer_kernels : DotDict,
A dict of layer implementations e.g. layer_kernels['Linear'] = "Module.submodule.Linear". Defined in config/models/<model>.yaml
num_heads: int
Number of heads to use, default 16
mlp_hidden_ratio: int
Expand All @@ -94,6 +98,7 @@ def __init__(
self.build_blocks(
TransformerProcessorBlock,
num_channels=num_channels,
layer_kernels=layer_kernels,
hidden_dim=(mlp_hidden_ratio * num_channels),
num_heads=num_heads,
activation=activation,
Expand Down
5 changes: 5 additions & 0 deletions src/anemoi/models/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional

import torch
from anemoi.utils.config import DotDict
from torch import Tensor
from torch.nn.functional import dropout
from torch_geometric.nn.conv import MessagePassing
Expand All @@ -30,6 +31,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
**kwargs,
Expand All @@ -42,6 +44,8 @@ def __init__(
Number of input channels.
out_channels : int
Number of output channels.
layer_kernels : DotDict,
A dict of layer implementations e.g. layer_kernels['Linear'] = "Module.submodule.Linear". Defined in config/models/<model>.yaml
mlp_extra_layers : int, optional
Extra layers in MLP, by default 0
activation : str, optional
Expand All @@ -55,6 +59,7 @@ def __init__(
out_channels,
n_extra_layers=mlp_extra_layers,
activation=activation,
layer_kernels=layer_kernels,
)

def forward(self, x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, size: Optional[Size] = None):
Expand Down
Loading