From cf0cc85f6b35220467199750f826fe84240c69a5 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 22 Nov 2024 14:11:23 +0000 Subject: [PATCH 1/9] Refactor to instantiate normalization layer. --- src/anemoi/models/layers/block.py | 13 ++++++---- src/anemoi/models/layers/chunk.py | 10 ++++++-- src/anemoi/models/layers/mlp.py | 2 +- src/anemoi/models/layers/normalization.py | 31 +++++++++++++++++++++++ src/anemoi/models/layers/processor.py | 4 ++- src/anemoi/models/layers/utils.py | 15 ----------- 6 files changed, 51 insertions(+), 24 deletions(-) create mode 100644 src/anemoi/models/layers/normalization.py diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6c..81ee3204 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -69,6 +69,7 @@ def __init__( activation: str, window_size: int, dropout_p: float = 0.0, + layer_norm: Optional[dict] = None ): super().__init__() @@ -78,7 +79,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = nn.LayerNorm(num_channels) + self.layer_norm1 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) + self.layer_norm2 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -94,14 +96,15 @@ def __init__( act_func(), nn.Linear(hidden_dim, num_channels), ) - self.layer_norm2 = nn.LayerNorm(num_channels) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + self, x: Tensor, shapes: list, batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, + **kwargs ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x)) + x = x + self.attention(self.layer_norm1(x, **kwargs), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm2(x, **kwargs)) return x diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 5c4fae38..44dfc76a 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -75,6 +75,7 @@ def __init__( mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, + layer_norm: Optional[dict] = None, ) -> None: """Initialize TransformerProcessor. @@ -103,13 +104,18 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, + layer_norm=layer_norm, ) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + self, x: Tensor, shapes: list, batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, + **kwargs, ) -> Tensor: for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group) + x = self.blocks[i](x, shapes, batch_size, + model_comm_group=model_comm_group, + **kwargs) return (x,) # return tuple for consistency with other processors diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 4a1e7957..1d939704 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -13,7 +13,7 @@ import torch from torch import nn -from anemoi.models.layers.utils import AutocastLayerNorm +from anemoi.models.layers.normalization import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py new file mode 100644 index 00000000..551d90c8 --- /dev/null +++ b/src/anemoi/models/layers/normalization.py @@ -0,0 +1,31 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +import torch +from torch import nn + + +class AutocastLayerNorm(nn.LayerNorm): + """LayerNorm that casts the output back to the input type.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + """Forward with explicit autocast back to the input type. + + This casts the output to (b)float16 (instead of float32) when we run in mixed + precision. + """ + return super().forward(x).type_as(x) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 4fd32311..0afdf927 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -97,6 +97,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, + layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -138,6 +139,7 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, + layer_norm=layer_norm, ) self.offload_layers(cpu_offload) @@ -157,7 +159,7 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group) + (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group, **kwargs) return x diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index e243874a..f35c2b8b 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -22,18 +22,3 @@ def __init__(self, module: nn.Module) -> None: def forward(self, *args, **kwargs): return checkpoint(self.module, *args, **kwargs, use_reentrant=False) - - -class AutocastLayerNorm(nn.LayerNorm): - """LayerNorm that casts the output back to the input type.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor) -> Tensor: - """Forward with explicit autocast back to the input type. - - This casts the output to (b)float16 (instead of float32) when we run in mixed - precision. - """ - return super().forward(x).type_as(x) From 207984679d015852ca45b51f2aa2dc67a4aadf7b Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Fri, 22 Nov 2024 15:12:33 +0000 Subject: [PATCH 2/9] Fix dependencies for development --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6d473472..45d56709 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ + # Fix certain dependencies during development + "anemoi-training @ git+https://github.com/ecmwf/anemoi-training.git@25abf5e143a29d5931ccb4ac42a5f83c5cd26851", "anemoi-utils>=0.1.9", "einops>=0.6.1", "hydra-core>=1.3", From 8bc7d79fc8d6e2a119974c77d9c8124a86e59077 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Wed, 4 Dec 2024 10:05:32 +0000 Subject: [PATCH 3/9] can pass arbitrary kernels via config --- src/anemoi/models/layers/attention.py | 8 ++- src/anemoi/models/layers/block.py | 71 +++++++++++++------ src/anemoi/models/layers/chunk.py | 21 +++++- src/anemoi/models/layers/mapper.py | 21 +++++- src/anemoi/models/layers/mlp.py | 17 +++-- src/anemoi/models/layers/processor.py | 9 ++- .../models/encoder_processor_decoder.py | 30 ++++++++ 7 files changed, 143 insertions(+), 34 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f54920..dc859b47 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -28,6 +28,8 @@ from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence +from anemoi.utils.config import DotDict + LOGGER = logging.getLogger(__name__) @@ -38,6 +40,7 @@ def __init__( self, num_heads: int, embed_dim: int, + layer_kernels: DotDict, bias: bool = False, is_causal: bool = False, window_size: Optional[int] = None, @@ -56,13 +59,14 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal - self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + linear=layer_kernels["Linear"] + self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func if not _FLASH_ATTENTION_AVAILABLE: LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") - self.projection = nn.Linear(embed_dim, embed_dim, bias=True) + self.projection = linear(embed_dim, embed_dim, bias=True) def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 81ee3204..8296f1b4 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -32,6 +32,7 @@ from anemoi.models.layers.conv import GraphConv from anemoi.models.layers.conv import GraphTransformerConv from anemoi.models.layers.mlp import MLP +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def __init__( num_heads: int, activation: str, window_size: int, + layer_kernels: DotDict, dropout_p: float = 0.0, - layer_norm: Optional[dict] = None ): super().__init__() @@ -79,8 +80,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) - self.layer_norm2 = layer_norm() if layer_norm else nn.LayerNorm(num_channels) + self.layer_norm1 = layer_kernels["LayerNorm"](num_channels) + self.layer_norm2 = layer_kernels["LayerNorm"](num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -89,12 +90,13 @@ def __init__( bias=False, is_causal=False, dropout_p=dropout_p, + layer_kernels=layer_kernels, ) self.mlp = nn.Sequential( - nn.Linear(num_channels, hidden_dim), + layer_kernels["Linear"](num_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, num_channels), + layer_kernels["Linear"](hidden_dim, num_channels), ) def forward( @@ -103,8 +105,8 @@ def forward( **kwargs ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x, **kwargs), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x, **kwargs)) + x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm2(x)) return x @@ -115,6 +117,7 @@ def __init__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -129,6 +132,9 @@ def __init__( Number of input channels. out_channels : int Number of output channels. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra layers in MLP, by default 0 activation : str, optional @@ -147,6 +153,7 @@ def __init__( 2 * in_channels, out_channels, out_channels, + layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -176,6 +183,7 @@ def __ini__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -190,6 +198,7 @@ def __ini__( activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, + layer_kernels=layer_kernels, **kwargs, ) @@ -232,6 +241,7 @@ def __ini__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", update_src_nodes: bool = True, @@ -246,6 +256,7 @@ def __ini__( activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, + layer_kernels=layer_kernels, **kwargs, ) @@ -298,6 +309,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -315,6 +327,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -333,15 +348,17 @@ def __init__( self.num_chunks = num_chunks - self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv) - self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias) - self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False) + linear=layer_kernels['Linear'] + layerNorm=layer_kernels['LayerNorm'] + self.lin_key = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_query = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_value = linear(in_channels, num_heads * self.out_channels_conv) + self.lin_self = linear(in_channels, num_heads * self.out_channels_conv, bias=bias) + self.lin_edge = linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False) self.conv = GraphTransformerConv(out_channels=self.out_channels_conv) - self.projection = nn.Linear(out_channels, out_channels) + self.projection = linear(out_channels, out_channels) try: act_func = getattr(nn, activation) @@ -350,20 +367,20 @@ def __init__( raise RuntimeError from ae self.node_dst_mlp = nn.Sequential( - nn.LayerNorm(out_channels), - nn.Linear(out_channels, hidden_dim), + layerNorm(out_channels), + linear(out_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, out_channels), + linear(hidden_dim, out_channels), ) - self.layer_norm1 = nn.LayerNorm(in_channels) + self.layer_norm1 = layerNorm(in_channels) if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - nn.LayerNorm(out_channels), - nn.Linear(out_channels, hidden_dim), + layerNorm(out_channels), + linear(out_channels, hidden_dim), act_func(), - nn.Linear(hidden_dim, out_channels), + linear(hidden_dim, out_channels), ) def shard_qkve_heads( @@ -438,6 +455,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -455,6 +473,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -469,6 +490,7 @@ def __init__( hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, @@ -477,7 +499,7 @@ def __init__( **kwargs, ) - self.layer_norm2 = nn.LayerNorm(in_channels) + self.layer_norm2 = layer_kernels["LayerNorm"](in_channels) def forward( self, @@ -564,6 +586,7 @@ def __init__( hidden_dim: int, out_channels: int, edge_dim: int, + layer_kernels: DotDict, num_heads: int = 16, bias: bool = True, activation: str = "GELU", @@ -581,6 +604,9 @@ def __init__( Number of output channels. edge_dim : int, Edge dimension + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads : int, Number of heads bias : bool, by default True, @@ -596,11 +622,12 @@ def __init__( hidden_dim=hidden_dim, out_channels=out_channels, edge_dim=edge_dim, + layer_kernels=layer_kernels, num_heads=num_heads, bias=bias, activation=activation, num_chunks=num_chunks, - update_src_nodes=update_src_nodes, + update_src_nodes=update_src_nodes **kwargs, ) diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 44dfc76a..621101db 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -24,6 +24,7 @@ from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.mlp import MLP +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -37,6 +38,7 @@ def __init__( num_layers: int, *args, activation: str = "GELU", + layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize BaseProcessorChunk.""" @@ -71,11 +73,11 @@ def __init__( num_channels: int, num_layers: int, window_size: int, + layer_kernels: DotDict, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, - layer_norm: Optional[dict] = None, ) -> None: """Initialize TransformerProcessor. @@ -85,6 +87,11 @@ def __init__( Number of channels num_layers : int Number of layers + window_size: int, + 1/2 size of shifted window for attention computation + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -104,7 +111,7 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, - layer_norm=layer_norm, + layer_kernels=layer_kernels ) def forward( @@ -127,6 +134,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", edge_dim: Optional[int] = None, @@ -139,6 +147,9 @@ def __init__( Channels of the message passing blocks. num_layers : int Number of message passing blocks. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra num_layers in MLP, by default 0 activation : str, optional @@ -166,6 +177,7 @@ def __init__( num_channels, mlp_extra_layers=mlp_extra_layers, activation=activation, + layer_kernels=layer_kernels, ) def forward( @@ -194,6 +206,7 @@ def __init__( self, num_channels: int, num_layers: int, + layer_kernels: DotDict, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", @@ -207,6 +220,9 @@ def __init__( Number of channels. num_layers : int Number of layers. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -226,6 +242,7 @@ def __init__( num_heads=num_heads, edge_dim=edge_dim, activation=activation, + layer_kernels=layer_kernels, ) def forward( diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 1ae45031..3ebcdff8 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -31,6 +31,7 @@ from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mlp import MLP +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -190,6 +191,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerBaseMapper. @@ -213,6 +215,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict, optional + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -222,7 +227,11 @@ def __init__( num_chunks=num_chunks, cpu_offload=cpu_offload, activation=activation, + layer_kernels=layer_kernels, ) + + #Linear = layer_kernels.get("Linear", torch.nn.Linear) + Linear = layer_kernels["Linear"] self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) @@ -236,11 +245,12 @@ def __init__( edge_dim=self.edge_dim, activation=activation, num_chunks=num_chunks, + layer_kernels=layer_kernels ) self.offload_layers(cpu_offload) - self.emb_nodes_dst = nn.Linear(self.in_channels_dst, self.hidden_dim) + self.emb_nodes_dst = Linear(self.in_channels_dst, self.hidden_dim) def forward( self, @@ -291,6 +301,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerForwardMapper. @@ -330,9 +341,10 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) - self.emb_nodes_src = nn.Linear(self.in_channels_src, self.hidden_dim) + self.emb_nodes_src = layer_kernels["Linear"](self.in_channels_src, self.hidden_dim) def forward( self, @@ -364,6 +376,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GraphTransformerBackwardMapper. @@ -387,6 +400,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -403,6 +419,7 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.node_data_extractor = nn.Sequential( diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 1d939704..e771d6e6 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -15,6 +15,7 @@ from anemoi.models.layers.normalization import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper +from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -27,6 +28,7 @@ def __init__( in_features: int, hidden_dim: int, out_features: int, + layer_kernels: DotDict, n_extra_layers: int = 0, activation: str = "SiLU", final_activation: bool = False, @@ -43,6 +45,9 @@ def __init__( Hidden dimensions out_features : int Number of output features + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml n_extra_layers : int, optional Number of extra layers in MLP, by default 0 activation : str, optional @@ -65,23 +70,27 @@ def __init__( If activation function is not supported """ super().__init__() + + Linear = layer_kernels["Linear"] + LayerNorm = layer_kernels["LayerNorm"] + try: act_func = getattr(nn, activation) except AttributeError as ae: LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - mlp1 = nn.Sequential(nn.Linear(in_features, hidden_dim), act_func()) + mlp1 = nn.Sequential(Linear(in_features, hidden_dim), act_func()) for _ in range(n_extra_layers + 1): - mlp1.append(nn.Linear(hidden_dim, hidden_dim)) + mlp1.append(Linear(hidden_dim, hidden_dim)) mlp1.append(act_func()) - mlp1.append(nn.Linear(hidden_dim, out_features)) + mlp1.append(Linear(hidden_dim, out_features)) if final_activation: mlp1.append(act_func()) if layer_norm: - mlp1.append(AutocastLayerNorm(out_features)) + mlp1.append(LayerNorm(out_features).as_type(out_features)) self.model = CheckpointWrapper(mlp1) if checkpoints else mlp1 diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 0afdf927..a069ab3f 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -27,6 +27,7 @@ from anemoi.models.layers.chunk import TransformerProcessorChunk from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mapper import GraphEdgeMixin +from anemoi.utils.config import DotDict class BaseProcessor(nn.Module, ABC): @@ -88,6 +89,7 @@ class TransformerProcessor(BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, *args, window_size: Optional[int] = None, num_channels: int = 128, @@ -97,7 +99,6 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, - layer_norm: Optional[dict] = None, **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -106,6 +107,9 @@ def __init__( ---------- num_layers : int Number of num_layers + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml window_size: int, 1/2 size of shifted window for attention computation num_channels : int @@ -128,6 +132,7 @@ def __init__( cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, + #layer_kernels=layer_kernels, ) self.build_layers( @@ -139,7 +144,7 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, - layer_norm=layer_norm, + layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c67c8c03..16dc9bd9 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -20,6 +20,7 @@ from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData +from hydra.errors import InstantiationException from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import NamedNodesAttributes @@ -64,6 +65,9 @@ def __init__( self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] + + # read config.model.layer_kernels to get the implementation for certain layers + self._load_layer_kernels(model_config) # Encoder data -> hidden self.encoder = instantiate( @@ -74,6 +78,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + layer_kernels=self.layer_kernels, ) # Processor hidden -> hidden @@ -83,6 +88,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + layer_kernels=self.layer_kernels, ) # Decoder hidden -> data @@ -95,6 +101,7 @@ def __init__( sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + layer_kernels=self.layer_kernels, ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -231,3 +238,26 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_out = bounding(x_out) return x_out + + def _load_layer_kernels(self, config: DotDict) -> None: + + # If self.layer_kernels entry is missing from the config, use torch.nn by default + default_kernels=DotDict() + default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True}) + default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True}) + + #self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... + self.layer_kernels= config.model.layer_kernels + + # Loop through all kernels in the layer_kernels config entry and try import them + for kernel in self.layer_kernels: + kernel_entry = self.layer_kernels[kernel] + try: + instantiate(kernel_entry) + except InstantiationException: + LOGGER.info( + f"{kernel_entry['_target_']} not found! check your config.model.layer_kernel.{kernel} entry. Maybe your desired kernel is not installed or the import string is incorrect?" + ) + raise InstantiationException + else: + LOGGER.info(f"{kernel} kernel: {kernel_entry}") \ No newline at end of file From 5505e975a914bb18483e95d6e9adad4770167d1f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:48:22 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/attention.py | 6 +++--- src/anemoi/models/layers/block.py | 13 +++++-------- src/anemoi/models/layers/chunk.py | 13 +++++++------ src/anemoi/models/layers/mapper.py | 8 ++++---- src/anemoi/models/layers/mlp.py | 5 ++--- src/anemoi/models/layers/normalization.py | 3 --- src/anemoi/models/layers/processor.py | 4 ++-- src/anemoi/models/layers/utils.py | 1 - .../models/models/encoder_processor_decoder.py | 16 ++++++++-------- 9 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index dc859b47..64cd7b69 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -25,11 +25,11 @@ else: _FLASH_ATTENTION_AVAILABLE = True +from anemoi.utils.config import DotDict + from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence -from anemoi.utils.config import DotDict - LOGGER = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal - linear=layer_kernels["Linear"] + linear = layer_kernels["Linear"] self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 8296f1b4..edfd2f46 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -16,6 +16,7 @@ import einops import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -32,7 +33,6 @@ from anemoi.models.layers.conv import GraphConv from anemoi.models.layers.conv import GraphTransformerConv from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -100,9 +100,7 @@ def __init__( ) def forward( - self, x: Tensor, shapes: list, batch_size: int, - model_comm_group: Optional[ProcessGroup] = None, - **kwargs + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs ) -> Tensor: # Need to be out of place for gradient propagation x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) @@ -348,8 +346,8 @@ def __init__( self.num_chunks = num_chunks - linear=layer_kernels['Linear'] - layerNorm=layer_kernels['LayerNorm'] + linear = layer_kernels["Linear"] + layerNorm = layer_kernels["LayerNorm"] self.lin_key = linear(in_channels, num_heads * self.out_channels_conv) self.lin_query = linear(in_channels, num_heads * self.out_channels_conv) self.lin_value = linear(in_channels, num_heads * self.out_channels_conv) @@ -627,8 +625,7 @@ def __init__( bias=bias, activation=activation, num_chunks=num_chunks, - update_src_nodes=update_src_nodes - **kwargs, + update_src_nodes=update_src_nodes**kwargs, ) def forward( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 621101db..2c2d0761 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -13,6 +13,7 @@ from abc import abstractmethod from typing import Optional +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -24,7 +25,6 @@ from anemoi.models.layers.block import GraphTransformerProcessorBlock from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -111,18 +111,19 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, - layer_kernels=layer_kernels + layer_kernels=layer_kernels, ) def forward( - self, x: Tensor, shapes: list, batch_size: int, + self, + x: Tensor, + shapes: list, + batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs, ) -> Tensor: for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, - model_comm_group=model_comm_group, - **kwargs) + x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs) return (x,) # return tuple for consistency with other processors diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 3ebcdff8..c4537647 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -14,6 +14,7 @@ import numpy as np import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper @@ -31,7 +32,6 @@ from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mlp import MLP -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -229,8 +229,8 @@ def __init__( activation=activation, layer_kernels=layer_kernels, ) - - #Linear = layer_kernels.get("Linear", torch.nn.Linear) + + # Linear = layer_kernels.get("Linear", torch.nn.Linear) Linear = layer_kernels["Linear"] self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) @@ -245,7 +245,7 @@ def __init__( edge_dim=self.edge_dim, activation=activation, num_chunks=num_chunks, - layer_kernels=layer_kernels + layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index e771d6e6..5230b00c 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -11,11 +11,10 @@ import logging import torch +from anemoi.utils.config import DotDict from torch import nn -from anemoi.models.layers.normalization import AutocastLayerNorm from anemoi.models.layers.utils import CheckpointWrapper -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -73,7 +72,7 @@ def __init__( Linear = layer_kernels["Linear"] LayerNorm = layer_kernels["LayerNorm"] - + try: act_func = getattr(nn, activation) except AttributeError as ae: diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py index 551d90c8..be400f95 100644 --- a/src/anemoi/models/layers/normalization.py +++ b/src/anemoi/models/layers/normalization.py @@ -9,10 +9,7 @@ from __future__ import annotations -from abc import ABC -from abc import abstractmethod -import torch from torch import nn diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index a069ab3f..e892883b 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -11,6 +11,7 @@ from abc import ABC from typing import Optional +from anemoi.utils.config import DotDict from torch import Tensor from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper @@ -27,7 +28,6 @@ from anemoi.models.layers.chunk import TransformerProcessorChunk from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mapper import GraphEdgeMixin -from anemoi.utils.config import DotDict class BaseProcessor(nn.Module, ABC): @@ -132,7 +132,7 @@ def __init__( cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, - #layer_kernels=layer_kernels, + # layer_kernels=layer_kernels, ) self.build_layers( diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index f35c2b8b..6bec46aa 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. -from torch import Tensor from torch import nn from torch.utils.checkpoint import checkpoint diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 16dc9bd9..fd5eb139 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -14,13 +14,13 @@ import einops import torch from anemoi.utils.config import DotDict +from hydra.errors import InstantiationException from hydra.utils import instantiate from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from hydra.errors import InstantiationException from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import NamedNodesAttributes @@ -65,7 +65,7 @@ def __init__( self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] - + # read config.model.layer_kernels to get the implementation for certain layers self._load_layer_kernels(model_config) @@ -242,13 +242,13 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> def _load_layer_kernels(self, config: DotDict) -> None: # If self.layer_kernels entry is missing from the config, use torch.nn by default - default_kernels=DotDict() + default_kernels = DotDict() default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True}) default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True}) - - #self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... - self.layer_kernels= config.model.layer_kernels - + + # self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... + self.layer_kernels = config.model.layer_kernels + # Loop through all kernels in the layer_kernels config entry and try import them for kernel in self.layer_kernels: kernel_entry = self.layer_kernels[kernel] @@ -260,4 +260,4 @@ def _load_layer_kernels(self, config: DotDict) -> None: ) raise InstantiationException else: - LOGGER.info(f"{kernel} kernel: {kernel_entry}") \ No newline at end of file + LOGGER.info(f"{kernel} kernel: {kernel_entry}") From cb791973f8484000e78e0c88caee1bc5dbbd77ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:04:05 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/normalization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py index be400f95..7665c9ea 100644 --- a/src/anemoi/models/layers/normalization.py +++ b/src/anemoi/models/layers/normalization.py @@ -9,7 +9,6 @@ from __future__ import annotations - from torch import nn From ccafbb2d4e09ba57d6c22a7f68addf1e3d70aa1d Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Thu, 12 Dec 2024 09:53:15 +0000 Subject: [PATCH 6/9] Set default behavior for layer_kernels. --- pyproject.toml | 1 - src/anemoi/models/layers/block.py | 28 ++++++++++--------- src/anemoi/models/layers/chunk.py | 3 +- src/anemoi/models/layers/processor.py | 2 +- .../models/encoder_processor_decoder.py | 15 +++++----- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45d56709..aacce2f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ # Fix certain dependencies during development - "anemoi-training @ git+https://github.com/ecmwf/anemoi-training.git@25abf5e143a29d5931ccb4ac42a5f83c5cd26851", "anemoi-utils>=0.1.9", "einops>=0.6.1", "hydra-core>=1.3", diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index edfd2f46..cf04d815 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -13,6 +13,7 @@ from abc import ABC from abc import abstractmethod from typing import Optional +from hydra.utils import instantiate import einops import torch @@ -80,8 +81,8 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae - self.layer_norm1 = layer_kernels["LayerNorm"](num_channels) - self.layer_norm2 = layer_kernels["LayerNorm"](num_channels) + self.layer_norm_attention = layer_kernels["LayerNorm"](normalized_shape=num_channels) + self.layer_norm_mlp = layer_kernels["LayerNorm"](normalized_shape=num_channels) self.attention = MultiHeadSelfAttention( num_heads=num_heads, @@ -100,11 +101,11 @@ def __init__( ) def forward( - self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: # Need to be out of place for gradient propagation - x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group) - x = x + self.mlp(self.layer_norm2(x)) + x = x + self.attention(self.layer_norm_attention(x), shapes, batch_size, model_comm_group=model_comm_group) + x = x + self.mlp(self.layer_norm_mlp(x)) return x @@ -356,6 +357,7 @@ def __init__( self.conv = GraphTransformerConv(out_channels=self.out_channels_conv) + # Why does the GraphTransformer not have a layer_norm_mlp like the Transformer? self.projection = linear(out_channels, out_channels) try: @@ -365,17 +367,17 @@ def __init__( raise RuntimeError from ae self.node_dst_mlp = nn.Sequential( - layerNorm(out_channels), + layerNorm(normalized_shape=out_channels), linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), ) - self.layer_norm1 = layerNorm(in_channels) + self.layer_norm_attention = layerNorm(normalized_shape=in_channels) if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - layerNorm(out_channels), + layerNorm(normlaized_shape=out_channels), linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), @@ -497,7 +499,7 @@ def __init__( **kwargs, ) - self.layer_norm2 = layer_kernels["LayerNorm"](in_channels) + self.layer_norm_attention_2 = layer_kernels["LayerNorm"](normalized_shape=in_channels) def forward( self, @@ -512,9 +514,9 @@ def forward( x_skip = x x = ( - self.layer_norm1(x[0]), - self.layer_norm2(x[1]), - ) # Why does this use layer_norm2? And only is a mapper thing? + self.layer_norm_attention(x[0]), + self.layer_norm_attention_2(x[1]), + ) # Why does this use layer_norm_attention_2? And only is a mapper thing? x_r = self.lin_self(x[1]) query = self.lin_query(x[1]) key = self.lin_key(x[0]) @@ -640,7 +642,7 @@ def forward( ): x_skip = x - x = self.layer_norm1(x) + x = self.layer_norm_attention(x) x_r = self.lin_self(x) query = self.lin_query(x) key = self.lin_key(x) diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 2c2d0761..179b63b3 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -120,10 +120,9 @@ def forward( shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, - **kwargs, ) -> Tensor: for i in range(self.num_layers): - x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs) + x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group) return (x,) # return tuple for consistency with other processors diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index e77be034..a90448a6 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -164,7 +164,7 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group, **kwargs) + (x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group) return x diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index fd5eb139..4c7a4406 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -16,6 +16,7 @@ from anemoi.utils.config import DotDict from hydra.errors import InstantiationException from hydra.utils import instantiate +from omegaconf import OmegaConf from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -241,13 +242,13 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> def _load_layer_kernels(self, config: DotDict) -> None: - # If self.layer_kernels entry is missing from the config, use torch.nn by default - default_kernels = DotDict() - default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True}) - default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True}) - - # self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels... - self.layer_kernels = config.model.layer_kernels + # If self.layer_kernels entry is missing from the config, use torch.nn kernels + default_kernels = { + "Linear": {"_target_": "torch.nn.Linear", "_partial_": True}, + "LayerNorm": {"_target_": "torch.nn.LayerNorm", "_partial_": True}, + } + user_kernel = OmegaConf.select(config, "model.layer_kernels") + self.layer_kernels = {**default_kernels, **user_kernel} # Loop through all kernels in the layer_kernels config entry and try import them for kernel in self.layer_kernels: From b9c1ff9f7ce5a211fc4cfb83222a8efe0e9aeddf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:45:58 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index cf04d815..a4ef1f6f 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -13,7 +13,6 @@ from abc import ABC from abc import abstractmethod from typing import Optional -from hydra.utils import instantiate import einops import torch From 4e350332e3a017a2b184dcc7cb4067078a4c403d Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 18 Dec 2024 13:48:42 +0000 Subject: [PATCH 8/9] Add flexible layer kernels to GNN and GraphTransformer --- src/anemoi/models/layers/block.py | 18 +++++++++++------- src/anemoi/models/layers/chunk.py | 15 ++++++++------- src/anemoi/models/layers/conv.py | 6 ++++++ src/anemoi/models/layers/mapper.py | 21 ++++++++++++++++++++- src/anemoi/models/layers/mlp.py | 2 +- src/anemoi/models/layers/processor.py | 17 +++++++++-------- 6 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 988cef53..e06302f3 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -159,6 +159,7 @@ def __init__( self.conv = GraphConv( in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, ) @@ -192,11 +193,11 @@ def __ini__( self, in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, - layer_kernels=layer_kernels, **kwargs, ) @@ -250,11 +251,11 @@ def __ini__( self, in_channels=in_channels, out_channels=out_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=update_src_nodes, num_chunks=num_chunks, - layer_kernels=layer_kernels, **kwargs, ) @@ -365,18 +366,19 @@ def __init__( LOGGER.error("Activation function %s not supported", activation) raise RuntimeError from ae + self.layer_norm_attention = layerNorm(normalized_shape=in_channels) + self.layer_norm_mlp = layerNorm(normalized_shape=out_channels) + self.node_dst_mlp = nn.Sequential( - layerNorm(normalized_shape=out_channels), + self.layer_norm_mlp, linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), ) - self.layer_norm_attention = layerNorm(normalized_shape=in_channels) - if self.update_src_nodes: self.node_src_mlp = nn.Sequential( - layerNorm(normlaized_shape=out_channels), + self.layer_norm_mlp, linear(out_channels, hidden_dim), act_func(), linear(hidden_dim, out_channels), @@ -516,6 +518,7 @@ def forward( self.layer_norm_attention(x[0]), self.layer_norm_attention_2(x[1]), ) # Why does this use layer_norm_attention_2? And only is a mapper thing? + x_r = self.lin_self(x[1]) query = self.lin_query(x[1]) key = self.lin_key(x[0]) @@ -624,7 +627,8 @@ def __init__( bias=bias, activation=activation, num_chunks=num_chunks, - update_src_nodes=update_src_nodes**kwargs, + update_src_nodes=update_src_nodes, + **kwargs, ) def forward( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 179b63b3..ac0a0bce 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -72,8 +72,8 @@ def __init__( self, num_channels: int, num_layers: int, - window_size: int, layer_kernels: DotDict, + window_size: int, num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", @@ -87,11 +87,11 @@ def __init__( Number of channels num_layers : int Number of layers - window_size: int, - 1/2 size of shifted window for attention computation layer_kernels : DotDict A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" Defined in config/models/.yaml + window_size: int, + 1/2 size of shifted window for attention computation num_heads: int Number of heads to use, default 16 mlp_hidden_ratio: int @@ -110,8 +110,8 @@ def __init__( num_heads=num_heads, activation=activation, window_size=window_size, - dropout_p=dropout_p, layer_kernels=layer_kernels, + dropout_p=dropout_p, ) def forward( @@ -165,6 +165,7 @@ def __init__( in_features=edge_dim, hidden_dim=num_channels, out_features=num_channels, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -175,9 +176,9 @@ def __init__( GraphConvProcessorBlock, num_channels, num_channels, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, - layer_kernels=layer_kernels, ) def forward( @@ -239,10 +240,10 @@ def __init__( in_channels=num_channels, hidden_dim=mlp_hidden_ratio * num_channels, out_channels=num_channels, - num_heads=num_heads, edge_dim=edge_dim, - activation=activation, + num_heads=num_heads, layer_kernels=layer_kernels, + activation=activation, ) def forward( diff --git a/src/anemoi/models/layers/conv.py b/src/anemoi/models/layers/conv.py index 6b3a767e..5c354502 100644 --- a/src/anemoi/models/layers/conv.py +++ b/src/anemoi/models/layers/conv.py @@ -11,6 +11,7 @@ from typing import Optional import torch +from anemoi.utils.config import DotDict from torch import Tensor from torch.nn.functional import dropout from torch_geometric.nn.conv import MessagePassing @@ -31,6 +32,7 @@ def __init__( self, in_channels: int, out_channels: int, + layer_kernels: DotDict, mlp_extra_layers: int = 0, activation: str = "SiLU", **kwargs, @@ -43,6 +45,9 @@ def __init__( Number of input channels. out_channels : int Number of output channels. + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml mlp_extra_layers : int, optional Extra layers in MLP, by default 0 activation : str, optional @@ -54,6 +59,7 @@ def __init__( 3 * in_channels, out_channels, out_channels, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index c4537647..38c0c152 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -227,7 +227,6 @@ def __init__( num_chunks=num_chunks, cpu_offload=cpu_offload, activation=activation, - layer_kernels=layer_kernels, ) # Linear = layer_kernels.get("Linear", torch.nn.Linear) @@ -453,6 +452,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNBaseMapper. @@ -476,6 +476,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -493,6 +496,7 @@ def __init__( in_features=self.edge_dim, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -557,6 +561,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNForwardMapper. @@ -580,6 +585,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -595,11 +603,13 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.proc = GraphConvMapperBlock( hidden_dim, hidden_dim, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=True, @@ -612,6 +622,7 @@ def __init__( in_features=in_channels_src, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -620,6 +631,7 @@ def __init__( in_features=in_channels_dst, hidden_dim=hidden_dim, out_features=hidden_dim, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=activation, ) @@ -643,6 +655,7 @@ def __init__( sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, + layer_kernels: DotDict = None, ) -> None: """Initialize GNNBackwardMapper. @@ -666,6 +679,9 @@ def __init__( Whether to offload processing to CPU, by default False out_channels_dst : Optional[int], optional Output channels of the destination node, by default None + layer_kernels : DotDict + A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear" + Defined in config/models/.yaml """ super().__init__( in_channels_src, @@ -681,11 +697,13 @@ def __init__( sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, + layer_kernels=layer_kernels, ) self.proc = GraphConvMapperBlock( hidden_dim, hidden_dim, + layer_kernels=layer_kernels, mlp_extra_layers=mlp_extra_layers, activation=activation, update_src_nodes=False, @@ -698,6 +716,7 @@ def __init__( in_features=self.hidden_dim, hidden_dim=self.hidden_dim, out_features=self.out_channels_dst, + layer_kernels=layer_kernels, n_extra_layers=mlp_extra_layers, activation=self.activation, layer_norm=False, diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index 5230b00c..af2ce74c 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -89,7 +89,7 @@ def __init__( mlp1.append(act_func()) if layer_norm: - mlp1.append(LayerNorm(out_features).as_type(out_features)) + mlp1.append(LayerNorm(normalized_shape=out_features)) self.model = CheckpointWrapper(mlp1) if checkpoints else mlp1 diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index a90448a6..5b77fdc0 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -124,27 +124,26 @@ def __init__( Dropout probability used for multi-head self attention, default 0.0 """ super().__init__( - num_channels=num_channels, num_layers=num_layers, + num_channels=num_channels, window_size=window_size, num_chunks=num_chunks, activation=activation, cpu_offload=cpu_offload, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, - # layer_kernels=layer_kernels, ) self.build_layers( TransformerProcessorChunk, num_channels=num_channels, + num_layers=self.chunk_size, + layer_kernels=layer_kernels, mlp_hidden_ratio=mlp_hidden_ratio, num_heads=num_heads, - num_layers=self.chunk_size, window_size=window_size, activation=activation, dropout_p=dropout_p, - layer_kernels=layer_kernels, ) self.offload_layers(cpu_offload) @@ -175,6 +174,7 @@ class GNNProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, *args, trainable_size: int = 8, num_channels: int = 128, @@ -219,16 +219,15 @@ def __init__( self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) kwargs = { - "num_layers": self.chunk_size, "mlp_extra_layers": mlp_extra_layers, "activation": activation, "edge_dim": None, } - self.build_layers(GNNProcessorChunk, num_channels, **kwargs) + self.build_layers(GNNProcessorChunk, num_channels, self.chunk_size, layer_kernels, **kwargs) kwargs["edge_dim"] = self.edge_dim # Edge dim for first layer - self.proc[0] = GNNProcessorChunk(num_channels, **kwargs) + self.proc[0] = GNNProcessorChunk(num_channels, self.chunk_size, layer_kernels, **kwargs) self.offload_layers(cpu_offload) @@ -263,6 +262,7 @@ class GraphTransformerProcessor(GraphEdgeMixin, BaseProcessor): def __init__( self, num_layers: int, + layer_kernels: DotDict, trainable_size: int = 8, num_channels: int = 128, num_chunks: int = 2, @@ -296,8 +296,8 @@ def __init__( Whether to offload processing to CPU, by default False """ super().__init__( - num_layers=num_layers, num_channels=num_channels, + num_layers=num_layers, num_chunks=num_chunks, activation=activation, cpu_offload=cpu_offload, @@ -313,6 +313,7 @@ def __init__( GraphTransformerProcessorChunk, num_channels=num_channels, num_layers=self.chunk_size, + layer_kernels=layer_kernels, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, activation=activation, From 84dfdfc3d58243f5df48e341e5f7e9a21dc02ff3 Mon Sep 17 00:00:00 2001 From: Jakob Schloer Date: Wed, 18 Dec 2024 13:51:55 +0000 Subject: [PATCH 9/9] Add type annotation. --- src/anemoi/models/layers/normalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/models/layers/normalization.py b/src/anemoi/models/layers/normalization.py index 7665c9ea..d3193c4a 100644 --- a/src/anemoi/models/layers/normalization.py +++ b/src/anemoi/models/layers/normalization.py @@ -9,6 +9,7 @@ from __future__ import annotations +from torch import Tensor from torch import nn