diff --git a/docs/modules/distributed.rst b/docs/modules/distributed.rst index de8a98a6..13c33d07 100644 --- a/docs/modules/distributed.rst +++ b/docs/modules/distributed.rst @@ -2,10 +2,68 @@ Distributed ############# +******* + graph +******* +.. automodule:: anemoi.models.distributed.graph + :members: + :no-undoc-members: + :show-inheritance: + +************ + khop_edges +************ + +.. automodule:: anemoi.models.distributed.khop_edges + :members: + :no-undoc-members: + :show-inheritance: + +.. + ************* + +.. + primitives + +.. + ************* + +.. + .. automodule:: anemoi.models.distributed.primitives + +.. + :members: + +.. + :no-undoc-members: + +.. + :show-inheritance: + +******** + shapes +******** + +.. automodule:: anemoi.models.distributed.shapes + :members: + :no-undoc-members: + :show-inheritance: + +************* + transformer +************* + +.. automodule:: anemoi.models.distributed.transformer + :members: + :no-undoc-members: + :show-inheritance: +******* + utils +******* -.. automodule:: anemoi.models.distributed +.. automodule:: anemoi.models.distributed.utils :members: :no-undoc-members: :show-inheritance: diff --git a/pyproject.toml b/pyproject.toml index eb66d0b2..2282b95f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,3 @@ -#!/usr/bin/env python # (C) Copyright 2024 ECMWF. # # This software is licensed under the terms of the Apache Licence Version 2.0 @@ -10,12 +9,13 @@ # https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ [build-system] -requires = ["setuptools>=60", "setuptools-scm>=8.0"] +requires = ["setuptools>=61", "setuptools-scm>=8.0"] +build-backend = "setuptools.build_meta" [project] description = "A package to hold various functions to support training of ML models." name = "anemoi-models" - +readme = "README.md" dynamic = ["version"] license = { file = "LICENSE" } requires-python = ">=3.9" @@ -39,19 +39,34 @@ classifiers = [ "Operating System :: OS Independent", ] -dependencies = [] +dependencies = [ + "torch==2.3", + "torch-geometric==2.4", + "einops==0.6.1", + "hydra-core==1.3", + "anemoi-datasets==0.2.1", + "anemoi-utils==0.1.9", +] [project.optional-dependencies] docs = [ # For building the documentation - "sphinx", "sphinx_rtd_theme", "nbsphinx", "pandoc", "sphinx_argparse" ] all = [] -dev = [] +tests = ["pytest", "hypothesis"] + +dev = [ + "sphinx", + "sphinx_rtd_theme", + "nbsphinx", + "pandoc", + "pytest", + "hypothesis", +] [project.urls] Homepage = "https://github.com/ecmwf/anemoi-models/" diff --git a/src/anemoi/models/processors/__init__.py b/src/anemoi/models/data_indices/__init__.py similarity index 100% rename from src/anemoi/models/processors/__init__.py rename to src/anemoi/models/data_indices/__init__.py diff --git a/src/anemoi/models/data_indices/collection.py b/src/anemoi/models/data_indices/collection.py new file mode 100644 index 00000000..3e325aaa --- /dev/null +++ b/src/anemoi/models/data_indices/collection.py @@ -0,0 +1,74 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import operator + +import yaml +from omegaconf import OmegaConf + +from anemoi.models.data_indices.index import BaseIndex +from anemoi.models.data_indices.index import DataIndex +from anemoi.models.data_indices.index import ModelIndex +from anemoi.models.data_indices.tensor import BaseTensorIndex +from anemoi.models.data_indices.tensor import InputTensorIndex +from anemoi.models.data_indices.tensor import OutputTensorIndex + + +class IndexCollection: + """Collection of data and model indices.""" + + def __init__(self, config, name_to_index) -> None: + self.config = OmegaConf.to_container(config, resolve=True) + + self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True) + self.diagnostic = ( + [] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True) + ) + + assert set(self.diagnostic).isdisjoint(self.forcing), ( + f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ", + "Please drop them at a dataset-level to exclude them from the training data.", + ) + self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1))) + name_to_index_model_input = { + name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic) + } + name_to_index_model_output = { + name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing) + } + + self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index) + self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output) + + def __repr__(self) -> str: + return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})" + + def __eq__(self, other): + if not isinstance(other, IndexCollection): + # don't attempt to compare against unrelated types + return NotImplemented + + return self.model == other.model and self.data == other.data + + def __getitem__(self, key): + return getattr(self, key) + + def todict(self): + return { + "data": self.data.todict(), + "model": self.model.todict(), + } + + @staticmethod + def representer(dumper, data): + return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data)) + + +for cls in [BaseTensorIndex, InputTensorIndex, OutputTensorIndex, BaseIndex, DataIndex, ModelIndex, IndexCollection]: + yaml.add_representer(cls, cls.representer) diff --git a/src/anemoi/models/data_indices/index.py b/src/anemoi/models/data_indices/index.py new file mode 100644 index 00000000..1c8b032c --- /dev/null +++ b/src/anemoi/models/data_indices/index.py @@ -0,0 +1,93 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from anemoi.models.data_indices.tensor import InputTensorIndex +from anemoi.models.data_indices.tensor import OutputTensorIndex + + +class BaseIndex: + """Base class for data and model indices.""" + + def __init__(self) -> None: + self.input = NotImplementedError + self.output = NotImplementedError + + def __eq__(self, other): + if not isinstance(other, BaseIndex): + # don't attempt to compare against unrelated types + return NotImplemented + + return self.input == other.input and self.output == other.output + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(input={self.input}, output={self.output})" + + def __getitem__(self, key): + return getattr(self, key) + + def todict(self): + return { + "input": self.input.todict(), + "output": self.output.todict(), + } + + @staticmethod + def representer(dumper, data): + return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data)) + + +class DataIndex(BaseIndex): + """Indexing for data variables.""" + + def __init__(self, diagnostic, forcing, name_to_index) -> None: + self._diagnostic = diagnostic + self._forcing = forcing + self._name_to_index = name_to_index + self.input = InputTensorIndex( + includes=forcing, + excludes=diagnostic, + name_to_index=name_to_index, + ) + + self.output = OutputTensorIndex( + includes=diagnostic, + excludes=forcing, + name_to_index=name_to_index, + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, name_to_index={self._name_to_index})" + + +class ModelIndex(BaseIndex): + """Indexing for model variables.""" + + def __init__(self, diagnostic, forcing, name_to_index_model_input, name_to_index_model_output) -> None: + self._diagnostic = diagnostic + self._forcing = forcing + self._name_to_index_model_input = name_to_index_model_input + self._name_to_index_model_output = name_to_index_model_output + self.input = InputTensorIndex( + includes=forcing, + excludes=[], + name_to_index=name_to_index_model_input, + ) + + self.output = OutputTensorIndex( + includes=diagnostic, + excludes=[], + name_to_index=name_to_index_model_output, + ) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(diagnostic={self._input}, forcing={self._output}, " + f"name_to_index_model_input={self._name_to_index_model_input}, " + f"name_to_index_model_output={self._name_to_index_model_output})" + ) diff --git a/src/anemoi/models/data_indices/tensor.py b/src/anemoi/models/data_indices/tensor.py new file mode 100644 index 00000000..c7306cf9 --- /dev/null +++ b/src/anemoi/models/data_indices/tensor.py @@ -0,0 +1,114 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import torch + + +class BaseTensorIndex: + """Indexing for variables in index as Tensor.""" + + def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None: + """Initialize indexing tensors from includes and excludes using name_to_index. + + Parameters + ---------- + includes : list[str] + Variables to include in the indexing that are exclusive to this indexing. + e.g. Forcing variables for the input indexing, diagnostic variables for the output indexing + excludes : list[str] + Variables to exclude from the indexing. + e.g. Diagnostic variables for the input indexing, forcing variables for the output indexing + name_to_index : dict[str, int] + Dictionary mapping variable names to their index in the Tensor. + """ + self.includes = includes + self.excludes = excludes + self.name_to_index = name_to_index + + assert set(self.excludes).issubset( + self.name_to_index.keys(), + ), f"Data indexing has invalid entries {[var for var in self.excludes if var not in self.name_to_index]}, not in dataset." + assert set(self.includes).issubset( + self.name_to_index.keys(), + ), f"Data indexing has invalid entries {[var for var in self.includes if var not in self.name_to_index]}, not in dataset." + + self.full = self._build_idx_from_excludes() + self._only = self._build_idx_from_includes() + self._removed = self._build_idx_from_includes(self.excludes) + self.prognostic = self._build_idx_prognostic() + self.diagnostic = NotImplementedError + self.forcing = NotImplementedError + + def __len__(self) -> int: + return len(self.full) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(includes={self.includes}, excludes={self.excludes}, name_to_index={self.name_to_index})" + + def __eq__(self, other): + if not isinstance(other, BaseTensorIndex): + # don't attempt to compare against unrelated types + return NotImplemented + + return ( + torch.allclose(self.full, other.full) + and torch.allclose(self._only, other._only) + and torch.allclose(self._removed, other._removed) + and torch.allclose(self.prognostic, other.prognostic) + and torch.allclose(self.diagnostic, other.diagnostic) + and torch.allclose(self.forcing, other.forcing) + and self.includes == other.includes + and self.excludes == other.excludes + ) + + def __getitem__(self, key): + return getattr(self, key) + + def todict(self): + return { + "full": self.full, + "prognostic": self.prognostic, + "diagnostic": self.diagnostic, + "forcing": self.forcing, + } + + @staticmethod + def representer(dumper, data): + return dumper.represent_scalar(f"!{data.__class__.__name__}", repr(data)) + + def _build_idx_from_excludes(self, excludes=None) -> "torch.Tensor[int]": + if excludes is None: + excludes = self.excludes + return torch.Tensor(sorted(i for name, i in self.name_to_index.items() if name not in excludes)).to(torch.int) + + def _build_idx_from_includes(self, includes=None) -> "torch.Tensor[int]": + if includes is None: + includes = self.includes + return torch.Tensor(sorted(self.name_to_index[name] for name in includes)).to(torch.int) + + def _build_idx_prognostic(self) -> "torch.Tensor[int]": + return self._build_idx_from_excludes(self.includes + self.excludes) + + +class InputTensorIndex(BaseTensorIndex): + """Indexing for input variables.""" + + def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None: + super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index) + self.forcing = self._only + self.diagnostic = self._removed + + +class OutputTensorIndex(BaseTensorIndex): + """Indexing for output variables.""" + + def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None: + super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index) + self.forcing = self._removed + self.diagnostic = self._only diff --git a/src/anemoi/models/distributed/graph.py b/src/anemoi/models/distributed/graph.py new file mode 100644 index 00000000..3d2e3d27 --- /dev/null +++ b/src/anemoi/models/distributed/graph.py @@ -0,0 +1,298 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import torch +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup + +from anemoi.models.distributed.primitives import _gather +from anemoi.models.distributed.primitives import _reduce +from anemoi.models.distributed.primitives import _split + + +def shard_tensor( + input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup, gather_in_backward: bool = True +) -> Tensor: + """Shard tensor. + + Keeps only part of the tensor that is relevant for the current rank. + + Parameters + ---------- + input_ : Tensor + Input + dim : int + dimension along which to shard + shapes : tuple + Shapes of sharded Tensors + mgroup : ProcessGroup + model communication group + gather_in_backward : bool + perform gather in backward, default True + + Returns + ------- + Tensor + Sharded tensor. + """ + return _ShardParallelSection.apply(input_, dim, shapes, gather_in_backward, mgroup) + + +def gather_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: + """Gather tensor. + + Gathers tensor shards from ranks. + + Parameters + ---------- + input_ : Tensor + Input + dim : int + dimension along which to gather + shapes : tuple + Shapes of sharded Tensors + mgroup : ProcessGroup + model communication group + + Returns + ------- + Tensor + Gathered tensor. + """ + return _GatherParallelSection.apply(input_, dim, shapes, mgroup) + + +def reduce_tensor(input_: Tensor, mgroup: ProcessGroup) -> Tensor: + """Reduce tensor. + + Reduces tensor across ranks. + + Parameters + ---------- + input_ : Tensor + Input + mgroup : ProcessGroup + model communication group + + Returns + ------- + Tensor + Reduced tensor. + """ + return _ReduceParallelSection.apply(input_, mgroup) + + +def sync_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: + """Sync tensor. + + Perform a gather in the forward pass and an allreduce followed by a split in the backward pass. + + Parameters + ---------- + input_ : Tensor + Input + dim : int + dimension along which to gather + shapes : tuple + Shapes of sharded Tensors + mgroup : ProcessGroup + model communication group + + Returns + ------- + Tensor + Synced tensor. + """ + return _SyncParallelSection.apply(input_, dim, shapes, mgroup) + + +def reduce_shard_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: + """Reduces and then shards tensor. + + Perform an allreduce followed by a split in the forward pass and a gather in the backward pass. + + Parameters + ---------- + input_ : Tensor + Input + dim : int + dimension along which to gather + shapes : tuple + Shapes of sharded Tensors + mgroup : ProcessGroup + model communication group + + Returns + ------- + Tensor + Reduced sharded tensor. + """ + return _ReduceShardParallelSection.apply(input_, dim, shapes, mgroup) + + +class _SyncParallelSection(torch.autograd.Function): + """Sync the input from parallel section.""" + + @staticmethod + def forward(ctx, input_, dim_, shapes_, mgroup_): + ctx.dim = dim_ + ctx.comm_group = mgroup_ + ctx.shapes = shapes_ + if mgroup_: + return _gather(input_, dim_, shapes_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.comm_group: + grad_output = _reduce(grad_output, group=ctx.comm_group) + return ( + _split(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group), + None, + None, + None, + ) + return grad_output, None, None, None + + +class _ReduceShardParallelSection(torch.autograd.Function): + """All-reduce and shard the input from the parallel section.""" + + # Modified from + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + @staticmethod + def forward(ctx, input_, dim_, shapes_, mgroup_): + ctx.dim = dim_ + ctx.comm_group = mgroup_ + ctx.shapes = shapes_ + if mgroup_: + input_ = _reduce(input_, group=mgroup_) + return _split(input_, dim_, shapes_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.comm_group: + return ( + _gather(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group), + None, + None, + None, + ) + return grad_output, None, None, None + + +class _ShardParallelSection(torch.autograd.Function): + """Split the input and keep only the relevant chunck to the rank.""" + + # Modified from + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + @staticmethod + def forward(ctx, input_, dim_, shapes_, gather_in_backward_, mgroup_): + ctx.dim = dim_ + ctx.comm_group = mgroup_ + ctx.shapes = shapes_ + ctx.gather_in_backward = gather_in_backward_ + if mgroup_: + return _split(input_, dim_, shapes_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.comm_group: + return ( + _gather( + grad_output, ctx.dim, ctx.shapes, gather_in_backward=ctx.gather_in_backward, group=ctx.comm_group + ), + None, + None, + None, + None, + ) + return grad_output, None, None, None, None + + +class _GatherParallelSection(torch.autograd.Function): + """Gather the input from parallel section and concatenate.""" + + # Modified from + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + @staticmethod + def forward(ctx, input_, dim_, shapes_, mgroup_): + ctx.dim = dim_ + ctx.comm_group = mgroup_ + ctx.shapes = shapes_ + if mgroup_: + return _gather(input_, dim_, shapes_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.comm_group: + return ( + _split(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group), + None, + None, + None, + ) + return grad_output, None, None, None + + +class _ReduceParallelSection(torch.autograd.Function): + """All-reduce the input from the parallel section.""" + + @staticmethod + def forward(ctx, input_, mgroup_): + if mgroup_: + return _reduce(input_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None diff --git a/src/anemoi/models/distributed/khop_edges.py b/src/anemoi/models/distributed/khop_edges.py new file mode 100644 index 00000000..5b4bd815 --- /dev/null +++ b/src/anemoi/models/distributed/khop_edges.py @@ -0,0 +1,102 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from typing import Optional +from typing import Union + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.typing import Adj +from torch_geometric.utils import bipartite_subgraph +from torch_geometric.utils import k_hop_subgraph +from torch_geometric.utils import mask_to_index + + +def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops: int = 1) -> tuple[Adj, Tensor]: + """Return 1 hop subgraph. + + Parameters + ---------- + nodes : Tensor + destination nodes + edge_attr : Tensor + edge attributes + edge_index : Adj + edge index + num_hops: int, Optional, by default 1 + number of required hops + + Returns + ------- + tuple[Adj, Tensor] + K-hop subgraph of edge index and edge attributes + """ + _, edge_index_k, _, edge_mask_k = k_hop_subgraph( + node_idx=nodes, num_hops=num_hops, edge_index=edge_index, directed=True + ) + + return edge_attr[mask_to_index(edge_mask_k)], edge_index_k + + +def sort_edges_1hop( + num_nodes: Union[int, tuple[int, int]], + edge_attr: Tensor, + edge_index: Adj, + mgroup: Optional[ProcessGroup] = None, +) -> tuple[Adj, Tensor, list, list]: + """Rearanges edges into 1 hop neighbourhoods for sharding across GPUs. + + Parameters + ---------- + num_nodes : Union[int, tuple[int, int]] + Number of (target) nodes in Graph + edge_attr : Tensor + edge attributes + edge_index : Adj + edge index + mgroup : ProcessGroup + model communication group + + Returns + ------- + tuple[Adj, Tensor, list, list] + edges sorted according to k hop neigh., edge attributes of sorted edges, + shapes of edge indices for partitioning between GPUs, shapes of edge attr for partitioning between GPUs + """ + if mgroup: + num_chunks = dist.get_world_size(group=mgroup) + + if isinstance(num_nodes, int): + node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks) + else: + nodes_src = torch.arange(num_nodes[0], device=edge_index.device) + node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks) + + edge_index_list = [] + edge_attr_list = [] + for node_chunk in node_chunks: + if isinstance(num_nodes, int): + edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index) + else: + edge_index_chunk, edge_attr_chunk = bipartite_subgraph( + (nodes_src, node_chunk), + edge_index, + edge_attr, + size=(num_nodes[0], num_nodes[1]), + ) + edge_index_list.append(edge_index_chunk) + edge_attr_list.append(edge_attr_chunk) + edge_index_shapes = [x.shape for x in edge_index_list] + edge_attr_shapes = [x.shape for x in edge_attr_list] + + return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes + + return edge_attr, edge_index, [], [] diff --git a/src/anemoi/models/distributed/primitives.py b/src/anemoi/models/distributed/primitives.py new file mode 100644 index 00000000..39b7dea8 --- /dev/null +++ b/src/anemoi/models/distributed/primitives.py @@ -0,0 +1,142 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup + +from anemoi.models.distributed.utils import get_memory_format + + +def _split(input_: Tensor, dim_: int, shapes_: tuple, group: Optional[ProcessGroup] = None) -> Tensor: + """Split the tensor along dim and keep the relevant slice.""" + # Modified from + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + # get input format + input_format = get_memory_format(input_) + + # Bypass the function if we are using only 1 GPU. + comm_size = dist.get_world_size(group=group) + if comm_size == 1: + return input_ + + # sanity checks + assert dim_ < input_.dim(), f"Error, cannot split along {dim_} for tensor with {input_.dim()} dimensions." + + input_list = torch.split(input_, [x[dim_] for x in shapes_], dim=dim_) + + rank = dist.get_rank(group=group) + output = input_list[rank].contiguous(memory_format=input_format) + + return output + + +def _gather( + input_: Tensor, + dim_: int, + shapes: tuple, + gather_in_backward: Optional[bool] = True, + group: Optional[ProcessGroup] = None, +) -> Tensor: + """Gather tensors and concatenate along the last dimension.""" + # Modified from + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + # get input format + input_format = get_memory_format(input_) + + comm_size = dist.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if comm_size == 1: + return input_ + + # sanity checks + assert dim_ < input_.dim(), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions." + + # Size and dimension. + comm_rank = dist.get_rank(group=group) + + input_ = input_.contiguous(memory_format=input_format) + tensor_list = [ + torch.empty( + shapes[rank], dtype=input_.dtype, layout=input_.layout, device=input_.device, memory_format=input_format + ) + for rank in range(comm_size) + ] + + tensor_list[comm_rank] = input_ + if gather_in_backward: + dist.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format) + + return output + + +def _reduce(input_: Tensor, use_fp32: Optional[bool] = True, group: Optional[ProcessGroup] = None) -> Tensor: + """All-reduce the input tensor across model parallel group.""" + # Modified from + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + + comm_size = dist.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if comm_size == 1: + return input_ + + # All-reduce. + if use_fp32: + dtype = input_.dtype + inputf_ = input_.float() + dist.all_reduce(inputf_, group=group) + input_ = inputf_.to(dtype) + else: + dist.all_reduce(input_, group=group) + + return input_ diff --git a/src/anemoi/models/distributed/shapes.py b/src/anemoi/models/distributed/shapes.py new file mode 100644 index 00000000..7ba9efb5 --- /dev/null +++ b/src/anemoi/models/distributed/shapes.py @@ -0,0 +1,28 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup + + +def get_shape_shards(tensor: Tensor, dim: int, model_comm_group: Optional[ProcessGroup] = None) -> list: + """Get shape of tensor shards.""" + assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" + + comm_size = 1 if not model_comm_group else dist.get_world_size(group=model_comm_group) + return [list(x.shape) for x in torch.tensor_split(tensor, comm_size, dim=dim)] + + +def change_channels_in_shape(shape_list: list, channels: int) -> list: + """Change the number of channels in the tensor shape definition list.""" + return [x[:-1] + [channels] for x in shape_list] if shape_list else [] diff --git a/src/anemoi/models/distributed/transformer.py b/src/anemoi/models/distributed/transformer.py new file mode 100644 index 00000000..4be22413 --- /dev/null +++ b/src/anemoi/models/distributed/transformer.py @@ -0,0 +1,173 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from typing import Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup + +from anemoi.models.distributed.utils import get_memory_format + + +def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor: + """Apply all_to_all along the head dimension. + + Split input along dimension dim_split and join after all_to_all along dimesion + dim_concatenate. + """ + comm_size = dist.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if comm_size == 1: + return input_ + + # get input format + input_format = get_memory_format(input_) + + input_list = [x.contiguous() for x in torch.tensor_split(input_, comm_size, dim=-3)] # do we need contiguous? + + input_shape = [x.shape for x in input_list] # (b ... h n c) + heads_per_rank = [x.shape[-3] for x in input_list] + channels_per_rank = [x.shape[-1] for x in input_list] + seq_per_rank = [x[0] for x in shapes] + + output_list = [ + torch.empty( + (*input_shape[rank][:-3], heads_per_rank[rank], seq_per_rank[rank], channels_per_rank[rank]), + dtype=input_.dtype, + layout=input_.layout, + device=input_.device, + memory_format=input_format, + ) + for rank in range(comm_size) + ] + + dist.all_to_all(output_list, input_list, group=group) + + # Note: torch.cat already creates a contiguous tensor. + return torch.cat(output_list, dim=-2).contiguous(memory_format=input_format) + + +def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor: + """Apply all_to_all along the sequence dimension. + + Split input along dimension dim_split and join after all_to_all along dimesion + dim_concatenate. + """ + comm_size = dist.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if comm_size == 1: + return input_ + + comm_rank = dist.get_rank(group=group) + + # get input format + input_format = get_memory_format(input_) + + input_list = [x.contiguous() for x in torch.tensor_split(input_, comm_size, dim=-2)] # do we need contiguous? + + output_list = [torch.empty_like(input_list[comm_rank]) for _ in range(comm_size)] + + dist.all_to_all(output_list, input_list, group=group) + + # Note: torch.cat already creates a contiguous tensor. + return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format) + + +def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor: + """Sync tensor. + + Gathers e.g query, key or value tensor along sequence dimension via all to all communication + and shards along head dimension for parallel self-attention computation. + Expected format is (batch_size, ... heads, sequence_length, channels) + + Parameters + ---------- + input_ : Tensor + Input + shapes: list + shapes of shards + mgroup : ProcessGroup + model communication group + + Returns + ------- + Tensor + Sharded heads. + """ + return _SplitHeadsParallelSection.apply(input_, shapes, mgroup) + + +def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor: + """Sync tensor. + + Gathers e.g query, key or value tensor along head dimension via all to all communication + and shards along sequence dimension for parallel mlp and layernorm computation. + Expected format is (batch_size, ... heads, sequence_length, channels) + + Parameters + ---------- + input_ : Tensor + Input + shapes: list + shapes of shards + mgroup : ProcessGroup + model communication group + + Returns + ------- + Tensor + Sharded sequence + """ + return _SplitSequenceParallelSection.apply(input_, shapes, mgroup) + + +class _SplitHeadsParallelSection(torch.autograd.Function): + """Sync the input from parallel section.""" + + @staticmethod + def forward(ctx, input_, shapes_, mgroup_): + ctx.shapes = shapes_ + ctx.comm_group = mgroup_ + if mgroup_: + return _headsalltoall(input_, shapes_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.comm_group: + return ( + _seqalltoall(grad_output, ctx.shapes, group=ctx.comm_group), + None, + None, + ) + return grad_output, None, None + + +class _SplitSequenceParallelSection(torch.autograd.Function): + """Sync the input from parallel section.""" + + @staticmethod + def forward(ctx, input_, shapes_, mgroup_): + ctx.shapes = shapes_ + ctx.comm_group = mgroup_ + if mgroup_: + return _seqalltoall(input_, shapes_, group=mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.comm_group: + return ( + _headsalltoall(grad_output, ctx.shapes, group=ctx.comm_group), + None, + None, + ) + return grad_output, None, None diff --git a/src/anemoi/models/distributed/utils.py b/src/anemoi/models/distributed/utils.py new file mode 100644 index 00000000..e41d145c --- /dev/null +++ b/src/anemoi/models/distributed/utils.py @@ -0,0 +1,31 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import torch +from torch import Tensor + + +def get_memory_format(tensor: Tensor): + """Helper routine to get the memory format.""" + # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + if tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + return torch.contiguous_format diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py new file mode 100644 index 00000000..18e25e0a --- /dev/null +++ b/src/anemoi/models/interface/__init__.py @@ -0,0 +1,110 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import uuid + +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch_geometric.data import HeteroData + +from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec +from anemoi.models.preprocessing import Processors + + +class AnemoiModelInterface(torch.nn.Module): + """An interface for Anemoi models. + + This class is a wrapper around the Anemoi model that includes pre-processing and post-processing steps. + It inherits from the PyTorch Module class. + + Attributes + ---------- + config : DotConfig + Configuration settings for the model. + id : str + A unique identifier for the model instance. + multi_step : bool + Whether the model uses multi-step input. + graph_data : HeteroData + Graph data for the model. + statistics : dict + Statistics for the data. + metadata : dict + Metadata for the model. + data_indices : dict + Indices for the data. + pre_processors : Processors + Pre-processing steps to apply to the data before passing it to the model. + post_processors : Processors + Post-processing steps to apply to the model's output. + model : AnemoiModelEncProcDec + The underlying Anemoi model. + """ + + def __init__( + self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + ) -> None: + super().__init__() + self.config = config + self.id = str(uuid.uuid4()) + self.multi_step = self.config.training.multistep_input + self.graph_data = graph_data + self.statistics = statistics + self.metadata = metadata + self.data_indices = data_indices + self._build_model() + + def _build_model(self) -> None: + """Builds the model and pre- and post-processors.""" + # Instantiate processors + processors = [ + [name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)] + for name, processor in self.config.data.processors.items() + ] + + # Assign the processor list pre- and post-processors + self.pre_processors = Processors(processors) + self.post_processors = Processors(processors, inverse=True) + + # Instantiate the model (Can be generalised to other models in the future, here we use AnemoiModelEncProcDec) + self.model = AnemoiModelEncProcDec( + config=self.config, data_indices=self.data_indices, graph_data=self.graph_data + ) + + # Use the forward method of the model directly + self.forward = self.model.forward + + def predict_step(self, batch: torch.Tensor) -> torch.Tensor: + """Prediction step for the model. + + Parameters + ---------- + batch : torch.Tensor + Input batched data. + + Returns + ------- + torch.Tensor + Predicted data. + """ + batch = self.pre_processors(batch, in_place=False) + + with torch.no_grad(): + + assert ( + len(batch.shape) == 4 + ), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" + # Dimensions are + # batch, timesteps, horizonal space, variables + x = batch[:, 0 : self.multi_step, None, ...] # add dummy ensemble dimension as 3rd index + + y_hat = self(x) + + return self.post_processors(y_hat, in_place=False) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index ab057eb2..2063ad02 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -24,8 +24,8 @@ else: _FLASH_ATTENTION_AVAILABLE = True -from anemoi.models.distributed.helpers import shard_heads -from anemoi.models.distributed.helpers import shard_sequence +from anemoi.models.distributed.transformer import shard_heads +from anemoi.models.distributed.transformer import shard_sequence LOGGER = logging.getLogger(__name__) @@ -73,11 +73,14 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - query, key, value = map( - lambda t: einops.rearrange( - t, "(batch grid) (heads vars) -> batch heads grid vars", batch=batch_size, heads=self.num_heads - ), - (query, key, value), + query, key, value = ( + einops.rearrange( + t, + "(batch grid) (heads vars) -> batch heads grid vars", + batch=batch_size, + heads=self.num_heads, + ) + for t in (query, key, value) ) query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) @@ -85,9 +88,8 @@ def forward( value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) if _FLASH_ATTENTION_AVAILABLE: - query, key, value = map( - lambda t: einops.rearrange(t, "batch heads grid vars -> batch grid heads vars"), - (query, key, value), + query, key, value = ( + einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) ) out = self.attention(query, key, value, causal=False, window_size=self.window_size) out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index f30323eb..c43d0592 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -11,7 +11,6 @@ from abc import ABC from abc import abstractmethod from typing import Optional -from typing import Tuple import einops import torch @@ -22,10 +21,10 @@ from torch_geometric.typing import OptPairTensor from torch_geometric.typing import Size -from anemoi.models.distributed.helpers import shard_heads -from anemoi.models.distributed.helpers import shard_sequence -from anemoi.models.distributed.helpers import shard_tensor -from anemoi.models.distributed.helpers import sync_tensor +from anemoi.models.distributed.graph import shard_tensor +from anemoi.models.distributed.graph import sync_tensor +from anemoi.models.distributed.transformer import shard_heads +from anemoi.models.distributed.transformer import shard_sequence from anemoi.models.layers.attention import MultiHeadSelfAttention from anemoi.models.layers.conv import GraphConv from anemoi.models.layers.conv import GraphTransformerConv @@ -46,11 +45,11 @@ def forward( x: OptPairTensor, edge_attr: torch.Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, batch_size: int, size: Optional[Size] = None, model_comm_group: Optional[ProcessGroup] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: ... + ) -> tuple[torch.Tensor, torch.Tensor]: ... class TransformerProcessorBlock(BaseBlock): @@ -148,10 +147,10 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, - ) -> Tuple[Tensor, Tensor]: ... + ) -> tuple[Tensor, Tensor]: ... class GraphConvProcessorBlock(GraphConvBaseBlock): @@ -182,10 +181,10 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: x_in = sync_tensor(x, 0, shapes[1], model_comm_group) @@ -197,9 +196,8 @@ def forward( out1, edges_out1 = self.conv(x_in, edge_attr_list[i], edge_index_list[i], size=size) edges_out.append(edges_out1) if i == 0: - out = out1 - else: - out = out + out1 + out = torch.zeros_like(out1) + out = out + out1 edges_new = torch.cat(edges_out, dim=0) else: out, edges_new = self.conv(x_in, edge_attr, edge_index, size=size) @@ -239,10 +237,10 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: x_src = sync_tensor(x[0], 0, shapes[0], model_comm_group) x_dst = sync_tensor(x[1], 0, shapes[1], model_comm_group) @@ -256,9 +254,8 @@ def forward( out1, edges_out1 = self.conv(x_in, edge_attr_list[i], edge_index_list[i], size=size) edges_out.append(edges_out1) if i == 0: - out = out1 - else: - out = out + out1 + out = torch.zeros_like(out1) + out = out + out1 edges_new = torch.cat(edges_out, dim=0) else: out, edges_new = self.conv(x_in, edge_attr, edge_index, size=size) @@ -267,10 +264,8 @@ def forward( nodes_new_dst = self.node_mlp(torch.cat([x[1], out], dim=1)) + x[1] - if self.update_src_nodes: # update only needed in forward mapper - nodes_new_src = self.node_mlp(torch.cat([x[0], x[0]], dim=1)) + x[0] - else: - nodes_new_src = x[0] + # update only needed in forward mapper + nodes_new_src = x[0] if not self.update_src_nodes else self.node_mlp(torch.cat([x[0], x[0]], dim=1)) + x[0] nodes_new = (nodes_new_src, nodes_new_dst) @@ -360,41 +355,42 @@ def shard_qkve_heads( key: Tensor, value: Tensor, edges: Tensor, - shapes: Tuple, + shapes: tuple, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Shards qkv and edges along head dimension.""" - shape_src_nodes, shape_dst_nodes, shape_edges = shapes - query, key, value, edges = map( - lambda t: einops.rearrange( + query, key, value, edges = ( + einops.rearrange( t, "(batch grid) (heads vars) -> batch heads grid vars", heads=self.num_heads, vars=self.out_channels_conv, batch=batch_size, - ), - (query, key, value, edges), + ) + for t in (query, key, value, edges) ) query = shard_heads(query, shapes=shape_dst_nodes, mgroup=model_comm_group) key = shard_heads(key, shapes=shape_src_nodes, mgroup=model_comm_group) value = shard_heads(value, shapes=shape_src_nodes, mgroup=model_comm_group) edges = shard_heads(edges, shapes=shape_edges, mgroup=model_comm_group) - query, key, value, edges = map( - lambda t: einops.rearrange(t, "batch heads grid vars -> (batch grid) heads vars"), - (query, key, value, edges), + query, key, value, edges = ( + einops.rearrange(t, "batch heads grid vars -> (batch grid) heads vars") for t in (query, key, value, edges) ) return query, key, value, edges def shard_output_seq( - self, out: Tensor, shapes: Tuple, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + self, + out: Tensor, + shapes: tuple, + batch_size: int, + model_comm_group: Optional[ProcessGroup] = None, ) -> Tensor: """Shards Tensor sequence dimension.""" - shape_dst_nodes = shapes[1] out = einops.rearrange(out, "(batch grid) heads vars -> batch heads grid vars", batch=batch_size) @@ -409,7 +405,7 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, @@ -471,7 +467,7 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, @@ -495,11 +491,8 @@ def forward( query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) - # TODO: Is this alright? - if self.training: - num_chunks = self.num_chunks - else: - num_chunks = 4 # reduce memory for inference + # TODO: remove magic number + num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference if num_chunks > 1: edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1) @@ -514,9 +507,8 @@ def forward( size=size, ) if i == 0: - out = out1 - else: - out = out + out1 + out = torch.zeros_like(out1) + out = out + out1 else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) @@ -526,10 +518,7 @@ def forward( out = out + x_skip[1] nodes_new_dst = self.node_dst_mlp(out) + out - if self.update_src_nodes: - nodes_new_src = self.node_src_mlp(x_skip[0]) + x_skip[0] - else: - nodes_new_src = x_skip[0] + nodes_new_src = self.node_src_mlp(x_skip[0]) + x_skip[0] if self.update_src_nodes else x_skip[0] nodes_new = (nodes_new_src, nodes_new_dst) @@ -590,7 +579,7 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, @@ -613,10 +602,7 @@ def forward( query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) # TODO: Is this alright? - if self.training: - num_chunks = self.num_chunks - else: - num_chunks = 4 # reduce memory for inference + num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference if num_chunks > 1: edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1) @@ -631,9 +617,8 @@ def forward( size=size, ) if i == 0: - out = out1 - else: - out = out + out1 + out = torch.zeros_like(out1) + out = out + out1 else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 1e7202ec..61dec34b 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -7,10 +7,10 @@ # nor does it submit to any jurisdiction. # +import logging from abc import ABC from abc import abstractmethod from typing import Optional -from typing import Tuple from torch import Tensor from torch import nn @@ -24,6 +24,8 @@ from anemoi.models.layers.block import TransformerProcessorBlock from anemoi.models.layers.mlp import MLP +LOGGER = logging.getLogger(__name__) + class BaseProcessorChunk(nn.Module, ABC): """Base Processor Chunk.""" @@ -42,17 +44,16 @@ def __init__( self.num_channels = num_channels self.num_layers = num_layers - def build_blocks(self, Block: nn.Module, *args, **kwargs) -> None: + def build_blocks(self, block: nn.Module, *args, **kwargs) -> None: """Build Layers.""" - self.blocks = nn.ModuleList( [ - Block( + block( *args, **kwargs, ) for _ in range(self.num_layers) - ] + ], ) @abstractmethod @@ -161,7 +162,7 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, ) -> OptPairTensor: @@ -185,7 +186,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, activation: str = "GELU", - edge_dim: int = None, + edge_dim: Optional[int] = None, ) -> None: """Initialize GraphTransformerProcessorChunk. @@ -221,7 +222,7 @@ def forward( x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, - shapes: Tuple, + shapes: tuple, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, size: Optional[Size] = None, diff --git a/src/anemoi/models/layers/conv.py b/src/anemoi/models/layers/conv.py index f10f6edb..e5e42298 100644 --- a/src/anemoi/models/layers/conv.py +++ b/src/anemoi/models/layers/conv.py @@ -9,11 +9,10 @@ import math from typing import Optional -from typing import Tuple import torch -import torch.nn.functional as F from torch import Tensor +from torch.nn.functional import dropout from torch_geometric.nn.conv import MessagePassing from torch_geometric.typing import Adj from torch_geometric.typing import OptPairTensor @@ -60,10 +59,7 @@ def __init__( ) def forward(self, x: OptPairTensor, edge_attr: Tensor, edge_index: Adj, size: Optional[Size] = None): - if isinstance(x, Tensor): - dim_size = x.shape[0] - else: - dim_size = x[1].shape[0] + dim_size = x.shape[0] if isinstance(x, Tensor) else x[1].shape[0] out, edges_new = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size, dim_size=dim_size) @@ -74,7 +70,7 @@ def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor, dim_size: Optiona return edges_new - def aggregate(self, edges_new: Tensor, edge_index: Adj, dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]: + def aggregate(self, edges_new: Tensor, edge_index: Adj, dim_size: Optional[int] = None) -> tuple[Tensor, Tensor]: out = scatter(edges_new, edge_index[1], dim=0, dim_size=dim_size, reduce="sum") return out, edges_new @@ -137,6 +133,6 @@ def message( alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels) alpha = softmax(alpha, index, ptr, size_i) - alpha = F.dropout(alpha, p=self.dropout, training=self.training) + alpha = dropout(alpha, p=self.dropout, training=self.training) return (value_j + edge_attr) * alpha.view(-1, heads, 1) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index dfd2124d..71703d9f 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -25,7 +25,7 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None: torch.empty( tensor_size, trainable_size, - ) + ), ) nn.init.constant_(trainable, 0) else: diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index cf9c1fa3..9f5f90bf 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -9,7 +9,6 @@ from abc import ABC from typing import Optional -from typing import Tuple import numpy as np import torch @@ -21,11 +20,11 @@ from torch_geometric.typing import Adj from torch_geometric.typing import PairTensor -from anemoi.models.distributed.helpers import change_channels_in_shape -from anemoi.models.distributed.helpers import gather_tensor -from anemoi.models.distributed.helpers import get_shape_shards -from anemoi.models.distributed.helpers import shard_tensor -from anemoi.models.distributed.helpers import sort_edges_1hop +from anemoi.models.distributed.graph import gather_tensor +from anemoi.models.distributed.graph import shard_tensor +from anemoi.models.distributed.khop_edges import sort_edges_1hop +from anemoi.models.distributed.shapes import change_channels_in_shape +from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.block import GraphConvMapperBlock from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor @@ -62,7 +61,7 @@ def offload_layers(self, cpu_offload): if cpu_offload: self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) - def pre_process(self, x, shard_shapes, model_comm_group=None) -> Tuple[Tensor, Tensor, Tuple[int], Tuple[int]]: + def pre_process(self, x, shard_shapes, model_comm_group=None) -> tuple[Tensor, Tensor, tuple[int], tuple[int]]: """Pre-processing for the Mappers. Splits the tuples into src and dst nodes and shapes as the base operation. @@ -136,8 +135,7 @@ def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, t ) def _expand_edges(self, edge_index: Adj, edge_inc: Tensor, batch_size: int) -> Adj: - """Expand edge index correct number of times while adding the proper number to - the edge index. + """Expand edge index while incrementing to the edge index. Parameters ---------- @@ -234,7 +232,7 @@ def forward( self, x: PairTensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], Tuple[int]], + shard_shapes: tuple[tuple[int], tuple[int]], model_comm_group: Optional[ProcessGroup] = None, ) -> PairTensor: size = (sum(x[0] for x in shard_shapes[0]), sum(x[0] for x in shard_shapes[1])) @@ -324,7 +322,7 @@ def forward( self, x: PairTensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], Tuple[int]], + shard_shapes: tuple[tuple[int], tuple[int]], model_comm_group: Optional[ProcessGroup] = None, ) -> PairTensor: x_dst = super().forward(x, batch_size, shard_shapes, model_comm_group) @@ -481,7 +479,7 @@ def forward( self, x: PairTensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], Tuple[int]], + shard_shapes: tuple[tuple[int], tuple[int]], model_comm_group: Optional[ProcessGroup] = None, ) -> PairTensor: @@ -492,7 +490,12 @@ def forward( x_src, x_dst, shapes_src, shapes_dst = self.pre_process(x, shard_shapes, model_comm_group) (x_src, x_dst), edge_attr = self.proc( - (x_src, x_dst), edge_attr, edge_index, (shapes_src, shapes_dst), model_comm_group, size=size + (x_src, x_dst), + edge_attr, + edge_index, + (shapes_src, shapes_dst), + model_comm_group, + size=size, ) x_dst = self.post_process(x_dst, shapes_dst, model_comm_group) @@ -671,7 +674,7 @@ def forward( self, x: PairTensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], Tuple[int]], + shard_shapes: tuple[tuple[int], tuple[int]], model_comm_group: Optional[ProcessGroup] = None, ) -> Tensor: diff --git a/src/anemoi/models/layers/mlp.py b/src/anemoi/models/layers/mlp.py index d3440bbf..4de53001 100644 --- a/src/anemoi/models/layers/mlp.py +++ b/src/anemoi/models/layers/mlp.py @@ -63,7 +63,6 @@ def __init__( RuntimeError If activation function is not supported """ - super().__init__() try: act_func = getattr(nn, activation) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 35910e87..39a6f24a 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -9,7 +9,6 @@ from abc import ABC from typing import Optional -from typing import Tuple from torch import Tensor from torch import nn @@ -18,10 +17,10 @@ from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from anemoi.models.distributed.helpers import change_channels_in_shape -from anemoi.models.distributed.helpers import get_shape_shards -from anemoi.models.distributed.helpers import shard_tensor -from anemoi.models.distributed.helpers import sort_edges_1hop +from anemoi.models.distributed.graph import shard_tensor +from anemoi.models.distributed.khop_edges import sort_edges_1hop +from anemoi.models.distributed.shapes import change_channels_in_shape +from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.chunk import GNNProcessorChunk from anemoi.models.layers.chunk import GraphTransformerProcessorChunk from anemoi.models.layers.chunk import TransformerProcessorChunk @@ -45,7 +44,8 @@ def __init__( """Initialize BaseProcessor.""" super().__init__() - self.num_layers = num_layers + # Each Processor divides the layers into chunks that get assigned to each ProcessorChunk + self.num_chunks = num_chunks self.num_channels = num_channels self.chunk_size = num_layers // num_chunks @@ -57,20 +57,19 @@ def offload_layers(self, cpu_offload): if cpu_offload: self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) - def build_layers(self, ProcessorChunk, *args, **kwargs) -> None: + def build_layers(self, processor_chunk_class, *args, **kwargs) -> None: """Build Layers.""" - self.proc = nn.ModuleList( [ - ProcessorChunk( + processor_chunk_class( *args, **kwargs, ) - for _ in range(self.num_layers) - ] + for _ in range(self.num_chunks) + ], ) - def run_layers(self, data: Tuple, *args, **kwargs) -> Tensor: + def run_layers(self, data: tuple, *args, **kwargs) -> Tensor: """Run Layers with checkpoint.""" for layer in self.proc: data = checkpoint(layer, *data, *args, **kwargs, use_reentrant=False) @@ -142,7 +141,7 @@ def forward( self, x: Tensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], ...], + shard_shapes: tuple[tuple[int], ...], model_comm_group: Optional[ProcessGroup] = None, *args, **kwargs, @@ -206,12 +205,12 @@ def __init__( self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) - kwargs = dict( - num_layers=self.chunk_size, - mlp_extra_layers=mlp_extra_layers, - activation=activation, - edge_dim=None, - ) + kwargs = { + "num_layers": self.chunk_size, + "mlp_extra_layers": mlp_extra_layers, + "activation": activation, + "edge_dim": None, + } self.build_layers(GNNProcessorChunk, num_channels, **kwargs) @@ -224,7 +223,7 @@ def forward( self, x: Tensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], Tuple[int]], + shard_shapes: tuple[tuple[int], tuple[int]], model_comm_group: Optional[ProcessGroup] = None, ) -> Tensor: shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels) @@ -232,7 +231,10 @@ def forward( edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size) target_nodes = sum(x[0] for x in shape_nodes) edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop( - target_nodes, edge_attr, edge_index, model_comm_group + target_nodes, + edge_attr, + edge_index, + model_comm_group, ) edge_index = shard_tensor(edge_index, 1, shapes_edge_idx, model_comm_group) edge_attr = shard_tensor(edge_attr, 0, shapes_edge_attr, model_comm_group) @@ -309,7 +311,7 @@ def forward( self, x: Tensor, batch_size: int, - shard_shapes: Tuple[Tuple[int], Tuple[int]], + shard_shapes: tuple[tuple[int], tuple[int]], model_comm_group: Optional[ProcessGroup] = None, *args, **kwargs, diff --git a/src/anemoi/models/layers/utils.py b/src/anemoi/models/layers/utils.py index f4caf7ae..90cdcc9a 100644 --- a/src/anemoi/models/layers/utils.py +++ b/src/anemoi/models/layers/utils.py @@ -26,7 +26,7 @@ def forward(self, *args, **kwargs): class AutocastLayerNorm(nn.LayerNorm): """LayerNorm that casts the output back to the input type.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, x: Tensor) -> Tensor: diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py new file mode 100644 index 00000000..3bab80e6 --- /dev/null +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -0,0 +1,264 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import einops +import torch +from anemoi.utils.config import DotConfig +from hydra.utils import instantiate +from torch import Tensor +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup +from torch.utils.checkpoint import checkpoint +from torch_geometric.data import HeteroData + +from anemoi.models.distributed.shapes import get_shape_shards +from anemoi.models.layers.graph import TrainableTensor + +LOGGER = logging.getLogger(__name__) + + +class AnemoiModelEncProcDec(nn.Module): + """Message passing graph neural network.""" + + def __init__( + self, + *, + config: DotConfig, + data_indices: dict, + graph_data: HeteroData, + ) -> None: + """Initializes the graph neural network. + + Parameters + ---------- + config : DictConfig + Job configuration + graph_data : HeteroData + Graph definition + """ + super().__init__() + + self._graph_data = graph_data + self._graph_name_data = config.graph.data + self._graph_name_hidden = config.graph.hidden + + self._calculate_shapes_and_indices(data_indices) + self._assert_matching_indices(data_indices) + + self.multi_step = config.training.multistep_input + + self._define_tensor_sizes(config) + + # Create trainable tensors + self._create_trainable_attributes() + + # Register lat/lon + self._register_latlon("data", self._graph_name_data) + self._register_latlon("hidden", self._graph_name_hidden) + + self.num_channels = config.model.num_channels + + input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + + # Encoder data -> hidden + self.encoder = instantiate( + config.model.encoder, + in_channels_src=input_dim, + in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + hidden_dim=self.num_channels, + sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], + src_grid_size=self._data_grid_size, + dst_grid_size=self._hidden_grid_size, + ) + + # Processor hidden -> hidden + self.processor = instantiate( + config.model.processor, + num_channels=self.num_channels, + sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], + src_grid_size=self._hidden_grid_size, + dst_grid_size=self._hidden_grid_size, + ) + + # Decoder hidden -> data + self.decoder = instantiate( + config.model.decoder, + in_channels_src=self.num_channels, + in_channels_dst=input_dim, + hidden_dim=self.num_channels, + out_channels_dst=self.num_output_channels, + sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], + src_grid_size=self._hidden_grid_size, + dst_grid_size=self._data_grid_size, + ) + + def _calculate_shapes_and_indices(self, data_indices: dict) -> None: + self.num_input_channels = len(data_indices.model.input) + self.num_output_channels = len(data_indices.model.output) + self._internal_input_idx = data_indices.model.input.prognostic + self._internal_output_idx = data_indices.model.output.prognostic + + def _assert_matching_indices(self, data_indices: dict) -> None: + + assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len( + data_indices.model.output.diagnostic + ), ( + f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding " + f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})", + ) + assert len(self._internal_input_idx) == len( + self._internal_output_idx, + ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" + + def _define_tensor_sizes(self, config: DotConfig) -> None: + # Define Sizes of different tensors + self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[ + 0 + ] + self._hidden_grid_size = self._graph_data[ + (self._graph_name_hidden, "to", self._graph_name_hidden) + ].hcoords_rad.shape[0] + + self.trainable_data_size = config.model.trainable_parameters.data + self.trainable_hidden_size = config.model.trainable_parameters.hidden + + def _register_latlon(self, name: str, key: str) -> None: + """Register lat/lon buffers. + + Parameters + ---------- + name : str + Name of grid to map + key : str + Key of the grid + """ + self.register_buffer( + f"latlons_{name}", + torch.cat( + [ + torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), + torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), + ], + dim=-1, + ), + persistent=True, + ) + + def _create_trainable_attributes(self) -> None: + """Create all trainable attributes.""" + self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) + self.trainable_hidden = TrainableTensor( + trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size + ) + + def _run_mapper( + self, + mapper: nn.Module, + data: tuple[Tensor], + batch_size: int, + shard_shapes: tuple[tuple[int, int], tuple[int, int]], + model_comm_group: Optional[ProcessGroup] = None, + use_reentrant: bool = False, + ) -> Tensor: + """Run mapper with activation checkpoint. + + Parameters + ---------- + mapper : nn.Module + Which processor to use + data : tuple[Tensor] + tuple of data to pass in + batch_size: int, + Batch size + shard_shapes : tuple[tuple[int, int], tuple[int, int]] + Shard shapes for the data + model_comm_group : ProcessGroup + model communication group, specifies which GPUs work together + in one model instance + use_reentrant : bool, optional + Use reentrant, by default False + + Returns + ------- + Tensor + Mapped data + """ + return checkpoint( + mapper, + data, + batch_size=batch_size, + shard_shapes=shard_shapes, + model_comm_group=model_comm_group, + use_reentrant=use_reentrant, + ) + + def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: + batch_size = x.shape[0] + ensemble_size = x.shape[2] + + # add data positional info (lat/lon) + x_data_latent = torch.cat( + ( + einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), + self.trainable_data(self.latlons_data, batch_size=batch_size), + ), + dim=-1, # feature dimension + ) + + x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + + # get shard shapes + shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) + shard_shapes_hidden = get_shape_shards(x_hidden_latent, 0, model_comm_group) + + # Run encoder + x_data_latent, x_latent = self._run_mapper( + self.encoder, + (x_data_latent, x_hidden_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_data, shard_shapes_hidden), + model_comm_group=model_comm_group, + ) + + x_latent_proc = self.processor( + x_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hidden, + model_comm_group=model_comm_group, + ) + + # add skip connection (hidden -> hidden) + x_latent_proc = x_latent_proc + x_latent + + # Run decoder + x_out = self._run_mapper( + self.decoder, + (x_latent_proc, x_data_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_hidden, shard_shapes_data), + model_comm_group=model_comm_group, + ) + + x_out = ( + einops.rearrange( + x_out, + "(batch ensemble grid) vars -> batch ensemble grid vars", + batch=batch_size, + ensemble=ensemble_size, + ) + .to(dtype=x.dtype) + .clone() + ) + + # residual connection (just for the prognostic variables) + x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + return x_out diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py new file mode 100644 index 00000000..c23c15e8 --- /dev/null +++ b/src/anemoi/models/preprocessing/__init__.py @@ -0,0 +1,167 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import torch +from torch import Tensor +from torch import nn + +LOGGER = logging.getLogger(__name__) + + +class BasePreprocessor(nn.Module): + """Base class for data pre- and post-processors.""" + + def __init__( + self, + config=None, + statistics: Optional[dict] = None, + data_indices: Optional[dict] = None, + ) -> None: + """Initialize the preprocessor. + + Parameters + ---------- + config : Dotconfig + configuration object + statistics : dict + Data statistics dictionary + data_indices : dict + Data indices for input and output variables + """ + super().__init__() + + self.default, self.method_config = self._process_config(config) + self.methods = self._invert_key_value_list(self.method_config) + + self.data_indices = data_indices + + def _process_config(self, config): + default = config.get("default", "none") + method_config = {k: v for k, v in config.items() if k != "default" and v is not None and v != "none"} + + if not method_config: + LOGGER.warning( + f"{self.__class__.__name__}: Using default method {default} for all variables not specified in the config.", + ) + + return default, method_config + + def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[str, str]: + """Invert a dictionary of methods with lists of variables. + + Parameters + ---------- + method_config : dict[str, list[str]] + dictionary of the methods with lists of variables + + Returns + ------- + dict[str, str] + dictionary of the variables with methods + """ + return { + variable: method + for method, variables in method_config.items() + if not isinstance(variables, str) + for variable in variables + } + + def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor: + """Process the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place + inverse : bool + Whether to inverse transform the input + + Returns + ------- + torch.Tensor + Processed tensor + """ + if inverse: + return self.inverse_transform(x, in_place=in_place) + return self.transform(x, in_place=in_place) + + def transform(self, x, in_place: bool = True) -> Tensor: + """Process the input tensor.""" + if not in_place: + x = x.clone() + return x + + def inverse_transform(self, x, in_place: bool = True) -> Tensor: + """Inverse process the input tensor.""" + if not in_place: + x = x.clone() + return x + + +class Processors(nn.Module): + """A collection of processors.""" + + def __init__(self, processors: list, inverse: bool = False) -> None: + """Initialize the processors. + + Parameters + ---------- + processors : list + List of processors + """ + super().__init__() + + self.inverse = inverse + self.first_run = True + + if inverse: + # Reverse the order of processors for inverse transformation + # e.g. first impute then normalise forward but denormalise then de-impute for inverse + processors = processors[::-1] + + self.processors = nn.ModuleDict(processors) + + def __repr__(self) -> str: + return f"{self.__class__.__name__} [{'inverse' if self.inverse else 'forward'}]({self.processors})" + + def forward(self, x, in_place: bool = True) -> Tensor: + """Process the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place + + Returns + ------- + torch.Tensor + Processed tensor + """ + for processor in self.processors.values(): + x = processor(x, in_place=in_place, inverse=self.inverse) + + if self.first_run: + self.first_run = False + self._run_checks(x) + return x + + def _run_checks(self, x): + """Run checks on the processed tensor.""" + if not self.inverse: + # Forward transformation checks: + assert not torch.isnan( + x + ).any(), f"NaNs ({torch.isnan(x).sum()}) found in processed tensor after {self.__class__.__name__}." diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py new file mode 100644 index 00000000..a2e5bd63 --- /dev/null +++ b/src/anemoi/models/preprocessing/imputer.py @@ -0,0 +1,210 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor + +LOGGER = logging.getLogger(__name__) + + +class BaseImputer(BasePreprocessor, ABC): + """Base class for Imputers.""" + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + """Initialize the imputer. + + Parameters + ---------- + config : Dotconfig + configuration object + statistics : dict + Data statistics dictionary + data_indices : dict + Data indices for input and output variables + """ + super().__init__(config, statistics, data_indices) + + self.nan_locations = None + self.data_indices = data_indices + + def _validate_indices(self): + assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), ( + f"Error creating imputation indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.replacement)}" + ) + assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.replacement), ( + f"Error creating imputation indices {len(self.index_training_output)}, " + f"{len(self.index_inference_output)}, {len(self.replacement)}" + ) + + def _create_imputation_indices( + self, + statistics=None, + ): + """Create the indices for imputation.""" + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + + ( + self.index_training_input, + self.index_inference_input, + self.index_training_output, + self.index_inference_output, + self.replacement, + ) = ([], [], [], [], []) + + # Create indices for imputation + for name in name_to_index_training_input: + + method = self.methods.get(name, self.default) + if method == "none": + LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified") + continue + + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output.get(name, None)) + self.index_inference_input.append(name_to_index_inference_input.get(name, None)) + self.index_inference_output.append(name_to_index_inference_output.get(name, None)) + + if statistics is None: + self.replacement.append(method) + elif isinstance(statistics, dict): + assert method in statistics, f"{method} is not a method in the statistics metadata" + self.replacement.append(statistics[method][name_to_index_training_input[name]]) + else: + raise TypeError(f"Statistics {type(statistics)} is optional and not a dictionary") + + LOGGER.debug(f"Imputer: replacing NaNs in {name} with value {self.replacement[-1]}") + + def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: + """Expand the subset of the mask to the correct shape.""" + return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initilialize mask once + if self.nan_locations is None: + # The mask is only saved for the last two dimensions (grid, variable) + idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] + self.nan_locations = torch.isnan(x[idx].squeeze()) + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value + return x + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Replace original nans with nan again + if x.shape[-1] == self.num_training_output_vars: + index = self.index_training_output + elif x.shape[-1] == self.num_inference_output_vars: + index = self.index_inference_output + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})", + ) + + # Replace values + for idx_src, idx_dst in zip(self.index_training_input, index): + if idx_dst is not None: + x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = torch.nan + return x + + +class InputImputer(BaseImputer): + """Imputes missing values using the statistics supplied. + + Expects the config to have keys corresponding to available statistics + and values as lists of variables to impute.: + ``` + default: "none" + mean: + - y + maximum: + - x + minimum: + - q + ``` + """ + + def __init__( + self, + config=None, + statistics: Optional[dict] = None, + data_indices: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + + self._create_imputation_indices(statistics) + + self._validate_indices() + + +class ConstantImputer(BaseImputer): + """Imputes missing values using the constant value. + + Expects the config to have keys corresponding to available statistics + and values as lists of variables to impute.: + ``` + default: "none" + 1: + - y + 5.0: + - x + 3.14: + - q + ``` + """ + + def __init__( + self, config=None, statistics: Optional[dict] = None, data_indices: Optional[IndexCollection] = None + ) -> None: + super().__init__(config, data_indices, statistics) + + self._create_imputation_indices() + + self._validate_indices() diff --git a/src/anemoi/models/preprocessing/normalizer.py b/src/anemoi/models/preprocessing/normalizer.py new file mode 100644 index 00000000..8a7dd614 --- /dev/null +++ b/src/anemoi/models/preprocessing/normalizer.py @@ -0,0 +1,182 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import warnings +from typing import Optional + +import numpy as np +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor + +LOGGER = logging.getLogger(__name__) + + +class InputNormalizer(BasePreprocessor): + """Normalizes input data with a configurable method.""" + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + """Initialize the normalizer. + + Parameters + ---------- + config : Dotconfig + configuration object + statistics : dict + Data statistics dictionary + data_indices : dict + Data indices for input and output variables + """ + super().__init__(config, statistics, data_indices) + + name_to_index_training_input = self.data_indices.data.input.name_to_index + + minimum = statistics["minimum"] + maximum = statistics["maximum"] + mean = statistics["mean"] + stdev = statistics["stdev"] + + self._validate_normalization_inputs(name_to_index_training_input, minimum, maximum, mean, stdev) + + _norm_add = np.zeros((minimum.size,), dtype=np.float32) + _norm_mul = np.ones((minimum.size,), dtype=np.float32) + + for name, i in name_to_index_training_input.items(): + method = self.methods.get(name, self.default) + if method == "mean-std": + LOGGER.debug(f"Normalizing: {name} is mean-std-normalised.") + if stdev[i] < (mean[i] * 1e-6): + warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}") + _norm_mul[i] = 1 / stdev[i] + _norm_add[i] = -mean[i] / stdev[i] + + elif method == "min-max": + LOGGER.debug(f"Normalizing: {name} is min-max-normalised to [0, 1].") + x = maximum[i] - minimum[i] + if x < 1e-9: + warnings.warn(f"Normalizing: the field {name} seems to have only one value {maximum[i]}.") + _norm_mul[i] = 1 / x + _norm_add[i] = -minimum[i] / x + + elif method == "max": + LOGGER.debug(f"Normalizing: {name} is max-normalised to [0, 1].") + _norm_mul[i] = 1 / maximum[i] + + elif method == "none": + LOGGER.info(f"Normalizing: {name} is not normalized.") + + else: + raise ValueError[f"Unknown normalisation method for {name}: {method}"] + + # register buffer - this will ensure they get copied to the correct device(s) + self.register_buffer("_norm_mul", torch.from_numpy(_norm_mul), persistent=True) + self.register_buffer("_norm_add", torch.from_numpy(_norm_add), persistent=True) + self.register_buffer("_input_idx", data_indices.data.input.full, persistent=True) + self.register_buffer("_output_idx", self.data_indices.data.output.full, persistent=True) + + def _validate_normalization_inputs(self, name_to_index_training_input: dict, minimum, maximum, mean, stdev): + assert len(self.methods) == sum(len(v) for v in self.method_config.values()), ( + f"Error parsing methods in InputNormalizer methods ({len(self.methods)}) " + f"and entries in config ({sum(len(v) for v in self.method_config)}) do not match." + ) + n = minimum.size + assert maximum.size == n, (maximum.size, n) + assert mean.size == n, (mean.size, n) + assert stdev.size == n, (stdev.size, n) + + assert isinstance(self.methods, dict) + for name, method in self.methods.items(): + assert name in name_to_index_training_input, f"{name} is not a valid variable name" + assert method in [ + "mean-std", + # "robust", + "min-max", + "max", + "none", + ], f"{method} is not a valid normalisation method" + + def transform( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Normalizes an input tensor x of shape [..., nvars]. + + Normalization done in-place unless specified otherwise. + + The default usecase either assume the full batch tensor or the full input tensor. + A dataindex is based on the full data can be supplied to choose which variables to normalise. + + Parameters + ---------- + x : torch.Tensor + Data to normalize + in_place : bool, optional + Normalize in-place, by default True + data_index : Optional[torch.Tensor], optional + Normalize only the specified indices, by default None + + Returns + ------- + torch.Tensor + _description_ + """ + if not in_place: + x = x.clone() + + if data_index is not None: + x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index] + elif x.shape[-1] == len(self._input_idx): + x[..., :] = x[..., :] * self._norm_mul[self._input_idx] + self._norm_add[self._input_idx] + else: + x[..., :] = x[..., :] * self._norm_mul + self._norm_add + return x + + def inverse_transform( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Denormalizes an input tensor x of shape [..., nvars | nvars_pred]. + + Denormalization done in-place unless specified otherwise. + + The default usecase either assume the full batch tensor or the full output tensor. + A dataindex is based on the full data can be supplied to choose which variables to denormalise. + + Parameters + ---------- + x : torch.Tensor + Data to denormalize + in_place : bool, optional + Denormalize in-place, by default True + data_index : Optional[torch.Tensor], optional + Denormalize only the specified indices, by default None + + Returns + ------- + torch.Tensor + Denormalized data + """ + if not in_place: + x = x.clone() + + # Denormalize dynamic or full tensors + # input and predicted tensors have different shapes + # hence, we mask out the forcing indices + if data_index is not None: + x[..., :] = (x[..., :] - self._norm_add[data_index]) / self._norm_mul[data_index] + elif x.shape[-1] == len(self._output_idx): + x[..., :] = (x[..., :] - self._norm_add[self._output_idx]) / self._norm_mul[self._output_idx] + else: + x[..., :] = (x[..., :] - self._norm_add) / self._norm_mul + return x diff --git a/tests/data_indices/test_collection.py b/tests/data_indices/test_collection.py new file mode 100644 index 00000000..5558c91d --- /dev/null +++ b/tests/data_indices/test_collection.py @@ -0,0 +1,92 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch +from omegaconf import DictConfig + +from anemoi.models.data_indices.collection import IndexCollection + + +@pytest.fixture() +def data_indices(): + config = DictConfig( + { + "data": { + "forcing": ["x"], + "diagnostic": ["z", "q"], + }, + }, + ) + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + return IndexCollection(config=config, name_to_index=name_to_index) + + +def test_dataindices_init(data_indices) -> None: + assert data_indices.data.input.includes == ["x"] + assert data_indices.data.input.excludes == ["z", "q"] + assert data_indices.data.output.includes == ["z", "q"] + assert data_indices.data.output.excludes == ["x"] + assert data_indices.model.input.includes == ["x"] + assert data_indices.model.input.excludes == [] + assert data_indices.model.output.includes == ["z", "q"] + assert data_indices.model.output.excludes == [] + assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "other": 2} + assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3} + + +def test_dataindices_max(data_indices) -> None: + assert max(data_indices.data.input.full) == max(data_indices.data.input.name_to_index.values()) + assert max(data_indices.data.output.full) == max(data_indices.data.output.name_to_index.values()) + assert max(data_indices.model.input.full) == max(data_indices.model.input.name_to_index.values()) + assert max(data_indices.model.output.full) == max(data_indices.model.output.name_to_index.values()) + + +def test_dataindices_todict(data_indices) -> None: + expected_output = { + "input": { + "full": torch.Tensor([0, 1, 4]).to(torch.int), + "forcing": torch.Tensor([0]).to(torch.int), + "diagnostic": torch.Tensor([2, 3]).to(torch.int), + "prognostic": torch.Tensor([1, 4]).to(torch.int), + }, + "output": { + "full": torch.Tensor([1, 2, 3, 4]).to(torch.int), + "forcing": torch.Tensor([0]).to(torch.int), + "diagnostic": torch.Tensor([2, 3]).to(torch.int), + "prognostic": torch.Tensor([1, 4]).to(torch.int), + }, + } + + for key in ["output", "input"]: + for subkey, value in data_indices.data.todict()[key].items(): + assert subkey in expected_output[key] + assert torch.allclose(value, expected_output[key][subkey]) + + +def test_modelindices_todict(data_indices) -> None: + expected_output = { + "input": { + "full": torch.Tensor([0, 1, 2]).to(torch.int), + "forcing": torch.Tensor([0]).to(torch.int), + "diagnostic": torch.Tensor([]).to(torch.int), + "prognostic": torch.Tensor([1, 2]).to(torch.int), + }, + "output": { + "full": torch.Tensor([0, 1, 2, 3]).to(torch.int), + "forcing": torch.Tensor([]).to(torch.int), + "diagnostic": torch.Tensor([1, 2]).to(torch.int), + "prognostic": torch.Tensor([0, 3]).to(torch.int), + }, + } + + for key in ["output", "input"]: + for subkey, value in data_indices.model.todict()[key].items(): + assert subkey in expected_output[key] + assert torch.allclose(value, expected_output[key][subkey]) diff --git a/tests/data_indices/test_data_indices.py b/tests/data_indices/test_data_indices.py new file mode 100644 index 00000000..e447832f --- /dev/null +++ b/tests/data_indices/test_data_indices.py @@ -0,0 +1,136 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch + +from anemoi.models.data_indices.index import DataIndex +from anemoi.models.data_indices.tensor import InputTensorIndex +from anemoi.models.data_indices.tensor import OutputTensorIndex + + +@pytest.fixture() +def fake_data(): + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + forcing = ["x", "y"] + diagnostic = ["z"] + return forcing, diagnostic, name_to_index + + +@pytest.fixture() +def input_tensor_index(fake_data): + forcing, diagnostic, name_to_index = fake_data + return InputTensorIndex(includes=forcing, excludes=diagnostic, name_to_index=name_to_index) + + +@pytest.fixture() +def output_tensor_index(fake_data): + forcing, diagnostic, name_to_index = fake_data + return OutputTensorIndex(includes=diagnostic, excludes=forcing, name_to_index=name_to_index) + + +def test_dataindex_init(fake_data, input_tensor_index, output_tensor_index) -> None: + forcing, diagnostic, name_to_index = fake_data + data_index = DataIndex(forcing=forcing, diagnostic=diagnostic, name_to_index=name_to_index) + assert data_index.input == input_tensor_index + assert data_index.output == output_tensor_index + + +def test_output_tensor_index_full(output_tensor_index) -> None: + expected_output = torch.Tensor([2, 3, 4]).to(torch.int) + assert torch.allclose(output_tensor_index.full, expected_output) + + +def test_output_tensor_index_only(output_tensor_index) -> None: + expected_output = torch.Tensor([2]).to(torch.int) + assert torch.allclose(output_tensor_index._only, expected_output) + + +def test_output_tensor_index_prognostic(output_tensor_index) -> None: + expected_output = torch.Tensor([3, 4]).to(torch.int) + assert torch.allclose(output_tensor_index.prognostic, expected_output) + + +def test_output_tensor_index_todict(output_tensor_index) -> None: + expected_output = { + "full": torch.Tensor([2, 3, 4]).to(torch.int), + "diagnostic": torch.Tensor([2]).to(torch.int), + "forcing": torch.Tensor([0, 1]).to(torch.int), + "prognostic": torch.Tensor([3, 4]).to(torch.int), + } + for key, value in output_tensor_index.todict().items(): + assert key in expected_output + assert torch.allclose(value, expected_output[key]) + + +def test_output_tensor_index_getattr(output_tensor_index) -> None: + assert output_tensor_index.full is not None + with pytest.raises(AttributeError): + output_tensor_index.z + + +def test_output_tensor_index_build_idx_from_excludes(output_tensor_index) -> None: + expected_output = torch.Tensor([2, 3, 4]).to(torch.int) + assert torch.allclose(output_tensor_index._build_idx_from_excludes(), expected_output) + + +def test_output_tensor_index_build_idx_from_includes(output_tensor_index) -> None: + expected_output = torch.Tensor([2]).to(torch.int) + assert torch.allclose(output_tensor_index._build_idx_from_includes(), expected_output) + + +def test_output_tensor_index_build_idx_prognostic(output_tensor_index) -> None: + expected_output = torch.Tensor([3, 4]).to(torch.int) + assert torch.allclose(output_tensor_index._build_idx_prognostic(), expected_output) + + +def test_input_tensor_index_full(input_tensor_index) -> None: + expected_output = torch.Tensor([0, 1, 3, 4]).to(torch.int) + assert torch.allclose(input_tensor_index.full, expected_output) + + +def test_input_tensor_index_only(input_tensor_index) -> None: + expected_output = torch.Tensor([0, 1]).to(torch.int) + assert torch.allclose(input_tensor_index._only, expected_output) + + +def test_input_tensor_index_prognostic(input_tensor_index) -> None: + expected_output = torch.Tensor([3, 4]).to(torch.int) + assert torch.allclose(input_tensor_index.prognostic, expected_output) + + +def test_input_tensor_index_todict(input_tensor_index) -> None: + expected_output = { + "full": torch.Tensor([0, 1, 3, 4]).to(torch.int), + "diagnostic": torch.Tensor([2]).to(torch.int), + "forcing": torch.Tensor([0, 1]).to(torch.int), + "prognostic": torch.Tensor([3, 4]).to(torch.int), + } + for key, value in input_tensor_index.todict().items(): + assert key in expected_output + assert torch.allclose(value, expected_output[key]) + + +def test_input_tensor_index_getattr(input_tensor_index) -> None: + assert input_tensor_index.full is not None + with pytest.raises(AttributeError): + input_tensor_index.z + + +def test_input_tensor_index_build_idx_from_excludes(input_tensor_index) -> None: + expected_output = torch.Tensor([0, 1, 3, 4]).to(torch.int) + assert torch.allclose(input_tensor_index._build_idx_from_excludes(), expected_output) + + +def test_input_tensor_index_build_idx_from_includes(input_tensor_index) -> None: + expected_output = torch.Tensor([0, 1]).to(torch.int) + assert torch.allclose(input_tensor_index._build_idx_from_includes(), expected_output) + + +def test_input_tensor_index_build_idx_prognostic(input_tensor_index) -> None: + expected_output = torch.Tensor([3, 4]).to(torch.int) + assert torch.allclose(input_tensor_index._build_idx_prognostic(), expected_output) diff --git a/tests/layers/block/test_block_graphconv.py b/tests/layers/block/test_block_graphconv.py index 38653e3d..34083e6e 100644 --- a/tests/layers/block/test_block_graphconv.py +++ b/tests/layers/block/test_block_graphconv.py @@ -5,14 +5,15 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from aifs.layers.block import MLP -from aifs.layers.block import GraphConvMapperBlock -from aifs.layers.block import GraphConvProcessorBlock -from aifs.layers.conv import GraphConv from hypothesis import given from hypothesis import settings from hypothesis import strategies as st +from anemoi.models.layers.block import MLP +from anemoi.models.layers.block import GraphConvMapperBlock +from anemoi.models.layers.block import GraphConvProcessorBlock +from anemoi.models.layers.conv import GraphConv + class TestGraphConvProcessorBlock: @given( diff --git a/tests/layers/block/test_block_graphtransformer.py b/tests/layers/block/test_block_graphtransformer.py index 93b20e55..00af258f 100644 --- a/tests/layers/block/test_block_graphtransformer.py +++ b/tests/layers/block/test_block_graphtransformer.py @@ -8,9 +8,10 @@ import pytest import torch import torch.nn as nn -from aifs.layers.block import GraphTransformerMapperBlock -from aifs.layers.block import GraphTransformerProcessorBlock -from aifs.layers.conv import GraphTransformerConv + +from anemoi.models.layers.block import GraphTransformerMapperBlock +from anemoi.models.layers.block import GraphTransformerProcessorBlock +from anemoi.models.layers.conv import GraphTransformerConv @pytest.fixture diff --git a/tests/layers/block/test_block_transformer.py b/tests/layers/block/test_block_transformer.py index 79bbf24e..97c274f6 100644 --- a/tests/layers/block/test_block_transformer.py +++ b/tests/layers/block/test_block_transformer.py @@ -5,19 +5,21 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import logging + import torch -from aifs.layers.attention import MultiHeadSelfAttention -from aifs.layers.block import MLP -from aifs.layers.block import GraphConvProcessorBlock -from aifs.layers.block import TransformerProcessorBlock -from aifs.layers.conv import GraphConv -from aifs.utils.logger import get_code_logger from hypothesis import given from hypothesis import settings from hypothesis import strategies as st from torch import nn -LOGGER = get_code_logger(__name__) +from anemoi.models.layers.attention import MultiHeadSelfAttention +from anemoi.models.layers.block import MLP +from anemoi.models.layers.block import GraphConvProcessorBlock +from anemoi.models.layers.block import TransformerProcessorBlock +from anemoi.models.layers.conv import GraphConv + +LOGGER = logging.getLogger(__name__) class TestTransformerProcessorBlock: diff --git a/tests/layers/chunk/test_chunk_gnn.py b/tests/layers/chunk/test_chunk_gnn.py index f12a1091..64cc04b7 100644 --- a/tests/layers/chunk/test_chunk_gnn.py +++ b/tests/layers/chunk/test_chunk_gnn.py @@ -6,9 +6,10 @@ # nor does it submit to any jurisdiction. import pytest -from aifs.layers.block import GraphConvProcessorBlock -from aifs.layers.chunk import GNNProcessorChunk -from aifs.layers.mlp import MLP + +from anemoi.models.layers.block import GraphConvProcessorBlock +from anemoi.models.layers.chunk import GNNProcessorChunk +from anemoi.models.layers.mlp import MLP class TestGNNProcessorChunk: diff --git a/tests/layers/chunk/test_chunk_graphtransformer.py b/tests/layers/chunk/test_chunk_graphtransformer.py index b51127c8..1249e237 100644 --- a/tests/layers/chunk/test_chunk_graphtransformer.py +++ b/tests/layers/chunk/test_chunk_graphtransformer.py @@ -6,8 +6,9 @@ # nor does it submit to any jurisdiction. import pytest -from aifs.layers.block import GraphTransformerProcessorBlock -from aifs.layers.chunk import GraphTransformerProcessorChunk + +from anemoi.models.layers.block import GraphTransformerProcessorBlock +from anemoi.models.layers.chunk import GraphTransformerProcessorChunk class TestGraphTransformerProcessorChunk: diff --git a/tests/layers/chunk/test_chunk_transformer.py b/tests/layers/chunk/test_chunk_transformer.py index fd5d8fc9..1fe7c6d7 100644 --- a/tests/layers/chunk/test_chunk_transformer.py +++ b/tests/layers/chunk/test_chunk_transformer.py @@ -6,8 +6,9 @@ # nor does it submit to any jurisdiction. import pytest -from aifs.layers.block import TransformerProcessorBlock -from aifs.layers.chunk import TransformerProcessorChunk + +from anemoi.models.layers.block import TransformerProcessorBlock +from anemoi.models.layers.chunk import TransformerProcessorChunk class TestGraphTransformerProcessorChunk: diff --git a/tests/layers/mapper/test_base_mapper.py b/tests/layers/mapper/test_base_mapper.py index efb20a13..5b82b658 100644 --- a/tests/layers/mapper/test_base_mapper.py +++ b/tests/layers/mapper/test_base_mapper.py @@ -7,9 +7,10 @@ import pytest import torch -from aifs.layers.mapper import BaseMapper from torch_geometric.data import HeteroData +from anemoi.models.layers.mapper import BaseMapper + class TestBaseMapper: @pytest.fixture diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index ff5f664d..480a4948 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -7,12 +7,13 @@ import pytest import torch -from aifs.layers.mapper import GNNBackwardMapper -from aifs.layers.mapper import GNNBaseMapper -from aifs.layers.mapper import GNNForwardMapper from torch import nn from torch_geometric.data import HeteroData +from anemoi.models.layers.mapper import GNNBackwardMapper +from anemoi.models.layers.mapper import GNNBaseMapper +from anemoi.models.layers.mapper import GNNForwardMapper + class TestGNNBaseMapper: BIG_GRID_SIZE = 1000 diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index bd553a2c..7fd0fc0c 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -7,12 +7,13 @@ import pytest import torch -from aifs.layers.mapper import GraphTransformerBackwardMapper -from aifs.layers.mapper import GraphTransformerBaseMapper -from aifs.layers.mapper import GraphTransformerForwardMapper from torch import nn from torch_geometric.data import HeteroData +from anemoi.models.layers.mapper import GraphTransformerBackwardMapper +from anemoi.models.layers.mapper import GraphTransformerBaseMapper +from anemoi.models.layers.mapper import GraphTransformerForwardMapper + class TestGraphTransformerBaseMapper: BIG_GRID_SIZE = 1000 diff --git a/tests/layers/processor/test_base_processor.py b/tests/layers/processor/test_base_processor.py index 918fb4c7..4af3c7bc 100644 --- a/tests/layers/processor/test_base_processor.py +++ b/tests/layers/processor/test_base_processor.py @@ -6,7 +6,8 @@ # nor does it submit to any jurisdiction. import pytest -from aifs.layers.processor import BaseProcessor + +from anemoi.models.layers.processor import BaseProcessor @pytest.fixture @@ -19,7 +20,7 @@ def processor_init(): return num_layers, num_channels, num_chunks, activation, cpu_offload -@pytest.fixture +@pytest.fixture() def base_processor(processor_init): num_layers, num_channels, num_chunks, activation, cpu_offload = processor_init return BaseProcessor( @@ -32,14 +33,14 @@ def base_processor(processor_init): def test_base_processor_init(processor_init, base_processor): - num_layers, num_channels, num_chunks, _activation, _cpu_offload = processor_init + num_layers, num_channels, num_chunks, *_ = processor_init - assert isinstance(base_processor.num_layers, int), "num_layers should be an integer" + assert isinstance(base_processor.num_chunks, int), "num_layers should be an integer" assert isinstance(base_processor.num_channels, int), "num_channels should be an integer" assert ( - base_processor.num_layers == num_layers - ), f"num_layers ({base_processor.num_layers}) should be equal to the input num_layers ({num_layers})" + base_processor.num_chunks == num_chunks + ), f"num_chunks ({base_processor.num_chunks}) should be equal to the input num_chunks ({num_chunks})" assert ( base_processor.num_channels == num_channels ), f"num_channels ({base_processor.num_channels}) should be equal to the input num_channels ({num_channels})" diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index 801b5f6f..569319be 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -7,10 +7,11 @@ import pytest import torch -from aifs.layers.graph import TrainableTensor -from aifs.layers.processor import GNNProcessor from torch_geometric.data import HeteroData +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.processor import GNNProcessor + @pytest.fixture def fake_graph(): @@ -87,7 +88,7 @@ def test_graphconv_processor_init(graphconv_processor, graphconv_init): _dst_grid_size, _trainable_size, ) = graphconv_init - assert graphconv_processor.num_layers == num_layers + assert graphconv_processor.num_chunks == num_chunks assert graphconv_processor.num_channels == num_channels assert graphconv_processor.chunk_size == num_layers // num_chunks assert isinstance(graphconv_processor.trainable, TrainableTensor) diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 278ccbee..81095a2e 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -7,10 +7,11 @@ import pytest import torch -from aifs.layers.graph import TrainableTensor -from aifs.layers.processor import GraphTransformerProcessor from torch_geometric.data import HeteroData +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.processor import GraphTransformerProcessor + @pytest.fixture def fake_graph(): @@ -92,7 +93,7 @@ def test_graphtransformer_processor_init(graphtransformer_processor, graphtransf _dst_grid_size, _trainable_size, ) = graphtransformer_init - assert graphtransformer_processor.num_layers == num_layers + assert graphtransformer_processor.num_chunks == num_chunks assert graphtransformer_processor.num_channels == num_channels assert graphtransformer_processor.chunk_size == num_layers // num_chunks assert isinstance(graphtransformer_processor.trainable, TrainableTensor) diff --git a/tests/layers/processor/test_transformer_processor.py b/tests/layers/processor/test_transformer_processor.py index b6f364bd..d359c270 100644 --- a/tests/layers/processor/test_transformer_processor.py +++ b/tests/layers/processor/test_transformer_processor.py @@ -7,7 +7,8 @@ import pytest import torch -from aifs.layers.processor import TransformerProcessor + +from anemoi.models.layers.processor import TransformerProcessor @pytest.fixture @@ -68,7 +69,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor _mlp_hidden_ratio, ) = transformer_processor_init assert isinstance(transformer_processor, TransformerProcessor) - assert transformer_processor.num_layers == num_layers + assert transformer_processor.num_chunks == num_chunks assert transformer_processor.num_channels == num_channels assert transformer_processor.chunk_size == num_layers // num_chunks diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index a3a0e320..ffeaebcf 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -9,10 +9,11 @@ import pytest import torch import torch.nn as nn -from aifs.layers.attention import MultiHeadSelfAttention from hypothesis import given from hypothesis import settings +from anemoi.models.layers.attention import MultiHeadSelfAttention + @given( num_heads=st.integers(min_value=1, max_value=50), diff --git a/tests/layers/test_graph.py b/tests/layers/test_graph.py index 159faa29..a17ebfd4 100644 --- a/tests/layers/test_graph.py +++ b/tests/layers/test_graph.py @@ -7,9 +7,10 @@ import pytest import torch -from aifs.layers.graph import TrainableTensor from torch import nn +from anemoi.models.layers.graph import TrainableTensor + class TestTrainableTensor: @pytest.fixture diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index eaec8ad7..e47a0b98 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -7,7 +7,8 @@ import pytest import torch -from aifs.layers.mlp import MLP + +from anemoi.models.layers.mlp import MLP @pytest.fixture diff --git a/tests/test_models.py b/tests/models/test_models.py similarity index 100% rename from tests/test_models.py rename to tests/models/test_models.py diff --git a/tests/preprocessing/test_preprocessor_imputer.py b/tests/preprocessing/test_preprocessor_imputer.py new file mode 100644 index 00000000..ea04b9a4 --- /dev/null +++ b/tests/preprocessing/test_preprocessor_imputer.py @@ -0,0 +1,313 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import numpy as np +import pytest +import torch +from omegaconf import DictConfig + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing.imputer import ConstantImputer +from anemoi.models.preprocessing.imputer import InputImputer + + +@pytest.fixture() +def non_default_input_imputer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "imputer": {"default": "none", "mean": ["y"], "maximum": ["x"], "none": ["z"], "minimum": ["q"]}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0]), + "stdev": np.array([0.5, 0.5, 0.5, 1, 14]), + "minimum": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "maximum": np.array([11.0, 10.0, 10.0, 10.0, 10.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputImputer(config=config.data.imputer, statistics=statistics, data_indices=data_indices) + + +@pytest.fixture() +def default_input_imputer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "imputer": {"default": "minimum"}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0]), + "stdev": np.array([0.5, 0.5, 0.5, 1, 14]), + "minimum": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "maximum": np.array([11.0, 10.0, 10.0, 10.0, 10.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputImputer(config=config.data.imputer, statistics=statistics, data_indices=data_indices) + + +@pytest.fixture() +def non_default_input_data(): + base = torch.Tensor([[1.0, 2.0, 3.0, np.nan, 5.0], [6.0, np.nan, 8.0, 9.0, 10.0]]) + expected = torch.Tensor([[1.0, 2.0, 3.0, 1.0, 5.0], [6.0, 2.0, 8.0, 9.0, 10.0]]) + return base, expected + + +@pytest.fixture() +def default_input_data(): + base = torch.Tensor([[1.0, 2.0, 3.0, np.nan, 5.0], [6.0, np.nan, 8.0, 9.0, 0]]) + expected = torch.Tensor([[1.0, 2.0, 3.0, 1.0, 5.0], [6.0, 1.0, 8.0, 9.0, 0]]) + return base, expected + + +@pytest.fixture() +def default_constant_imputer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "imputer": {"default": "none", 0: ["x"], 3.0: ["y"], 22.7: ["z"], 10: ["q"]}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return ConstantImputer(config=config.data.imputer, statistics=None, data_indices=data_indices) + + +@pytest.fixture() +def non_default_constant_imputer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "imputer": {"default": 22.7}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return ConstantImputer(config=config.data.imputer, statistics=None, data_indices=data_indices) + + +@pytest.fixture() +def non_default_constant_data(): + base = torch.Tensor([[1.0, 2.0, 3.0, np.nan, 5.0], [6.0, np.nan, 8.0, 9.0, 0]]) + expected = torch.Tensor([[1.0, 2.0, 3.0, 22.7, 5.0], [6.0, 22.7, 8.0, 9.0, 0]]) + return base, expected + + +@pytest.fixture() +def default_constant_data(): + base = torch.Tensor([[1.0, 2.0, 3.0, np.nan, 5.0], [6.0, np.nan, 8.0, 9.0, 0]]) + expected = torch.Tensor([[1.0, 2.0, 3.0, 10, 5.0], [6.0, 3.0, 8.0, 9.0, 0]]) + return base, expected + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None: + """Check that the imputer does not modify the input tensor when in_place=False.""" + x, _ = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + x_old = x.clone() + imputer(x, in_place=False) + assert torch.allclose(x, x_old, equal_nan=True), "Imputer does not handle in_place=False correctly." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None: + """Check that the imputer modifies the input tensor when in_place=True.""" + x, _ = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + x_old = x.clone() + out = imputer(x, in_place=True) + assert not torch.allclose(x, x_old, equal_nan=True) + assert torch.allclose(x, out, equal_nan=True) + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_transform_with_nan(imputer_fixture, data_fixture, request): + """Check that the imputer correctly transforms a tensor with NaNs.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + transformed = imputer.transform(x) + assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_transform_with_nan_small(imputer_fixture, data_fixture, request): + """Check that the imputer correctly transforms a tensor with NaNs.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + transformed = imputer.transform(x, in_place=False) + assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly." + x_small = x[..., [0, 1, 2, 3]] + expected_small = expected[..., [0, 1, 2, 3]] + transformed_small = imputer.transform(x_small, in_place=False) + assert torch.allclose( + transformed_small, + expected_small, + equal_nan=True, + ), "Transform (in inference) does not handle NaNs correctly." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_transform_with_nan_inference(imputer_fixture, data_fixture, request): + """Check that the imputer correctly transforms a tensor with NaNs in inference.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + transformed = imputer.transform(x, in_place=False) + assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly." + # Split data to "inference size" removing "diagnostics" + x_small_in = x[..., [0, 1, 2, 3]] + x_small_out = x[..., [0, 1, 4]] + expected_small_in = expected[..., [0, 1, 2, 3]] + expected_small_out = expected[..., [0, 1, 4]] + transformed_small = imputer.transform(x_small_in, in_place=False) + assert torch.allclose( + transformed_small, + expected_small_in, + equal_nan=True, + ), "Transform (in inference) does not handle NaNs correctly." + # Check that the inverse also performs correctly + restored = imputer.inverse_transform(expected_small_out, in_place=False) + assert torch.allclose( + restored, x_small_out, equal_nan=True + ), "Inverse transform does not restore NaNs correctly in inference." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_transform_noop(imputer_fixture, data_fixture, request): + """Check that the imputer does not modify a tensor without NaNs.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + _ = imputer.transform(x) + transformed = imputer.transform(expected) + assert torch.allclose(transformed, expected), "Transform does not handle NaNs correctly." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_inverse_transform(imputer_fixture, data_fixture, request): + """Check that the imputer correctly inverts the transformation.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + transformed = imputer.transform(x, in_place=False) + assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly." + restored = imputer.inverse_transform(transformed, in_place=False) + assert torch.allclose(restored, x, equal_nan=True), "Inverse transform does not restore NaNs correctly." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_mask_saving(imputer_fixture, data_fixture, request): + """Check that the imputer saves the NaN mask correctly.""" + x, _ = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + expected_mask = torch.isnan(x) + imputer.transform(x) + assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) +def test_reuse_imputer(imputer_fixture, data_fixture, request): + """Check that the imputer reuses the mask correctly on subsequent runs.""" + x, expected = request.getfixturevalue(data_fixture) + imputer = request.getfixturevalue(imputer_fixture) + x2 = x**2.0 + _ = imputer.transform(x2, in_place=False) + transformed2 = imputer.transform(x, in_place=False) + assert torch.allclose( + transformed2, expected, equal_nan=True + ), "Imputer does not reuse mask correctly on subsequent runs." diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py new file mode 100644 index 00000000..787079d8 --- /dev/null +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -0,0 +1,88 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import numpy as np +import pytest +import torch +from omegaconf import DictConfig + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing.normalizer import InputNormalizer + + +@pytest.fixture() +def input_normalizer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "normalizer": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0]), + "stdev": np.array([0.5, 0.5, 0.5, 1, 14]), + "minimum": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + "maximum": np.array([11.0, 10.0, 10.0, 10.0, 10.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputNormalizer(config=config.data.normalizer, statistics=statistics, data_indices=data_indices) + + +def test_normalizer_not_inplace(input_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + input_normalizer(x, in_place=False) + assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])) + + +def test_normalizer_inplace(input_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + out = input_normalizer(x, in_place=True) + assert not torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]])) + assert torch.allclose(x, out) + + +def test_normalize(input_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + expected_output = torch.Tensor([[0.0, 0.2, 3.0, -0.5, 1 / 7], [0.5, 0.7, 8.0, 4.5, 0.5]]) + assert torch.allclose(input_normalizer.transform(x), expected_output) + + +def test_normalize_small(input_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + expected_output = torch.Tensor([[0.0, 0.2, 3.0, -0.5], [0.5, 0.7, 8.0, 4.5]]) + assert torch.allclose( + input_normalizer.transform(x[..., [0, 1, 2, 3]], data_index=[0, 1, 2, 3], in_place=False), + expected_output, + ) + assert torch.allclose(input_normalizer.transform(x[..., [0, 1, 2, 3]]), expected_output) + + +def test_inverse_transform_small(input_normalizer) -> None: + expected_output = torch.Tensor([[1.0, 2.0, 5.0], [6.0, 7.0, 10.0]]) + x = torch.Tensor([[0.0, 0.2, 1 / 7], [0.5, 0.7, 0.5]]) + assert torch.allclose(input_normalizer.inverse_transform(x, data_index=[0, 1, 4], in_place=False), expected_output) + assert torch.allclose(input_normalizer.inverse_transform(x), expected_output) + + +def test_inverse_transform(input_normalizer) -> None: + x = torch.Tensor([[0.0, 0.2, 3.0, -0.5, 1 / 7], [0.5, 0.7, 8.0, 4.5, 0.5]]) + expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + assert torch.allclose(input_normalizer.inverse_transform(x), expected_output) + + +def test_normalize_inverse_transform(input_normalizer) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]) + assert torch.allclose( + input_normalizer.inverse_transform(input_normalizer.transform(x, in_place=False), in_place=False), x + )