diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbc225df..8f820a8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.8.1 hooks: - id: ruff args: @@ -64,7 +64,7 @@ repos: hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig - rev: v0.64.0 + rev: v0.65.0 hooks: - id: docsig args: @@ -74,6 +74,5 @@ repos: - --check-protected # Check protected methods - --check-class # Check class docstrings - --disable=E113 # Disable empty docstrings - - --summary # Print a summary ci: autoupdate_schedule: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md index 125679af..801cd05f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.4.0...HEAD) -- Add synchronisation workflow +### Added + +- New AnemoiModelEncProcDecHierarchical class available in models [#37](https://github.com/ecmwf/anemoi-models/pull/37) +- Mask NaN values in training loss function [#56](https://github.com/ecmwf/anemoi-models/pull/56) +- Added dynamic NaN masking for the imputer class with two new classes: DynamicInputImputer, DynamicConstantImputer [#89](https://github.com/ecmwf/anemoi-models/pull/89) +- Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84) +- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) + +## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design ### Added +- Add synchronisation workflow [#60](https://github.com/ecmwf/anemoi-models/pull/60) - Add anemoi-transform link to documentation - Codeowners file - Pygrep precommit hooks @@ -22,7 +31,6 @@ Keep it human-readable, your future self will thank you! - configurabilty of the dropout probability in the the MultiHeadSelfAttention module - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) -- Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271) - New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) - Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69) - Add sequence sharding strategy for TransformerProcessor [#90](https://github.com/ecmwf/anemoi-models/pull/90) diff --git a/docs/modules/models.rst b/docs/modules/models.rst index 392a9d61..416257df 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -13,3 +13,29 @@ encoder, processor, and decoder. :members: :no-undoc-members: :show-inheritance: + +********************************************** + Encoder Hierarchical processor Decoder Model +********************************************** + +This model extends the standard encoder-processor-decoder architecture +by introducing a **hierarchical processor**. + +Compared to the AnemoiModelEncProcDec model, this architecture requires +a predefined list of hidden nodes, `[hidden_1, ..., hidden_n]`. These +nodes must be sorted to match the expected flow of information `data -> +hidden_1 -> ... -> hidden_n -> ... -> hidden_1 -> data`. + +A new argument is added to the configuration file: +`enable_hierarchical_level_processing`. This argument determines whether +a processor is added at each hierarchy level or only at the final level. + +By default, the number of channels for the mappers is defined as `2^n * +config.num_channels`, where `n` represents the hierarchy level. This +scaling ensures that the processing capacity grows proportionally with +the depth of the hierarchy, enabling efficient handling of data. + +.. automodule:: anemoi.models.models.hierarchical + :members: + :no-undoc-members: + :show-inheritance: diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 25b7852a..261dec29 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -37,6 +37,8 @@ class AnemoiModelInterface(torch.nn.Module): Statistics for the data. metadata : dict Metadata for the model. + supporting_arrays : dict + Numpy arraysto store in the checkpoint. data_indices : dict Indices for the data. pre_processors : Processors @@ -48,7 +50,14 @@ class AnemoiModelInterface(torch.nn.Module): """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, + *, + config: DotDict, + graph_data: HeteroData, + statistics: dict, + data_indices: dict, + metadata: dict, + supporting_arrays: dict = None, ) -> None: super().__init__() self.config = config @@ -57,6 +66,7 @@ def __init__( self.graph_data = graph_data self.statistics = statistics self.metadata = metadata + self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {} self.data_indices = data_indices self._build_model() diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 90f9979d..d68f2e0e 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -514,8 +514,9 @@ def forward( edge_attr_list, edge_index_list = sort_edges_1hop_chunks( num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks ) + out = torch.zeros((x[1].shape[0], self.num_heads, self.out_channels_conv), device=x[1].device) for i in range(num_chunks): - out1 = self.conv( + out += self.conv( query=query, key=key, value=value, @@ -523,9 +524,6 @@ def forward( edge_index=edge_index_list[i], size=size, ) - if i == 0: - out = torch.zeros_like(out1, device=out1.device) - out = out + out1 else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 6d729c7b..ffa91fcd 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -325,6 +325,7 @@ def forward( *args, **kwargs, ) -> Tensor: + shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels) edge_attr = self.trainable(self.edge_attr, batch_size) diff --git a/src/anemoi/models/models/__init__.py b/src/anemoi/models/models/__init__.py index c167afa2..2072f12f 100644 --- a/src/anemoi/models/models/__init__.py +++ b/src/anemoi/models/models/__init__.py @@ -6,3 +6,8 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + +from .encoder_processor_decoder import AnemoiModelEncProcDec +from .hierarchical import AnemoiModelEncProcDecHierarchical + +__all__ = ["AnemoiModelEncProcDec", "AnemoiModelEncProcDecHierarchical"] diff --git a/src/anemoi/models/models/hierarchical.py b/src/anemoi/models/models/hierarchical.py new file mode 100644 index 00000000..94e82581 --- /dev/null +++ b/src/anemoi/models/models/hierarchical.py @@ -0,0 +1,308 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import einops +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch import Tensor +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.data import HeteroData + +from anemoi.models.distributed.shapes import get_shape_shards +from anemoi.models.layers.graph import NamedNodesAttributes +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.models import AnemoiModelEncProcDec + +LOGGER = logging.getLogger(__name__) + + +class AnemoiModelEncProcDecHierarchical(AnemoiModelEncProcDec): + """Message passing hierarchical graph neural network.""" + + def __init__( + self, + *, + model_config: DotDict, + data_indices: dict, + graph_data: HeteroData, + ) -> None: + """Initializes the graph neural network. + + Parameters + ---------- + config : DotDict + Job configuration + data_indices : dict + Data indices + graph_data : HeteroData + Graph definition + """ + nn.Module.__init__(self) + + self._graph_data = graph_data + self._graph_name_data = model_config.graph.data + self._graph_hidden_names = model_config.graph.hidden + self.num_hidden = len(self._graph_hidden_names) + + # Unpack config for hierarchical graph + self.level_process = model_config.model.enable_hierarchical_level_processing + + # hidden_dims is the dimentionality of features at each depth + self.hidden_dims = { + hidden: model_config.model.num_channels * (2**i) for i, hidden in enumerate(self._graph_hidden_names) + } + + self._calculate_shapes_and_indices(data_indices) + self._assert_matching_indices(data_indices) + self.data_indices = data_indices + + self.multi_step = model_config.training.multistep_input + + # self.node_attributes = {hidden_name: NamedNodesAttributes(model_config.model.trainable_parameters[hidden_name], self._graph_data) + # for hidden_name in self._graph_hidden_names} + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] + + # Encoder data -> hidden + self.encoder = instantiate( + model_config.model.encoder, + in_channels_src=input_dim, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_hidden_names[0]], + hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], + sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_hidden_names[0])], + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]], + ) + + # Level processors + if self.level_process: + self.down_level_processor = nn.ModuleDict() + self.up_level_processor = nn.ModuleDict() + + for i in range(0, self.num_hidden): + nodes_names = self._graph_hidden_names[i] + + self.down_level_processor[nodes_names] = instantiate( + model_config.model.processor, + num_channels=self.hidden_dims[nodes_names], + sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], + src_grid_size=self.node_attributes.num_nodes[nodes_names], + dst_grid_size=self.node_attributes.num_nodes[nodes_names], + num_layers=model_config.model.level_process_num_layers, + ) + + self.up_level_processor[nodes_names] = instantiate( + model_config.model.processor, + num_channels=self.hidden_dims[nodes_names], + sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], + src_grid_size=self.node_attributes.num_nodes[nodes_names], + dst_grid_size=self.node_attributes.num_nodes[nodes_names], + num_layers=model_config.model.level_process_num_layers, + ) + + # delete final upscale (does not exist): |->|->|<-|<-| + del self.up_level_processor[nodes_names] + + # Downscale + self.downscale = nn.ModuleDict() + + for i in range(0, self.num_hidden - 1): + src_nodes_name = self._graph_hidden_names[i] + dst_nodes_name = self._graph_hidden_names[i + 1] + + self.downscale[src_nodes_name] = instantiate( + model_config.model.encoder, + in_channels_src=self.hidden_dims[src_nodes_name], + in_channels_dst=self.node_attributes.attr_ndims[dst_nodes_name], + hidden_dim=self.hidden_dims[dst_nodes_name], + sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], + src_grid_size=self.node_attributes.num_nodes[src_nodes_name], + dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name], + ) + + # Upscale + self.upscale = nn.ModuleDict() + + for i in range(1, self.num_hidden): + src_nodes_name = self._graph_hidden_names[i] + dst_nodes_name = self._graph_hidden_names[i - 1] + + self.upscale[src_nodes_name] = instantiate( + model_config.model.decoder, + in_channels_src=self.hidden_dims[src_nodes_name], + in_channels_dst=self.hidden_dims[dst_nodes_name], + hidden_dim=self.hidden_dims[src_nodes_name], + out_channels_dst=self.hidden_dims[dst_nodes_name], + sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], + src_grid_size=self.node_attributes.num_nodes[src_nodes_name], + dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name], + ) + + # Decoder hidden -> data + self.decoder = instantiate( + model_config.model.decoder, + in_channels_src=self.hidden_dims[self._graph_hidden_names[0]], + in_channels_dst=input_dim, + hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], + out_channels_dst=self.num_output_channels, + sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)], + src_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + ) + + # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) + self.boundings = nn.ModuleList( + [ + instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index) + for cfg in getattr(model_config.model, "bounding", []) + ] + ) + + def _create_trainable_attributes(self) -> None: + """Create all trainable attributes.""" + self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) + self.trainable_hidden = nn.ModuleDict() + + for hidden in self._graph_hidden_names: + self.trainable_hidden[hidden] = TrainableTensor( + trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden] + ) + + def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: + batch_size = x.shape[0] + ensemble_size = x.shape[2] + + # add data positional info (lat/lon) + x_trainable_data = torch.cat( + ( + einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), + self.node_attributes(self._graph_name_data, batch_size=batch_size), + ), + dim=-1, # feature dimension + ) + + # Get all trainable parameters for the hidden layers -> initialisation of each hidden, which becomes trainable bias + x_trainable_hiddens = {} + for hidden in self._graph_hidden_names: + x_trainable_hiddens[hidden] = self.node_attributes(hidden, batch_size=batch_size) + + # Get data and hidden shapes for sharding + shard_shapes_data = get_shape_shards(x_trainable_data, 0, model_comm_group) + shard_shapes_hiddens = {} + for hidden, x_latent in x_trainable_hiddens.items(): + shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group) + + # Run encoder + x_data_latent, curr_latent = self._run_mapper( + self.encoder, + (x_trainable_data, x_trainable_hiddens[self._graph_hidden_names[0]]), + batch_size=batch_size, + shard_shapes=(shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]), + model_comm_group=model_comm_group, + ) + + # Run processor + x_encoded_latents = {} + x_skip = {} + + ## Downscale + for i in range(0, self.num_hidden - 1): + src_hidden_name = self._graph_hidden_names[i] + dst_hidden_name = self._graph_hidden_names[i + 1] + + # Processing at same level + if self.level_process: + curr_latent = self.down_level_processor[src_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[src_hidden_name], + model_comm_group=model_comm_group, + ) + + # store latents for skip connections + x_skip[src_hidden_name] = curr_latent + + # Encode to next hidden level + x_encoded_latents[src_hidden_name], curr_latent = self._run_mapper( + self.downscale[src_hidden_name], + (curr_latent, x_trainable_hiddens[dst_hidden_name]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), + model_comm_group=model_comm_group, + ) + + # Processing hidden-most level + if self.level_process: + curr_latent = self.down_level_processor[dst_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[dst_hidden_name], + model_comm_group=model_comm_group, + ) + + ## Upscale + for i in range(self.num_hidden - 1, 0, -1): + src_hidden_name = self._graph_hidden_names[i] + dst_hidden_name = self._graph_hidden_names[i - 1] + + # Process to next level + curr_latent = self._run_mapper( + self.upscale[src_hidden_name], + (curr_latent, x_encoded_latents[dst_hidden_name]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), + model_comm_group=model_comm_group, + ) + + # Add skip connections + curr_latent = curr_latent + x_skip[dst_hidden_name] + + # Processing at same level + if self.level_process: + curr_latent = self.up_level_processor[dst_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[dst_hidden_name], + model_comm_group=model_comm_group, + ) + + # Run decoder + x_out = self._run_mapper( + self.decoder, + (curr_latent, x_data_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[self._graph_hidden_names[0]], shard_shapes_data), + model_comm_group=model_comm_group, + ) + + x_out = ( + einops.rearrange( + x_out, + "(batch ensemble grid) vars -> batch ensemble grid vars", + batch=batch_size, + ensemble=ensemble_size, + ) + .to(dtype=x.dtype) + .clone() + ) + + # residual connection (just for the prognostic variables) + x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + + for bounding in self.boundings: + # bounding performed in the order specified in the config file + x_out = bounding(x_out) + + return x_out diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index 5090c9f0..4bd3c0ae 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -9,6 +9,7 @@ import logging +import warnings from abc import ABC from typing import Optional @@ -106,6 +107,12 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: """Expand the subset of the mask to the correct shape.""" return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """get NaN mask from data""" + # The mask is only saved for the last two dimensions (grid, variable) + idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] + return torch.isnan(x[idx].squeeze()) + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: @@ -113,9 +120,9 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # Initialize nan mask once if self.nan_locations is None: - # The mask is only saved for the last two dimensions (grid, variable) - idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] - self.nan_locations = torch.isnan(x[idx].squeeze()) + + # Get NaN locations + self.nan_locations = self.get_nans(x) # Initialize training loss mask to weigh imputed values with zeroes once self.loss_mask_training = torch.ones( @@ -222,3 +229,77 @@ def __init__( self._create_imputation_indices() self._validate_indices() + + +class DynamicMixin: + """Mixin to add dynamic imputation behavior.""" + + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """Override to calculate NaN locations dynamically.""" + return torch.isnan(x) + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initilialize mask every time + nan_locations = self.get_nans(x) + + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + x[..., idx_dst][nan_locations[..., idx_src]] = value + + return x + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + return x + + +class DynamicInputImputer(DynamicMixin, InputImputer): + "Imputes missing values using the statistics supplied and a dynamic NaN map." + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + warnings.warn( + "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ + The model will be trained to predict imputed values. This might deteriorate performances." + ) + + +class DynamicConstantImputer(DynamicMixin, ConstantImputer): + "Imputes missing values using the constant value and a dynamic NaN map." + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + warnings.warn( + "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ + The model will be trained to predict imputed values. This might deteriorate performances." + )