diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bfaf353f..8f820a8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: python-check-blanket-noqa # Check for # noqa: all - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black args: [--line-length=120] @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.8.1 hooks: - id: ruff args: @@ -60,11 +60,11 @@ repos: - id: rstfmt exclude: 'cli/.*' # Because we use argparse - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.4" + rev: "v2.5.0" hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig - rev: v0.64.0 + rev: v0.65.0 hooks: - id: docsig args: @@ -74,6 +74,5 @@ repos: - --check-protected # Check protected methods - --check-class # Check class docstrings - --disable=E113 # Disable empty docstrings - - --summary # Print a summary ci: autoupdate_schedule: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md index 07ee5709..ba61ebe3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.4.0...HEAD) -- Add synchronisation workflow +### Added + +- New AnemoiModelEncProcDecHierarchical class available in models [#37](https://github.com/ecmwf/anemoi-models/pull/37) +- Mask NaN values in training loss function [#56](https://github.com/ecmwf/anemoi-models/pull/56) +- Added dynamic NaN masking for the imputer class with two new classes: DynamicInputImputer, DynamicConstantImputer [#89](https://github.com/ecmwf/anemoi-models/pull/89) +- Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84) +- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) +- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88) + +## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design ### Added +- Add synchronisation workflow [#60](https://github.com/ecmwf/anemoi-models/pull/60) - Add anemoi-transform link to documentation - Codeowners file - Pygrep precommit hooks diff --git a/docs/modules/models.rst b/docs/modules/models.rst index 392a9d61..416257df 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -13,3 +13,29 @@ encoder, processor, and decoder. :members: :no-undoc-members: :show-inheritance: + +********************************************** + Encoder Hierarchical processor Decoder Model +********************************************** + +This model extends the standard encoder-processor-decoder architecture +by introducing a **hierarchical processor**. + +Compared to the AnemoiModelEncProcDec model, this architecture requires +a predefined list of hidden nodes, `[hidden_1, ..., hidden_n]`. These +nodes must be sorted to match the expected flow of information `data -> +hidden_1 -> ... -> hidden_n -> ... -> hidden_1 -> data`. + +A new argument is added to the configuration file: +`enable_hierarchical_level_processing`. This argument determines whether +a processor is added at each hierarchy level or only at the final level. + +By default, the number of channels for the mappers is defined as `2^n * +config.num_channels`, where `n` represents the hierarchy level. This +scaling ensures that the processing capacity grows proportionally with +the depth of the hierarchy, enabling efficient handling of data. + +.. automodule:: anemoi.models.models.hierarchical + :members: + :no-undoc-members: + :show-inheritance: diff --git a/pyproject.toml b/pyproject.toml index ba0b9d33..6d473472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 25b7852a..261dec29 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -37,6 +37,8 @@ class AnemoiModelInterface(torch.nn.Module): Statistics for the data. metadata : dict Metadata for the model. + supporting_arrays : dict + Numpy arraysto store in the checkpoint. data_indices : dict Indices for the data. pre_processors : Processors @@ -48,7 +50,14 @@ class AnemoiModelInterface(torch.nn.Module): """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, + *, + config: DotDict, + graph_data: HeteroData, + statistics: dict, + data_indices: dict, + metadata: dict, + supporting_arrays: dict = None, ) -> None: super().__init__() self.config = config @@ -57,6 +66,7 @@ def __init__( self.graph_data = graph_data self.statistics = statistics self.metadata = metadata + self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {} self.data_indices = data_indices self._build_model() diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6c..72e487d2 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -512,8 +512,9 @@ def forward( edge_attr_list, edge_index_list = sort_edges_1hop_chunks( num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks ) + out = torch.zeros((x[1].shape[0], self.num_heads, self.out_channels_conv), device=x[1].device) for i in range(num_chunks): - out1 = self.conv( + out += self.conv( query=query, key=key, value=value, @@ -521,9 +522,6 @@ def forward( edge_index=edge_index_list[i], size=size, ) - if i == 0: - out = torch.zeros_like(out1, device=out1.device) - out = out + out1 else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 4fd32311..8dba1f66 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -323,6 +323,7 @@ def forward( *args, **kwargs, ) -> Tensor: + shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels) edge_attr = self.trainable(self.edge_attr, batch_size) diff --git a/src/anemoi/models/models/__init__.py b/src/anemoi/models/models/__init__.py index c167afa2..2072f12f 100644 --- a/src/anemoi/models/models/__init__.py +++ b/src/anemoi/models/models/__init__.py @@ -6,3 +6,8 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + +from .encoder_processor_decoder import AnemoiModelEncProcDec +from .hierarchical import AnemoiModelEncProcDecHierarchical + +__all__ = ["AnemoiModelEncProcDec", "AnemoiModelEncProcDecHierarchical"] diff --git a/src/anemoi/models/models/hierarchical.py b/src/anemoi/models/models/hierarchical.py new file mode 100644 index 00000000..94e82581 --- /dev/null +++ b/src/anemoi/models/models/hierarchical.py @@ -0,0 +1,308 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import einops +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch import Tensor +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.data import HeteroData + +from anemoi.models.distributed.shapes import get_shape_shards +from anemoi.models.layers.graph import NamedNodesAttributes +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.models import AnemoiModelEncProcDec + +LOGGER = logging.getLogger(__name__) + + +class AnemoiModelEncProcDecHierarchical(AnemoiModelEncProcDec): + """Message passing hierarchical graph neural network.""" + + def __init__( + self, + *, + model_config: DotDict, + data_indices: dict, + graph_data: HeteroData, + ) -> None: + """Initializes the graph neural network. + + Parameters + ---------- + config : DotDict + Job configuration + data_indices : dict + Data indices + graph_data : HeteroData + Graph definition + """ + nn.Module.__init__(self) + + self._graph_data = graph_data + self._graph_name_data = model_config.graph.data + self._graph_hidden_names = model_config.graph.hidden + self.num_hidden = len(self._graph_hidden_names) + + # Unpack config for hierarchical graph + self.level_process = model_config.model.enable_hierarchical_level_processing + + # hidden_dims is the dimentionality of features at each depth + self.hidden_dims = { + hidden: model_config.model.num_channels * (2**i) for i, hidden in enumerate(self._graph_hidden_names) + } + + self._calculate_shapes_and_indices(data_indices) + self._assert_matching_indices(data_indices) + self.data_indices = data_indices + + self.multi_step = model_config.training.multistep_input + + # self.node_attributes = {hidden_name: NamedNodesAttributes(model_config.model.trainable_parameters[hidden_name], self._graph_data) + # for hidden_name in self._graph_hidden_names} + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] + + # Encoder data -> hidden + self.encoder = instantiate( + model_config.model.encoder, + in_channels_src=input_dim, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_hidden_names[0]], + hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], + sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_hidden_names[0])], + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]], + ) + + # Level processors + if self.level_process: + self.down_level_processor = nn.ModuleDict() + self.up_level_processor = nn.ModuleDict() + + for i in range(0, self.num_hidden): + nodes_names = self._graph_hidden_names[i] + + self.down_level_processor[nodes_names] = instantiate( + model_config.model.processor, + num_channels=self.hidden_dims[nodes_names], + sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], + src_grid_size=self.node_attributes.num_nodes[nodes_names], + dst_grid_size=self.node_attributes.num_nodes[nodes_names], + num_layers=model_config.model.level_process_num_layers, + ) + + self.up_level_processor[nodes_names] = instantiate( + model_config.model.processor, + num_channels=self.hidden_dims[nodes_names], + sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], + src_grid_size=self.node_attributes.num_nodes[nodes_names], + dst_grid_size=self.node_attributes.num_nodes[nodes_names], + num_layers=model_config.model.level_process_num_layers, + ) + + # delete final upscale (does not exist): |->|->|<-|<-| + del self.up_level_processor[nodes_names] + + # Downscale + self.downscale = nn.ModuleDict() + + for i in range(0, self.num_hidden - 1): + src_nodes_name = self._graph_hidden_names[i] + dst_nodes_name = self._graph_hidden_names[i + 1] + + self.downscale[src_nodes_name] = instantiate( + model_config.model.encoder, + in_channels_src=self.hidden_dims[src_nodes_name], + in_channels_dst=self.node_attributes.attr_ndims[dst_nodes_name], + hidden_dim=self.hidden_dims[dst_nodes_name], + sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], + src_grid_size=self.node_attributes.num_nodes[src_nodes_name], + dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name], + ) + + # Upscale + self.upscale = nn.ModuleDict() + + for i in range(1, self.num_hidden): + src_nodes_name = self._graph_hidden_names[i] + dst_nodes_name = self._graph_hidden_names[i - 1] + + self.upscale[src_nodes_name] = instantiate( + model_config.model.decoder, + in_channels_src=self.hidden_dims[src_nodes_name], + in_channels_dst=self.hidden_dims[dst_nodes_name], + hidden_dim=self.hidden_dims[src_nodes_name], + out_channels_dst=self.hidden_dims[dst_nodes_name], + sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], + src_grid_size=self.node_attributes.num_nodes[src_nodes_name], + dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name], + ) + + # Decoder hidden -> data + self.decoder = instantiate( + model_config.model.decoder, + in_channels_src=self.hidden_dims[self._graph_hidden_names[0]], + in_channels_dst=input_dim, + hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], + out_channels_dst=self.num_output_channels, + sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)], + src_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + ) + + # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) + self.boundings = nn.ModuleList( + [ + instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index) + for cfg in getattr(model_config.model, "bounding", []) + ] + ) + + def _create_trainable_attributes(self) -> None: + """Create all trainable attributes.""" + self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) + self.trainable_hidden = nn.ModuleDict() + + for hidden in self._graph_hidden_names: + self.trainable_hidden[hidden] = TrainableTensor( + trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden] + ) + + def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: + batch_size = x.shape[0] + ensemble_size = x.shape[2] + + # add data positional info (lat/lon) + x_trainable_data = torch.cat( + ( + einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), + self.node_attributes(self._graph_name_data, batch_size=batch_size), + ), + dim=-1, # feature dimension + ) + + # Get all trainable parameters for the hidden layers -> initialisation of each hidden, which becomes trainable bias + x_trainable_hiddens = {} + for hidden in self._graph_hidden_names: + x_trainable_hiddens[hidden] = self.node_attributes(hidden, batch_size=batch_size) + + # Get data and hidden shapes for sharding + shard_shapes_data = get_shape_shards(x_trainable_data, 0, model_comm_group) + shard_shapes_hiddens = {} + for hidden, x_latent in x_trainable_hiddens.items(): + shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group) + + # Run encoder + x_data_latent, curr_latent = self._run_mapper( + self.encoder, + (x_trainable_data, x_trainable_hiddens[self._graph_hidden_names[0]]), + batch_size=batch_size, + shard_shapes=(shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]), + model_comm_group=model_comm_group, + ) + + # Run processor + x_encoded_latents = {} + x_skip = {} + + ## Downscale + for i in range(0, self.num_hidden - 1): + src_hidden_name = self._graph_hidden_names[i] + dst_hidden_name = self._graph_hidden_names[i + 1] + + # Processing at same level + if self.level_process: + curr_latent = self.down_level_processor[src_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[src_hidden_name], + model_comm_group=model_comm_group, + ) + + # store latents for skip connections + x_skip[src_hidden_name] = curr_latent + + # Encode to next hidden level + x_encoded_latents[src_hidden_name], curr_latent = self._run_mapper( + self.downscale[src_hidden_name], + (curr_latent, x_trainable_hiddens[dst_hidden_name]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), + model_comm_group=model_comm_group, + ) + + # Processing hidden-most level + if self.level_process: + curr_latent = self.down_level_processor[dst_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[dst_hidden_name], + model_comm_group=model_comm_group, + ) + + ## Upscale + for i in range(self.num_hidden - 1, 0, -1): + src_hidden_name = self._graph_hidden_names[i] + dst_hidden_name = self._graph_hidden_names[i - 1] + + # Process to next level + curr_latent = self._run_mapper( + self.upscale[src_hidden_name], + (curr_latent, x_encoded_latents[dst_hidden_name]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), + model_comm_group=model_comm_group, + ) + + # Add skip connections + curr_latent = curr_latent + x_skip[dst_hidden_name] + + # Processing at same level + if self.level_process: + curr_latent = self.up_level_processor[dst_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[dst_hidden_name], + model_comm_group=model_comm_group, + ) + + # Run decoder + x_out = self._run_mapper( + self.decoder, + (curr_latent, x_data_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[self._graph_hidden_names[0]], shard_shapes_data), + model_comm_group=model_comm_group, + ) + + x_out = ( + einops.rearrange( + x_out, + "(batch ensemble grid) vars -> batch ensemble grid vars", + batch=batch_size, + ensemble=ensemble_size, + ) + .to(dtype=x.dtype) + .clone() + ) + + # residual connection (just for the prognostic variables) + x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + + for bounding in self.boundings: + # bounding performed in the order specified in the config file + x_out = bounding(x_out) + + return x_out diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index 4c7f6153..e26505f3 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -57,25 +57,32 @@ def __init__( super().__init__() - self.default, self.method_config = self._process_config(config) + self.default, self.remap, self.method_config = self._process_config(config) self.methods = self._invert_key_value_list(self.method_config) self.data_indices = data_indices - def _process_config(self, config): + @classmethod + def _process_config(cls, config): _special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method. default = config.get("default", "none") - self.remap = config.get("remap", {}) + remap = config.get("remap", {}) method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"} if not method_config: LOGGER.warning( - f"{self.__class__.__name__}: Using default method {default} for all variables not specified in the config.", + f"{cls.__name__}: Using default method {default} for all variables not specified in the config.", ) + for m in method_config: + if isinstance(method_config[m], str): + method_config[m] = {method_config[m]: f"{m}_{method_config[m]}"} + elif isinstance(method_config[m], list): + method_config[m] = {method: f"{m}_{method}" for method in method_config[m]} - return default, method_config + return default, remap, method_config - def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[str, str]: + @staticmethod + def _invert_key_value_list(method_config: dict[str, list[str]]) -> dict[str, str]: """Invert a dictionary of methods with lists of variables. Parameters diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index 5f9c1b9d..4bd3c0ae 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -9,6 +9,7 @@ import logging +import warnings from abc import ABC from typing import Optional @@ -43,6 +44,8 @@ def __init__( super().__init__(config, data_indices, statistics) self.nan_locations = None + # weight imputed values wiht zero in loss calculation + self.loss_mask_training = None def _validate_indices(self): assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), ( @@ -104,16 +107,31 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: """Expand the subset of the mask to the correct shape.""" return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """get NaN mask from data""" + # The mask is only saved for the last two dimensions (grid, variable) + idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] + return torch.isnan(x[idx].squeeze()) + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() - # Initilialize mask once + # Initialize nan mask once if self.nan_locations is None: - # The mask is only saved for the last two dimensions (grid, variable) - idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] - self.nan_locations = torch.isnan(x[idx].squeeze()) + + # Get NaN locations + self.nan_locations = self.get_nans(x) + + # Initialize training loss mask to weigh imputed values with zeroes once + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) # shape (grid, n_outputs) + # for all variables that are imputed and part of the model output, set the loss weight to zero + for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output): + if idx_dst is not None: + self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int() # Choose correct index based on number of variables if x.shape[-1] == self.num_training_input_vars: @@ -211,3 +229,77 @@ def __init__( self._create_imputation_indices() self._validate_indices() + + +class DynamicMixin: + """Mixin to add dynamic imputation behavior.""" + + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """Override to calculate NaN locations dynamically.""" + return torch.isnan(x) + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initilialize mask every time + nan_locations = self.get_nans(x) + + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + x[..., idx_dst][nan_locations[..., idx_src]] = value + + return x + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + return x + + +class DynamicInputImputer(DynamicMixin, InputImputer): + "Imputes missing values using the statistics supplied and a dynamic NaN map." + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + warnings.warn( + "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ + The model will be trained to predict imputed values. This might deteriorate performances." + ) + + +class DynamicConstantImputer(DynamicMixin, ConstantImputer): + "Imputes missing values using the constant value and a dynamic NaN map." + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + warnings.warn( + "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ + The model will be trained to predict imputed values. This might deteriorate performances." + ) diff --git a/src/anemoi/models/preprocessing/mappings.py b/src/anemoi/models/preprocessing/mappings.py new file mode 100644 index 00000000..dab46734 --- /dev/null +++ b/src/anemoi/models/preprocessing/mappings.py @@ -0,0 +1,75 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch + + +def noop(x): + """No operation.""" + return x + + +def cos_converter(x): + """Convert angle in degree to cos.""" + return torch.cos(x / 180 * torch.pi) + + +def sin_converter(x): + """Convert angle in degree to sin.""" + return torch.sin(x / 180 * torch.pi) + + +def atan2_converter(x): + """Convert cos and sin to angle in degree. + + Input: + x[..., 0]: cos + x[..., 1]: sin + """ + return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) + + +def log1p_converter(x): + """Convert positive var in to log(1+var).""" + return torch.log1p(x) + + +def boxcox_converter(x, lambd=0.5): + """Convert positive var in to boxcox(var).""" + pos_lam = (torch.pow(x, lambd) - 1) / lambd + null_lam = torch.log(x) + if lambd == 0: + return null_lam + else: + return pos_lam + + +def sqrt_converter(x): + """Convert positive var in to sqrt(var).""" + return torch.sqrt(x) + + +def expm1_converter(x): + """Convert back log(1+var) to var.""" + return torch.expm1(x) + + +def square_converter(x): + """Convert back sqrt(var) to var.""" + return x**2 + + +def inverse_boxcox_converter(x, lambd=0.5): + """Convert back boxcox(var) to var.""" + pos_lam = torch.pow(x * lambd + 1, 1 / lambd) + null_lam = torch.exp(x) + if lambd == 0: + return null_lam + else: + return pos_lam diff --git a/src/anemoi/models/preprocessing/monomapper.py b/src/anemoi/models/preprocessing/monomapper.py new file mode 100644 index 00000000..0359a4c3 --- /dev/null +++ b/src/anemoi/models/preprocessing/monomapper.py @@ -0,0 +1,150 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.mappings import boxcox_converter +from anemoi.models.preprocessing.mappings import expm1_converter +from anemoi.models.preprocessing.mappings import inverse_boxcox_converter +from anemoi.models.preprocessing.mappings import log1p_converter +from anemoi.models.preprocessing.mappings import noop +from anemoi.models.preprocessing.mappings import sqrt_converter +from anemoi.models.preprocessing.mappings import square_converter + +LOGGER = logging.getLogger(__name__) + + +class Monomapper(BasePreprocessor, ABC): + """Remap and convert variables for single variables.""" + + supported_methods = { + method: [f, inv] + for method, f, inv in zip( + ["log1p", "sqrt", "boxcox", "none"], + [log1p_converter, sqrt_converter, boxcox_converter, noop], + [expm1_converter, square_converter, inverse_boxcox_converter, noop], + ) + } + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + self._create_remapping_indices(statistics) + self._validate_indices() + + def _validate_indices(self): + assert ( + len(self.index_training_input) + == len(self.index_inference_input) + == len(self.index_inference_output) + == len(self.index_training_out) + == len(self.remappers) + ), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.index_training_input)}, {len(self.index_training_out)}, {len(self.remappers)}" + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + + ( + self.remappers, + self.backmappers, + self.index_training_input, + self.index_training_out, + self.index_inference_input, + self.index_inference_output, + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + method = self.methods.get(name, self.default) + if method in self.supported_methods: + self.remappers.append(self.supported_methods[method][0]) + self.backmappers.append(self.supported_methods[method][1]) + self.index_training_input.append(name_to_index_training_input[name]) + if name in name_to_index_training_output: + self.index_training_out.append(name_to_index_training_output[name]) + else: + self.index_training_out.append(None) + if name in name_to_index_inference_input: + self.index_inference_input.append(name_to_index_inference_input[name]) + else: + self.index_inference_input.append(None) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + else: + raise KeyError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x, in_place: bool = True) -> torch.Tensor: + if not in_place: + x = x.clone() + if x.shape[-1] == self.num_training_input_vars: + idx = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + idx = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + for i, remapper in zip(idx, self.remappers): + if i is not None: + x[..., i] = remapper(x[..., i]) + return x + + def inverse_transform(self, x, in_place: bool = True) -> torch.Tensor: + if not in_place: + x = x.clone() + if x.shape[-1] == self.num_training_output_vars: + idx = self.index_training_out + elif x.shape[-1] == self.num_inference_output_vars: + idx = self.index_inference_output + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})", + ) + for i, backmapper in zip(idx, self.backmappers): + if i is not None: + x[..., i] = backmapper(x[..., i]) + return x diff --git a/src/anemoi/models/preprocessing/multimapper.py b/src/anemoi/models/preprocessing/multimapper.py new file mode 100644 index 00000000..f7772e48 --- /dev/null +++ b/src/anemoi/models/preprocessing/multimapper.py @@ -0,0 +1,306 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.mappings import atan2_converter +from anemoi.models.preprocessing.mappings import cos_converter +from anemoi.models.preprocessing.mappings import sin_converter + +LOGGER = logging.getLogger(__name__) + + +class Multimapper(BasePreprocessor, ABC): + """Remap single variable to 2 or more variables, or the other way around. + + cos_sin: + Remap the variable to cosine and sine. + Map output back to degrees. + + ``` + cos_sin: + "mwd" : ["cos_mwd", "sin_mwd"] + ``` + """ + + supported_methods = { + method: [f, inv] + for method, f, inv in zip( + ["cos_sin"], + [[cos_converter, sin_converter]], + [atan2_converter], + ) + } + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + """Initialize the remapper. + + Parameters + ---------- + config : DotDict + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables + statistics : dict + Data statistics dictionary + """ + super().__init__(config, data_indices, statistics) + self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False + self._create_remapping_indices(statistics) + self._validate_indices() + + def _validate_indices(self): + assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.remappers)}" + ) + assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_output)}, " + f"{len(self.index_inference_output)}, {len(self.remappers)}" + ) + assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( + "Error creating conversion indices: variables remapped in config.data.remapped " + "that have no remapping function defined. Preprocessed tensors contains empty columns." + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index + name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index + name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index + name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) + self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) + self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) + self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + self.indices_keep_training_input = [] + for key, item in self.data_indices.data.input.name_to_index.items(): + if key in self.data_indices.internal_data.input.name_to_index: + self.indices_keep_training_input.append(item) + self.indices_keep_inference_input = [] + for key, item in self.data_indices.model.input.name_to_index.items(): + if key in self.data_indices.internal_model.input.name_to_index: + self.indices_keep_inference_input.append(item) + self.indices_keep_training_output = [] + for key, item in self.data_indices.data.output.name_to_index.items(): + if key in self.data_indices.internal_data.output.name_to_index: + self.indices_keep_training_output.append(item) + self.indices_keep_inference_output = [] + for key, item in self.data_indices.model.output.name_to_index.items(): + if key in self.data_indices.internal_model.output.name_to_index: + self.indices_keep_inference_output.append(item) + + ( + self.index_training_input, + self.index_training_remapped_input, + self.index_inference_input, + self.index_inference_remapped_input, + self.index_training_output, + self.index_training_backmapped_output, + self.index_inference_output, + self.index_inference_backmapped_output, + self.remappers, + self.backmappers, + ) = ([], [], [], [], [], [], [], [], [], []) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + + method = self.methods.get(name, self.default) + + if method == "none": + continue + + if method == "cos_sin": + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output[name]) + self.index_inference_input.append(name_to_index_inference_input[name]) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + multiple_training_output, multiple_inference_output = [], [] + multiple_training_input, multiple_inference_input = [], [] + for name_dst in self.method_config[method][name]: + assert name_dst in self.data_indices.internal_data.input.name_to_index, ( + f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " + f"Remap {name} to {name_dst} in config.data.remapped. " + ) + multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) + multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) + multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) + if name_dst in name_to_index_inference_remapped_output: + multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) + else: + # this is a forcing variable. It is not in the inference output. + multiple_inference_output.append(None) + + self.index_training_remapped_input.append(multiple_training_input) + self.index_inference_remapped_input.append(multiple_inference_input) + self.index_training_backmapped_output.append(multiple_training_output) + self.index_inference_backmapped_output.append(multiple_inference_output) + + self.remappers.append([cos_converter, sin_converter]) + self.backmappers.append(atan2_converter) + + LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") + + else: + raise ValueError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Remap and convert the input tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this preprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + indices_remapped = self.index_training_remapped_input + indices_keep = self.indices_keep_training_input + target_number_columns = self.num_remapped_training_input_vars + + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + indices_remapped = self.index_inference_remapped_input + indices_keep = self.indices_keep_inference_input + target_number_columns = self.num_remapped_inference_input_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_preprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_preprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., : len(indices_keep)] = x[..., indices_keep] + + # Remap variables + for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): + if idx_src is not None: + for jj, ii in enumerate(idx_dst): + x_remapped[..., ii] = remapper[jj](x[..., idx_src]) + + return x_remapped + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Convert and remap the output tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this postprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_remapped_training_output_vars: + index = self.index_training_output + indices_remapped = self.index_training_backmapped_output + indices_keep = self.indices_keep_training_output + target_number_columns = self.num_training_output_vars + + elif x.shape[-1] == self.num_remapped_inference_output_vars: + index = self.index_inference_output + indices_remapped = self.index_inference_backmapped_output + indices_keep = self.indices_keep_inference_output + target_number_columns = self.num_inference_output_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_postprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_postprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., indices_keep] = x[..., : len(indices_keep)] + + # Backmap variables + for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): + if idx_dst is not None: + x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) + + return x_remapped + + def transform_loss_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Remap the loss mask. + + ``` + x : torch.Tensor + Loss mask + ``` + """ + # use indices at model output level + index = self.index_inference_backmapped_output + indices_remapped = self.index_inference_output + indices_keep = self.indices_keep_inference_output + + # create new loss mask with target number of columns + mask_remapped = torch.zeros( + mask.shape[:-1] + (mask.shape[-1] + len(indices_remapped),), dtype=mask.dtype, device=mask.device + ) + + # copy loss mask for variables that are not remapped + mask_remapped[..., : len(indices_keep)] = mask[..., indices_keep] + + # remap loss mask for rest of variables + for idx_src, idx_dst in zip(indices_remapped, index): + if idx_dst is not None: + for ii in idx_dst: + mask_remapped[..., ii] = mask[..., idx_src] + + return mask_remapped diff --git a/src/anemoi/models/preprocessing/remapper.py b/src/anemoi/models/preprocessing/remapper.py index cc888222..c3f39c2e 100644 --- a/src/anemoi/models/preprocessing/remapper.py +++ b/src/anemoi/models/preprocessing/remapper.py @@ -12,290 +12,36 @@ from abc import ABC from typing import Optional -import torch - from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.monomapper import Monomapper +from anemoi.models.preprocessing.multimapper import Multimapper LOGGER = logging.getLogger(__name__) -def cos_converter(x): - """Convert angle in degree to cos.""" - return torch.cos(x / 180 * torch.pi) - - -def sin_converter(x): - """Convert angle in degree to sin.""" - return torch.sin(x / 180 * torch.pi) - - -def atan2_converter(x): - """Convert cos and sin to angle in degree. - - Input: - x[..., 0]: cos - x[..., 1]: sin - """ - return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) - +class Remapper(BasePreprocessor, ABC): + """Remap and convert variables for single variables.""" -class BaseRemapperVariable(BasePreprocessor, ABC): - """Base class for Remapping Variables.""" - - def __init__( - self, + def __new__( + cls, config=None, data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, ) -> None: - """Initialize the remapper. - - Parameters - ---------- - config : DotDict - configuration object of the processor - data_indices : IndexCollection - Data indices for input and output variables - statistics : dict - Data statistics dictionary - """ - super().__init__(config, data_indices, statistics) - - def _validate_indices(self): - assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( - f"Error creating conversion indices {len(self.index_training_input)}, " - f"{len(self.index_inference_input)}, {len(self.remappers)}" - ) - assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( - f"Error creating conversion indices {len(self.index_training_output)}, " - f"{len(self.index_inference_output)}, {len(self.remappers)}" - ) - assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( - "Error creating conversion indices: variables remapped in config.data.remapped " - "that have no remapping function defined. Preprocessed tensors contains empty columns." - ) - - def _create_remapping_indices( - self, - statistics=None, - ): - """Create the parameter indices for remapping.""" - # list for training and inference mode as position of parameters can change - name_to_index_training_input = self.data_indices.data.input.name_to_index - name_to_index_inference_input = self.data_indices.model.input.name_to_index - name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index - name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index - name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index - name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index - name_to_index_training_output = self.data_indices.data.output.name_to_index - name_to_index_inference_output = self.data_indices.model.output.name_to_index - - self.num_training_input_vars = len(name_to_index_training_input) - self.num_inference_input_vars = len(name_to_index_inference_input) - self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) - self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) - self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) - self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) - self.num_training_output_vars = len(name_to_index_training_output) - self.num_inference_output_vars = len(name_to_index_inference_output) - self.indices_keep_training_input = [] - for key, item in self.data_indices.data.input.name_to_index.items(): - if key in self.data_indices.internal_data.input.name_to_index: - self.indices_keep_training_input.append(item) - self.indices_keep_inference_input = [] - for key, item in self.data_indices.model.input.name_to_index.items(): - if key in self.data_indices.internal_model.input.name_to_index: - self.indices_keep_inference_input.append(item) - self.indices_keep_training_output = [] - for key, item in self.data_indices.data.output.name_to_index.items(): - if key in self.data_indices.internal_data.output.name_to_index: - self.indices_keep_training_output.append(item) - self.indices_keep_inference_output = [] - for key, item in self.data_indices.model.output.name_to_index.items(): - if key in self.data_indices.internal_model.output.name_to_index: - self.indices_keep_inference_output.append(item) - - ( - self.index_training_input, - self.index_training_remapped_input, - self.index_inference_input, - self.index_inference_remapped_input, - self.index_training_output, - self.index_training_backmapped_output, - self.index_inference_output, - self.index_inference_backmapped_output, - self.remappers, - self.backmappers, - ) = ([], [], [], [], [], [], [], [], [], []) - - # Create parameter indices for remapping variables - for name in name_to_index_training_input: - - method = self.methods.get(name, self.default) - - if method == "none": - continue - - if method == "cos_sin": - self.index_training_input.append(name_to_index_training_input[name]) - self.index_training_output.append(name_to_index_training_output[name]) - self.index_inference_input.append(name_to_index_inference_input[name]) - if name in name_to_index_inference_output: - self.index_inference_output.append(name_to_index_inference_output[name]) - else: - # this is a forcing variable. It is not in the inference output. - self.index_inference_output.append(None) - multiple_training_output, multiple_inference_output = [], [] - multiple_training_input, multiple_inference_input = [], [] - for name_dst in self.method_config[method][name]: - assert name_dst in self.data_indices.internal_data.input.name_to_index, ( - f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " - f"Remap {name} to {name_dst} in config.data.remapped. " - ) - multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) - multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) - multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) - if name_dst in name_to_index_inference_remapped_output: - multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) - else: - # this is a forcing variable. It is not in the inference output. - multiple_inference_output.append(None) - - self.index_training_remapped_input.append(multiple_training_input) - self.index_inference_remapped_input.append(multiple_inference_input) - self.index_training_backmapped_output.append(multiple_training_output) - self.index_inference_backmapped_output.append(multiple_inference_output) - - self.remappers.append([cos_converter, sin_converter]) - self.backmappers.append(atan2_converter) - - LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") - - else: - raise ValueError[f"Unknown remapping method for {name}: {method}"] - - def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Remap and convert the input tensor. - - ``` - x : torch.Tensor - Input tensor - in_place : bool - Whether to process the tensor in place. - in_place is not possible for this preprocessor. - ``` - """ - # Choose correct index based on number of variables - if x.shape[-1] == self.num_training_input_vars: - index = self.index_training_input - indices_remapped = self.index_training_remapped_input - indices_keep = self.indices_keep_training_input - target_number_columns = self.num_remapped_training_input_vars - - elif x.shape[-1] == self.num_inference_input_vars: - index = self.index_inference_input - indices_remapped = self.index_inference_remapped_input - indices_keep = self.indices_keep_inference_input - target_number_columns = self.num_remapped_inference_input_vars - + _, _, method_config = cls._process_config(config) + monomappings = Monomapper.supported_methods + multimappings = Multimapper.supported_methods + if all(method in monomappings for method in method_config): + return Monomapper(config, data_indices, statistics) + elif all(method in multimappings for method in method_config): + return Multimapper(config, data_indices, statistics) + elif not ( + any(method in monomappings for method in method_config) + or any(method in multimappings for method in method_config) + ): + raise ValueError("No valid remapping method found.") else: - raise ValueError( - f"Input tensor ({x.shape[-1]}) does not match the training " - f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", - ) - - # create new tensor with target number of columns - x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) - if in_place and not self.printed_preprocessor_warning: - LOGGER.warning( - "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + raise NotImplementedError( + f"Not implemented: method_config contains a mix of monomapper and multimapper methods: {list(method_config.keys())}" ) - self.printed_preprocessor_warning = True - - # copy variables that are not remapped - x_remapped[..., : len(indices_keep)] = x[..., indices_keep] - - # Remap variables - for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): - if idx_src is not None: - for jj, ii in enumerate(idx_dst): - x_remapped[..., ii] = remapper[jj](x[..., idx_src]) - - return x_remapped - - def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Convert and remap the output tensor. - - ``` - x : torch.Tensor - Input tensor - in_place : bool - Whether to process the tensor in place. - in_place is not possible for this postprocessor. - ``` - """ - # Choose correct index based on number of variables - if x.shape[-1] == self.num_remapped_training_output_vars: - index = self.index_training_output - indices_remapped = self.index_training_backmapped_output - indices_keep = self.indices_keep_training_output - target_number_columns = self.num_training_output_vars - - elif x.shape[-1] == self.num_remapped_inference_output_vars: - index = self.index_inference_output - indices_remapped = self.index_inference_backmapped_output - indices_keep = self.indices_keep_inference_output - target_number_columns = self.num_inference_output_vars - - else: - raise ValueError( - f"Input tensor ({x.shape[-1]}) does not match the training " - f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", - ) - - # create new tensor with target number of columns - x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) - if in_place and not self.printed_postprocessor_warning: - LOGGER.warning( - "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", - ) - self.printed_postprocessor_warning = True - - # copy variables that are not remapped - x_remapped[..., indices_keep] = x[..., : len(indices_keep)] - - # Backmap variables - for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): - if idx_dst is not None: - x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) - - return x_remapped - - -class Remapper(BaseRemapperVariable): - """Remap and convert variables. - - cos_sin: - Remap the variable to cosine and sine. - Map output back to degrees. - - ``` - cos_sin: - "mwd" : ["cos_mwd", "sin_mwd"] - ``` - """ - - def __init__( - self, - config=None, - data_indices: Optional[IndexCollection] = None, - statistics: Optional[dict] = None, - ) -> None: - super().__init__(config, data_indices, statistics) - - self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False - - self._create_remapping_indices(statistics) - - self._validate_indices() diff --git a/tests/preprocessing/test_preprocessor_imputer.py b/tests/preprocessing/test_preprocessor_imputer.py index a22e261c..5d1035eb 100644 --- a/tests/preprocessing/test_preprocessor_imputer.py +++ b/tests/preprocessing/test_preprocessor_imputer.py @@ -297,6 +297,26 @@ def test_mask_saving(imputer_fixture, data_fixture, request): assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run." +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_loss_nan_mask(imputer_fixture, data_fixture, request): + """Check that the imputer correctly transforms a tensor with NaNs.""" + x, _ = request.getfixturevalue(data_fixture) + expected = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]]) # only prognostic and diagnostic variables + imputer = request.getfixturevalue(imputer_fixture) + imputer.transform(x) + assert torch.allclose( + imputer.loss_mask_training, expected + ), "Transform does not calculate NaN-mask for loss function scaling correctly." + + @pytest.mark.parametrize( ("imputer_fixture", "data_fixture"), [ diff --git a/tests/preprocessing/test_preprocessor_remapper.py b/tests/preprocessing/test_preprocessor_remapper.py index a0ece2a3..6d27f906 100644 --- a/tests/preprocessing/test_preprocessor_remapper.py +++ b/tests/preprocessing/test_preprocessor_remapper.py @@ -8,11 +8,13 @@ # nor does it submit to any jurisdiction. +import numpy as np import pytest import torch from omegaconf import DictConfig from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing.imputer import InputImputer from anemoi.models.preprocessing.remapper import Remapper @@ -41,22 +43,82 @@ def input_remapper(): return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) +@pytest.fixture() +def input_remapper_1d(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": { + "log1p": "d", + "sqrt": "q", + }, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) + + +@pytest.fixture() +def input_imputer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": { + "cos_sin": { + "d": ["cos_d", "sin_d"], + } + }, + "imputer": {"default": "none", "mean": ["y", "d"]}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + "remapped": { + "d": ["cos_d", "sin_d"], + }, + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0, 1.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputImputer(config=config.data.imputer, data_indices=data_indices, statistics=statistics) + + def test_remap_not_inplace(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) input_remapper(x, in_place=False) - assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])) + assert torch.allclose( + x, + torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]), + ) def test_remap(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) expected_output = torch.Tensor( - [[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]] + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], + [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795], + ] ) assert torch.allclose(input_remapper.transform(x), expected_output) def test_inverse_transform(input_remapper) -> None: - x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]]) + x = torch.Tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], + [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795], + ] + ) expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose(input_remapper.inverse_transform(x), expected_output) @@ -64,5 +126,77 @@ def test_inverse_transform(input_remapper) -> None: def test_remap_inverse_transform(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose( - input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x + input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), + x, + ) + + +def test_transform_loss_mask(input_imputer, input_remapper) -> None: + x = torch.Tensor([[1.0, np.nan, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, np.nan, 10.0]]) + expected_output = torch.Tensor([[1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 0.0]]) + input_imputer.transform(x) + input_remapper.transform(x) + loss_mask_training = input_imputer.loss_mask_training + loss_mask_training = input_remapper.transform_loss_mask(loss_mask_training) + assert torch.allclose(loss_mask_training, expected_output) + + +def test_monomap_transform(input_remapper_1d) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + expected_output = torch.Tensor( + [ + [1.0, 2.0, 3.0, np.sqrt(4.0), np.log1p(150.0), 5.0], + [6.0, 7.0, 8.0, np.sqrt(9.0), np.log1p(201.0), 10.0], + ] + ) + assert torch.allclose(input_remapper_1d.transform(x, in_place=False), expected_output) + # inference mode (without prognostic variables) + assert torch.allclose( + input_remapper_1d.transform( + x[..., input_remapper_1d.data_indices.data.todict()["input"]["full"]], in_place=False + ), + expected_output[..., input_remapper_1d.data_indices.data.todict()["input"]["full"]], ) + # this one actually changes the values in x so need to be last + assert torch.allclose(input_remapper_1d.transform(x), expected_output) + + +def test_monomap_inverse_transform(input_remapper_1d) -> None: + expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + y = torch.Tensor( + [ + [1.0, 2.0, 3.0, np.sqrt(4.0), np.log1p(150.0), 5.0], + [6.0, 7.0, 8.0, np.sqrt(9.0), np.log1p(201.0), 10.0], + ] + ) + assert torch.allclose(input_remapper_1d.inverse_transform(y, in_place=False), expected_output) + # inference mode (without prognostic variables) + assert torch.allclose( + input_remapper_1d.inverse_transform( + y[..., input_remapper_1d.data_indices.data.todict()["output"]["full"]], in_place=False + ), + expected_output[..., input_remapper_1d.data_indices.data.todict()["output"]["full"]], + ) + + +def test_unsupported_remapper(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": {"log1p": "q", "cos_sin": "d"}, + "forcing": [], + "diagnostic": [], + }, + } + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "q": 2, "d": 3} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + + with pytest.raises(NotImplementedError): + Remapper( + config=config.data.remapper, + data_indices=data_indices, + statistics=statistics, + )