diff --git a/pyproject.toml b/pyproject.toml index 6d473472..aacce2f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ + # Fix certain dependencies during development "anemoi-utils>=0.1.9", "einops>=0.6.1", "hydra-core>=1.3", diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f54920..64cd7b69 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -25,6 +25,8 @@ else: _FLASH_ATTENTION_AVAILABLE = True +from anemoi.utils.config import DotDict + from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence @@ -38,6 +40,7 @@ def __init__( self, num_heads: int, embed_dim: int, + layer_kernels: DotDict, bias: bool = False, is_causal: bool = False, window_size: Optional[int] = None, @@ -56,13 +59,14 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal - self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + linear = layer_kernels["Linear"] + self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func if not _FLASH_ATTENTION_AVAILABLE: LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") - self.projection = nn.Linear(embed_dim, embed_dim, bias=True) + self.projection = linear(embed_dim, embed_dim, bias=True) def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 72e487d2..e06302f3 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -16,6 +16,7 @@ import einops import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -68,6 +69,7 @@ def __init__( num_heads: int, activation: str, window_size: int, + layer_kernels: DotDict, dropout_p: float = 0.0, ): super().__init__() @@ -78,7 +80,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = nn.LayerNorm(num_channels) + self.layer_norm_attention = layer_kernels["LayerNorm"](normalized_shape=num_channels) + self.layer_norm_mlp = layer_kernels["LayerNorm"](normalized_shape=num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -87,21 +90,21 @@ def __init__( bias=False, is_causal=False, dropout_p=dropout_p, + layer_kernels=layer_kernels, ) self.mlp = nn.Sequential( - nn.Linear(num_channels, hidden_dim), + layer_kernels["Linear"](num_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, num_channels), + layer_kernels["Linear"](hidden_dim, num_channels), ) - self.layer_norm2 = nn.LayerNorm(num_channels) def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x)) + x = x + self.attention(self.layer_norm_attention(x), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm_mlp(x)) return x @@ -112,6 +115,7 @@ def __init__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -126,6 +130,9 @@ def __init__( Number of input channels. out_channels : int Number of output channels. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra layers in MLP, by default 0 activation : str, optional @@ -144,6 +151,7 @@ def __init__( 2 * in_channels, out_channels, out_channels, + layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -151,6 +159,7 @@ def __init__( self.conv = GraphConv( in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, ) @@ -173,6 +182,7 @@ def __ini__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -183,6 +193,7 @@ def __ini__( self, in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, @@ -229,6 +240,7 @@ def __ini__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -239,6 +251,7 @@ def __ini__( self, in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, @@ -295,6 +308,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -312,6 +326,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -330,15 +347,18 @@ def __init__( self.num_chunks = num_chunks - self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias) - self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False) + linear = layer_kernels["Linear"] + layerNorm = layer_kernels["LayerNorm"] + self.lin_key = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_query = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_value = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_self = linear(in_channels, num_heads * self.out_channels_conv, bias=bias) + self.lin_edge = linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False) self.conv = GraphTransformerConv(out_channels=self.out_channels_conv) - self.projection = nn.Linear(out_channels, out_channels) + # Why does the GraphTransformer not have a layer_norm_mlp like the Transformer? + self.projection = linear(out_channels, out_channels) try: act_func = getattr(nn, activation) @@ -346,21 +366,22 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae + self.layer_norm_attention = layerNorm(normalized_shape=in_channels) + self.layer_norm_mlp = layerNorm(normalized_shape=out_channels) + self.node_dst_mlp = nn.Sequential( - nn.LayerNorm(out_channels), - nn.Linear(out_channels, hidden_dim), + self.layer_norm_mlp, + linear(out_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, out_channels), + linear(hidden_dim, out_channels), ) - self.layer_norm1 = nn.LayerNorm(in_channels) - if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - nn.LayerNorm(out_channels), - nn.Linear(out_channels, hidden_dim), + self.layer_norm_mlp, + linear(out_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, out_channels), + linear(hidden_dim, out_channels), ) def shard_qkve_heads( @@ -435,6 +456,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -452,6 +474,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -466,6 +491,7 @@ def __init__( hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, @@ -474,7 +500,7 @@ def __init__( **kwargs, ) - self.layer_norm2 = nn.LayerNorm(in_channels) + self.layer_norm_attention_2 = layer_kernels["LayerNorm"](normalized_shape=in_channels) def forward( self, @@ -489,9 +515,10 @@ def forward( x_skip = x x = ( - self.layer_norm1(x[0]), - self.layer_norm2(x[1]), - ) # Why does this use layer_norm2? And only is a mapper thing? + self.layer_norm_attention(x[0]), + self.layer_norm_attention_2(x[1]), + ) # Why does this use layer_norm_attention_2? And only is a mapper thing? + x_r = self.lin_self(x[1]) query = self.lin_query(x[1]) key = self.lin_key(x[0]) @@ -559,6 +586,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -576,6 +604,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -591,6 +622,7 @@ def __init__( hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, @@ -611,7 +643,7 @@ def forward( ): x_skip = x - x = self.layer_norm1(x) + x = self.layer_norm_attention(x) x_r = self.lin_self(x) query = self.lin_query(x) key = self.lin_key(x) diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 5c4fae38..ac0a0bce 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -13,6 +13,7 @@ from abc import abstractmethod from typing import Optional +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -37,6 +38,7 @@ def __init__( num_layers: int, *args, activation: str = "GELU", + layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize BaseProcessorChunk.""" @@ -70,6 +72,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, window_size: int, num_heads: int = 16, mlp_hidden_ratio: int = 4, @@ -84,6 +87,11 @@ def __init__( Number of channels num_layers : int Number of layers + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml + window_size: int, + 1/2 size of shifted window for attention computation num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -102,11 +110,16 @@ def __init__( num_heads=num_heads, activation=activation, window_size=window_size, + layer_kernels=layer_kernels, dropout_p=dropout_p, ) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + self, + x: Tensor, + shapes: list, + batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, ) -> Tensor: for i in range(self.num_layers): x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group) @@ -121,6 +134,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", edge_dim: Optional[int] = None, @@ -133,6 +147,9 @@ def __init__( Channels of the message passing blocks. num_layers : int Number of message passing blocks. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra num_layers in MLP, by default 0 activation : str, optional @@ -148,6 +165,7 @@ def __init__( in_features=edge_dim, hidden_dim=num_channels, out_features=num_channels, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -158,6 +176,7 @@ def __init__( GraphConvProcessorBlock, num_channels, num_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, ) @@ -188,6 +207,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", @@ -201,6 +221,9 @@ def __init__( Number of channels. num_layers : int Number of layers. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -217,8 +240,9 @@ def __init__( in_channels=num_channels, hidden_dim=mlp_hidden_ratio * num_channels, out_channels=num_channels, - num_heads=num_heads, edge_dim=edge_dim, + num_heads=num_heads, + layer_kernels=layer_kernels, activation=activation, ) diff --git a/src/anemoi/models/layers/conv.py b/src/anemoi/models/layers/conv.py index 6b3a767e..5c354502 100644 --- a/src/anemoi/models/layers/conv.py +++ b/src/anemoi/models/layers/conv.py @@ -11,6 +11,7 @@ from typing import Optional import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch.nn.functional import dropout from torch_geometric.nn.conv import MessagePassing @@ -31,6 +32,7 @@ def __init__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", **kwargs, @@ -43,6 +45,9 @@ def __init__( Number of input channels. out_channels : int Number of output channels. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra layers in MLP, by default 0 activation : str, optional @@ -54,6 +59,7 @@ def __init__( 3 * in_channels, out_channels, out_channels, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 1ae45031..38c0c152 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -14,6 +14,7 @@ import numpy as np import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper @@ -190,6 +191,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerBaseMapper. @@ -213,6 +215,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict, optional + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -224,6 +229,9 @@ def __init__( activation=activation, ) + # Linear = layer_kernels.get("Linear", torch.nn.Linear) + Linear = layer_kernels["Linear"] + self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) @@ -236,11 +244,12 @@ def __init__( edge_dim=self.edge_dim, activation=activation, num_chunks=num_chunks, + layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) - self.emb_nodes_dst = nn.Linear(self.in_channels_dst, self.hidden_dim) + self.emb_nodes_dst = Linear(self.in_channels_dst, self.hidden_dim) def forward( self, @@ -291,6 +300,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerForwardMapper. @@ -330,9 +340,10 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) - self.emb_nodes_src = nn.Linear(self.in_channels_src, self.hidden_dim) + self.emb_nodes_src = layer_kernels["Linear"](self.in_channels_src, self.hidden_dim) def forward( self, @@ -364,6 +375,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerBackwardMapper. @@ -387,6 +399,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -403,6 +418,7 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.node_data_extractor = nn.Sequential( @@ -436,6 +452,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNBaseMapper. @@ -459,6 +476,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -476,6 +496,7 @@ def __init__( in_features=self.edge_dim, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -540,6 +561,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNForwardMapper. @@ -563,6 +585,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -578,11 +603,13 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.proc = GraphConvMapperBlock( hidden_dim, hidden_dim, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=True, @@ -595,6 +622,7 @@ def __init__( in_features=in_channels_src, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -603,6 +631,7 @@ def __init__( in_features=in_channels_dst, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -626,6 +655,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNBackwardMapper. @@ -649,6 +679,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -664,11 +697,13 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.proc = GraphConvMapperBlock( hidden_dim, hidden_dim, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=False, @@ -681,6 +716,7 @@ def __init__( in_features=self.hidden_dim, hidden_dim=self.hidden_dim, out_features=self.out_channels_dst, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=self.activation, layer_norm=False, diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 4a1e7957..af2ce74c 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -11,9 +11,9 @@ import logging import torch +from anemoi.utils.config import DotDict from torch import nn -from anemoi.models.layers.utils import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper LOGGER = logging.getLogger(__name__) @@ -27,6 +27,7 @@ def __init__( in_features: int, hidden_dim: int, out_features: int, + layer_kernels: DotDict, n_extra_layers: int = 0, activation: str = "SiLU", final_activation: bool = False, @@ -43,6 +44,9 @@ def __init__( Hidden dimensions out_features : int Number of output features + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml n_extra_layers : int, optional Number of extra layers in MLP, by default 0 activation : str, optional @@ -65,23 +69,27 @@ def __init__( If activation function is not supported """ super().__init__() + + Linear = layer_kernels["Linear"] + LayerNorm = layer_kernels["LayerNorm"] + try: act_func = getattr(nn, activation) except AttributeError as ae: LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - mlp1 = nn.Sequential(nn.Linear(in_features, hidden_dim), act_func()) + mlp1 = nn.Sequential(Linear(in_features, hidden_dim), act_func()) for _ in range(n_extra_layers + 1): - mlp1.append(nn.Linear(hidden_dim, hidden_dim)) + mlp1.append(Linear(hidden_dim, hidden_dim)) mlp1.append(act_func()) - mlp1.append(nn.Linear(hidden_dim, out_features)) + mlp1.append(Linear(hidden_dim, out_features)) if final_activation: mlp1.append(act_func()) if layer_norm: - mlp1.append(AutocastLayerNorm(out_features)) + mlp1.append(LayerNorm(normalized_shape=out_features)) self.model = CheckpointWrapper(mlp1) if checkpoints else mlp1 diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py new file mode 100644 index 00000000..d3193c4a --- /dev/null +++ b/src/anemoi/models/layers/normalization.py @@ -0,0 +1,28 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from torch import Tensor +from torch import nn + + +class AutocastLayerNorm(nn.LayerNorm): + """LayerNorm that casts the output back to the input type.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + """Forward with explicit autocast back to the input type. + + This casts the output to (b)float16 (instead of float32) when we run in mixed + precision. + """ + return super().forward(x).type_as(x) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 8dba1f66..5b77fdc0 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -11,6 +11,7 @@ from abc import ABC from typing import Optional +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper @@ -88,6 +89,7 @@ class TransformerProcessor(BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, *args, window_size: Optional[int] = None, num_channels: int = 128, @@ -105,6 +107,9 @@ def __init__( ---------- num_layers : int Number of num_layers + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml window_size: int, 1/2 size of shifted window for attention computation num_channels : int @@ -119,8 +124,8 @@ def __init__( Dropout probability used for multi-head self attention, default 0.0 """ super().__init__( - num_channels=num_channels, num_layers=num_layers, + num_channels=num_channels, window_size=window_size, num_chunks=num_chunks, activation=activation, @@ -132,9 +137,10 @@ def __init__( self.build_layers( TransformerProcessorChunk, num_channels=num_channels, + num_layers=self.chunk_size, + layer_kernels=layer_kernels, mlp_hidden_ratio=mlp_hidden_ratio, num_heads=num_heads, - num_layers=self.chunk_size, window_size=window_size, activation=activation, dropout_p=dropout_p, @@ -168,6 +174,7 @@ class GNNProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, *args, trainable_size: int = 8, num_channels: int = 128, @@ -212,16 +219,15 @@ def __init__( self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) kwargs = { - "num_layers": self.chunk_size, "mlp_extra_layers": mlp_extra_layers, "activation": activation, "edge_dim": None, } - self.build_layers(GNNProcessorChunk, num_channels, **kwargs) + self.build_layers(GNNProcessorChunk, num_channels, self.chunk_size, layer_kernels, **kwargs) kwargs["edge_dim"] = self.edge_dim # Edge dim for first layer - self.proc[0] = GNNProcessorChunk(num_channels, **kwargs) + self.proc[0] = GNNProcessorChunk(num_channels, self.chunk_size, layer_kernels, **kwargs) self.offload_layers(cpu_offload) @@ -256,6 +262,7 @@ class GraphTransformerProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, trainable_size: int = 8, num_channels: int = 128, num_chunks: int = 2, @@ -289,8 +296,8 @@ def __init__( Whether to offload processing to CPU, by default False """ super().__init__( - num_layers=num_layers, num_channels=num_channels, + num_layers=num_layers, num_chunks=num_chunks, activation=activation, cpu_offload=cpu_offload, @@ -306,6 +313,7 @@ def __init__( GraphTransformerProcessorChunk, num_channels=num_channels, num_layers=self.chunk_size, + layer_kernels=layer_kernels, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, activation=activation, diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index e243874a..6bec46aa 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. -from torch import Tensor from torch import nn from torch.utils.checkpoint import checkpoint @@ -22,18 +21,3 @@ def __init__(self, module: nn.Module) -> None: def forward(self, *args, **kwargs): return checkpoint(self.module, *args, **kwargs, use_reentrant=False) - - -class AutocastLayerNorm(nn.LayerNorm): - """LayerNorm that casts the output back to the input type.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - """Forward with explicit autocast back to the input type. - - This casts the output to (b)float16 (instead of float32) when we run in mixed - precision. - """ - return super().forward(x).type_as(x) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c67c8c03..4c7a4406 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -14,7 +14,9 @@ import einops import torch from anemoi.utils.config import DotDict +from hydra.errors import InstantiationException from hydra.utils import instantiate +from omegaconf import OmegaConf from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -65,6 +67,9 @@ def __init__( input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] + # read config.model.layer_kernels to get the implementation for certain layers + self._load_layer_kernels(model_config) + # Encoder data -> hidden self.encoder = instantiate( model_config.model.encoder, @@ -74,6 +79,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + layer_kernels=self.layer_kernels, ) # Processor hidden -> hidden @@ -83,6 +89,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + layer_kernels=self.layer_kernels, ) # Decoder hidden -> data @@ -95,6 +102,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + layer_kernels=self.layer_kernels, ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -231,3 +239,26 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_out = bounding(x_out) return x_out + + def _load_layer_kernels(self, config: DotDict) -> None: + + # If self.layer_kernels entry is missing from the config, use torch.nn kernels + default_kernels = { + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True}, + "LayerNorm": {"_target_": "torch.nn.LayerNorm", "_partial_": True}, + } + user_kernel = OmegaConf.select(config, "model.layer_kernels") + self.layer_kernels = {**default_kernels, **user_kernel} + + # Loop through all kernels in the layer_kernels config entry and try import them + for kernel in self.layer_kernels: + kernel_entry = self.layer_kernels[kernel] + try: + instantiate(kernel_entry) + except InstantiationException: + LOGGER.info( + f"{kernel_entry['_target_']} not found! check your config.model.layer_kernel.{kernel} entry. Maybe your desired kernel is not installed or the import string is incorrect?" + ) + raise InstantiationException + else: + LOGGER.info(f"{kernel} kernel: {kernel_entry}")