From c27d597cba3027de065d0397bcd6ac0448f78d02 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 30 May 2024 14:21:46 +0000 Subject: [PATCH 1/7] Include changes from aifs-mono: feature/graph_refactor --- src/anemoi/models/interface/__init__.py | 4 +- src/anemoi/models/layers/mapper.py | 23 +- src/anemoi/models/layers/processor.py | 5 +- .../models/encoder_processor_decoder.py | 197 +++++++++--------- tests/layers/mapper/test_base_mapper.py | 8 +- tests/layers/mapper/test_graphconv_mapper.py | 8 +- .../mapper/test_graphtransformer_mapper.py | 8 +- .../processor/test_graphconv_processor.py | 9 +- .../test_graphtransformer_processor.py | 9 +- 9 files changed, 127 insertions(+), 144 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index f4ff82cf..c15d2ca3 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -12,8 +12,8 @@ import torch from anemoi.utils.config import DotConfig from hydra.utils import instantiate -from torch_geometric.data import HeteroData +from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec from anemoi.models.preprocessing import Processors @@ -22,7 +22,7 @@ class AnemoiModelInterface(torch.nn.Module): """Anemoi model on torch level.""" def __init__( - self, *, config: DotConfig, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, *, config: DotConfig, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict ) -> None: super().__init__() self.config = config diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 9f5f90bf..4aaeb102 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -16,7 +16,6 @@ from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper from torch.distributed.distributed_c10d import ProcessGroup -from torch_geometric.data import HeteroData from torch_geometric.typing import Adj from torch_geometric.typing import PairTensor @@ -113,12 +112,12 @@ def pre_process(self, x, shard_shapes, model_comm_group=None): class GraphEdgeMixin: - def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, trainable_size: int) -> None: + def _register_edges(self, sub_graph: dict, src_size: int, dst_size: int, trainable_size: int) -> None: """Register edge dim, attr, index_base, and increment. Parameters ---------- - sub_graph : HeteroData + sub_graph : dict Sub graph of the full structure src_size : int Source size @@ -127,9 +126,9 @@ def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, t trainable_size : int Trainable tensor size """ - self.edge_dim = sub_graph.edge_attr.shape[1] + trainable_size - self.register_buffer("edge_attr", sub_graph.edge_attr, persistent=False) - self.register_buffer("edge_index_base", sub_graph.edge_index, persistent=False) + self.edge_dim = sub_graph["edge_attr"].shape[1] + trainable_size + self.register_buffer("edge_attr", sub_graph["edge_attr"], persistent=False) + self.register_buffer("edge_index_base", sub_graph["edge_index"], persistent=False) self.register_buffer( "edge_inc", torch.from_numpy(np.asarray([[src_size], [dst_size]], dtype=np.int64)), persistent=True ) @@ -173,7 +172,7 @@ def __init__( activation: str = "GELU", num_heads: int = 16, mlp_hidden_ratio: int = 4, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -273,7 +272,7 @@ def __init__( activation: str = "GELU", num_heads: int = 16, mlp_hidden_ratio: int = 4, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -344,7 +343,7 @@ def __init__( activation: str = "GELU", num_heads: int = 16, mlp_hidden_ratio: int = 4, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -414,7 +413,7 @@ def __init__( cpu_offload: bool = False, activation: str = "SiLU", mlp_extra_layers: int = 0, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -517,7 +516,7 @@ def __init__( cpu_offload: bool = False, activation: str = "SiLU", mlp_extra_layers: int = 0, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -601,7 +600,7 @@ def __init__( cpu_offload: bool = False, activation: str = "SiLU", mlp_extra_layers: int = 0, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 39a6f24a..f1e91f64 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -15,7 +15,6 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint -from torch_geometric.data import HeteroData from anemoi.models.distributed.graph import shard_tensor from anemoi.models.distributed.khop_edges import sort_edges_1hop @@ -170,7 +169,7 @@ def __init__( mlp_extra_layers: int = 0, activation: str = "SiLU", cpu_offload: bool = False, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, **kwargs, @@ -257,7 +256,7 @@ def __init__( mlp_hidden_ratio: int = 4, activation: str = "GELU", cpu_offload: bool = False, - sub_graph: Optional[HeteroData] = None, + sub_graph: Optional[dict] = None, src_grid_size: int = 0, dst_grid_size: int = 0, **kwargs, diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 3bab80e6..5fbf1b4e 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -18,7 +18,6 @@ from torch import nn from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint -from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import TrainableTensor @@ -34,7 +33,7 @@ def __init__( *, config: DotConfig, data_indices: dict, - graph_data: HeteroData, + graph_data: dict, ) -> None: """Initializes the graph neural network. @@ -42,14 +41,20 @@ def __init__( ---------- config : DictConfig Job configuration - graph_data : HeteroData + graph_data : dict Graph definition """ super().__init__() self._graph_data = graph_data - self._graph_name_data = config.graph.data - self._graph_name_hidden = config.graph.hidden + self._graph_name_hidden = config.graphs.hidden_mesh.name + self._graph_mesh_names = [name for name in graph_data if isinstance(name, str)] + self._graph_input_meshes = [ + k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self._graph_name_hidden and k[2] != k[0] + ] + self._graph_output_meshes = [ + k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self._graph_name_hidden and k[2] != k[0] + ] self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) @@ -58,48 +63,49 @@ def __init__( self._define_tensor_sizes(config) - # Create trainable tensors - self._create_trainable_attributes() - # Register lat/lon - self._register_latlon("data", self._graph_name_data) - self._register_latlon("hidden", self._graph_name_hidden) + for name in self._graph_mesh_names: + self._register_latlon(name) self.num_channels = config.model.num_channels - input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + input_dim = self.multi_step * self.num_input_channels # Encoder data -> hidden - self.encoder = instantiate( - config.model.encoder, - in_channels_src=input_dim, - in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, - hidden_dim=self.num_channels, - sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], - src_grid_size=self._data_grid_size, - dst_grid_size=self._hidden_grid_size, - ) + self.encoders = nn.ModuleDict() + for data in self._graph_input_meshes: + self.encoders[data] = instantiate( + config.model.encoder, + in_channels_src=input_dim + self.num_node_features[data] + self.num_trainable_params[data], + in_channels_dst=self.num_node_features[self._graph_name_hidden] + self.num_trainable_params[self._graph_name_hidden], + hidden_dim=self.num_channels, + sub_graph=self._graph_data[(data, "to", self._graph_name_hidden)], + src_grid_size=self.num_nodes[data], + dst_grid_size=self.num_nodes[self._graph_name_hidden], + ) # Processor hidden -> hidden self.processor = instantiate( config.model.processor, num_channels=self.num_channels, - sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._hidden_grid_size, + sub_graph=self._graph_data.get((self._graph_name_hidden, "to", self._graph_name_hidden), None), + src_grid_size=self.num_nodes[self._graph_name_hidden], + dst_grid_size=self.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data - self.decoder = instantiate( - config.model.decoder, - in_channels_src=self.num_channels, - in_channels_dst=input_dim, - hidden_dim=self.num_channels, - out_channels_dst=self.num_output_channels, - sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._data_grid_size, - ) + self.decoders = nn.ModuleDict() + for data in self._graph_output_meshes: + self.decoders[data] = instantiate( + config.model.decoder, + in_channels_src=self.num_channels, + in_channels_dst=input_dim + self.num_node_features[data] + self.num_trainable_params[data], + hidden_dim=self.num_channels, + out_channels_dst=self.num_output_channels, + sub_graph=self._graph_data[(self._graph_name_hidden, "to", data)], + src_grid_size=self.num_nodes[self._graph_name_hidden], + dst_grid_size=self.num_nodes[data], + ) def _calculate_shapes_and_indices(self, data_indices: dict) -> None: self.num_input_channels = len(data_indices.model.input) @@ -121,45 +127,30 @@ def _assert_matching_indices(self, data_indices: dict) -> None: def _define_tensor_sizes(self, config: DotConfig) -> None: # Define Sizes of different tensors - self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[ - 0 - ] - self._hidden_grid_size = self._graph_data[ - (self._graph_name_hidden, "to", self._graph_name_hidden) - ].hcoords_rad.shape[0] - - self.trainable_data_size = config.model.trainable_parameters.data - self.trainable_hidden_size = config.model.trainable_parameters.hidden + self.num_nodes = {name: self._graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} + self.num_node_features = {name: 2 * self._graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names} + self.num_trainable_params = { + name: config.model.trainable_parameters.get("data" if name != "hidden" else name, 0) for name in self._graph_mesh_names + } - def _register_latlon(self, name: str, key: str) -> None: + def _register_latlon(self, name: str) -> None: """Register lat/lon buffers. Parameters ---------- name : str Name of grid to map - key : str - Key of the grid """ + trainable_tensor = TrainableTensor( + trainable_size=self.num_trainable_params[name], tensor_size=self._graph_data[name]["coords"].shape[0] + ) + setattr(self, f"trainable_{name}", trainable_tensor) self.register_buffer( f"latlons_{name}", - torch.cat( - [ - torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), - torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), - ], - dim=-1, - ), + torch.cat([torch.sin(self._graph_data[name]["coords"]), torch.cos(self._graph_data[name]["coords"])], dim=-1), persistent=True, ) - def _create_trainable_attributes(self) -> None: - """Create all trainable attributes.""" - self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) - self.trainable_hidden = TrainableTensor( - trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size - ) - def _run_mapper( self, mapper: nn.Module, @@ -202,32 +193,41 @@ def _run_mapper( ) def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: - batch_size = x.shape[0] - ensemble_size = x.shape[2] + batch_size, _, ensemble_size, *_ = x.shape # add data positional info (lat/lon) - x_data_latent = torch.cat( - ( - einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - self.trainable_data(self.latlons_data, batch_size=batch_size), - ), - dim=-1, # feature dimension - ) + x_data_latent = {} + for data in self._graph_input_meshes: + x_data_latent[data] = torch.cat( + ( + einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), + getattr(self, f"trainable_{data}")(getattr(self, f"latlons_{data}"), batch_size=batch_size), + ), + dim=-1, # feature dimension + ) x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) # get shard shapes - shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) + shard_shapes_data = {} + for data in self._graph_input_meshes: + shard_shapes_data[data] = get_shape_shards(x_data_latent[data], 0, model_comm_group) shard_shapes_hidden = get_shape_shards(x_hidden_latent, 0, model_comm_group) - # Run encoder - x_data_latent, x_latent = self._run_mapper( - self.encoder, - (x_data_latent, x_hidden_latent), - batch_size=batch_size, - shard_shapes=(shard_shapes_data, shard_shapes_hidden), - model_comm_group=model_comm_group, - ) + # Run encoders + x_latents = [] + for data, encoder in self.encoders.items(): + x_data_latent[data], x_latent = self._run_mapper( + encoder, + (x_data_latent[data], x_hidden_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_data[data], shard_shapes_hidden), + model_comm_group=model_comm_group, + ) + x_latents.append(x_latent) + + # TODO: This operation can be a desing choice (sum, mean, attention, ...) + x_latent = torch.stack(x_latents).sum(dim=0) if len(x_latents) > 1 else x_latents[0] x_latent_proc = self.processor( x_latent, @@ -239,26 +239,29 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # add skip connection (hidden -> hidden) x_latent_proc = x_latent_proc + x_latent - # Run decoder - x_out = self._run_mapper( - self.decoder, - (x_latent_proc, x_data_latent), - batch_size=batch_size, - shard_shapes=(shard_shapes_hidden, shard_shapes_data), - model_comm_group=model_comm_group, - ) + # Run decoders + x_out = {} + for data, decoder in self.decoders.items(): + x_out[data] = self._run_mapper( + decoder, + (x_latent_proc, x_data_latent[data]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hidden, shard_shapes_data[data]), + model_comm_group=model_comm_group, + ) - x_out = ( - einops.rearrange( - x_out, - "(batch ensemble grid) vars -> batch ensemble grid vars", - batch=batch_size, - ensemble=ensemble_size, + x_out[data] = ( + einops.rearrange( + x_out[data], + "(batch ensemble grid) vars -> batch ensemble grid vars", + batch=batch_size, + ensemble=ensemble_size, + ) + .to(dtype=x.dtype) + .clone() ) - .to(dtype=x.dtype) - .clone() - ) - # residual connection (just for the prognostic variables) - x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] - return x_out + # residual connection (just for the prognostic variables) + if data in self._graph_input_meshes: + x_out[data][..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + return x_out[self._graph_output_meshes[0]] diff --git a/tests/layers/mapper/test_base_mapper.py b/tests/layers/mapper/test_base_mapper.py index 5b82b658..653ed9d2 100644 --- a/tests/layers/mapper/test_base_mapper.py +++ b/tests/layers/mapper/test_base_mapper.py @@ -7,7 +7,6 @@ import pytest import torch -from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import BaseMapper @@ -71,11 +70,8 @@ def pair_tensor(self, mapper_init): ) @pytest.fixture - def fake_graph(self): - graph = HeteroData() - graph.edge_attr = torch.rand((100, 128)) - graph.edge_index = torch.randint(0, 100, (2, 100)) - return graph + def fake_graph(self) -> dict: + return {"edge_attr": torch.rand((100, 128)), "edge_index": torch.randint(0, 100, (2, 100))} def test_initialization(self, mapper, mapper_init): ( diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 480a4948..422ceb3f 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -8,7 +8,6 @@ import pytest import torch from torch import nn -from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import GNNBackwardMapper from anemoi.models.layers.mapper import GNNBaseMapper @@ -77,11 +76,8 @@ def pair_tensor(self, mapper_init): ) @pytest.fixture - def fake_graph(self): - graph = HeteroData() - graph.edge_attr = torch.rand((self.GRID_SIZE, 128)) - graph.edge_index = torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE)) - return graph + def fake_graph(self) -> dict: + return {"edge_attr": torch.rand((self.GRID_SIZE, 128)), "edge_index": torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE))} def test_initialization(self, mapper, mapper_init): ( diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index 7fd0fc0c..18dc6d87 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -8,7 +8,6 @@ import pytest import torch from torch import nn -from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import GraphTransformerBackwardMapper from anemoi.models.layers.mapper import GraphTransformerBaseMapper @@ -87,11 +86,8 @@ def pair_tensor(self, mapper_init): ) @pytest.fixture - def fake_graph(self): - graph = HeteroData() - graph.edge_attr = torch.rand((self.GRID_SIZE, 128)) - graph.edge_index = torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE)) - return graph + def fake_graph(self) -> dict: + return {"edge_attr": torch.rand((self.GRID_SIZE, 128)), "edge_index": torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE))} def test_initialization(self, mapper, mapper_init): ( diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index 569319be..a1a54fb5 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -7,18 +7,15 @@ import pytest import torch -from torch_geometric.data import HeteroData from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GNNProcessor @pytest.fixture -def fake_graph(): - graph = HeteroData() - graph.edge_attr = torch.rand((100, 128)) - graph.edge_index = torch.randint(0, 100, (2, 100)) - return graph +def fake_graph() -> dict: + return {"edge_attr": torch.rand((100, 128)), "edge_index": torch.randint(0, 100, (2, 100))} + @pytest.fixture diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 81095a2e..6101c143 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -7,18 +7,15 @@ import pytest import torch -from torch_geometric.data import HeteroData from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GraphTransformerProcessor @pytest.fixture -def fake_graph(): - graph = HeteroData() - graph.edge_attr = torch.rand((100, 128)) - graph.edge_index = torch.randint(0, 100, (2, 100)) - return graph +def fake_graph() -> dict: + return {"edge_attr": torch.rand((100, 128)), "edge_index": torch.randint(0, 100, (2, 100))} + @pytest.fixture From 0a699078672c7d0e72c86acf2bc8706fb11f78d8 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 30 May 2024 14:31:02 +0000 Subject: [PATCH 2/7] pre-commit --- .../models/models/encoder_processor_decoder.py | 14 ++++++++++---- tests/layers/mapper/test_graphconv_mapper.py | 5 ++++- .../layers/mapper/test_graphtransformer_mapper.py | 5 ++++- tests/layers/processor/test_graphconv_processor.py | 1 - .../processor/test_graphtransformer_processor.py | 1 - 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 5fbf1b4e..4d66d523 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -77,7 +77,8 @@ def __init__( self.encoders[data] = instantiate( config.model.encoder, in_channels_src=input_dim + self.num_node_features[data] + self.num_trainable_params[data], - in_channels_dst=self.num_node_features[self._graph_name_hidden] + self.num_trainable_params[self._graph_name_hidden], + in_channels_dst=self.num_node_features[self._graph_name_hidden] + + self.num_trainable_params[self._graph_name_hidden], hidden_dim=self.num_channels, sub_graph=self._graph_data[(data, "to", self._graph_name_hidden)], src_grid_size=self.num_nodes[data], @@ -128,9 +129,12 @@ def _assert_matching_indices(self, data_indices: dict) -> None: def _define_tensor_sizes(self, config: DotConfig) -> None: # Define Sizes of different tensors self.num_nodes = {name: self._graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} - self.num_node_features = {name: 2 * self._graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names} + self.num_node_features = { + name: 2 * self._graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names + } self.num_trainable_params = { - name: config.model.trainable_parameters.get("data" if name != "hidden" else name, 0) for name in self._graph_mesh_names + name: config.model.trainable_parameters.get("data" if name != "hidden" else name, 0) + for name in self._graph_mesh_names } def _register_latlon(self, name: str) -> None: @@ -147,7 +151,9 @@ def _register_latlon(self, name: str) -> None: setattr(self, f"trainable_{name}", trainable_tensor) self.register_buffer( f"latlons_{name}", - torch.cat([torch.sin(self._graph_data[name]["coords"]), torch.cos(self._graph_data[name]["coords"])], dim=-1), + torch.cat( + [torch.sin(self._graph_data[name]["coords"]), torch.cos(self._graph_data[name]["coords"])], dim=-1 + ), persistent=True, ) diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 422ceb3f..7abcc816 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -77,7 +77,10 @@ def pair_tensor(self, mapper_init): @pytest.fixture def fake_graph(self) -> dict: - return {"edge_attr": torch.rand((self.GRID_SIZE, 128)), "edge_index": torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE))} + return { + "edge_attr": torch.rand((self.GRID_SIZE, 128)), + "edge_index": torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE)), + } def test_initialization(self, mapper, mapper_init): ( diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index 18dc6d87..f3a1cae9 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -87,7 +87,10 @@ def pair_tensor(self, mapper_init): @pytest.fixture def fake_graph(self) -> dict: - return {"edge_attr": torch.rand((self.GRID_SIZE, 128)), "edge_index": torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE))} + return { + "edge_attr": torch.rand((self.GRID_SIZE, 128)), + "edge_index": torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE)), + } def test_initialization(self, mapper, mapper_init): ( diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index a1a54fb5..a4d086ed 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -17,7 +17,6 @@ def fake_graph() -> dict: return {"edge_attr": torch.rand((100, 128)), "edge_index": torch.randint(0, 100, (2, 100))} - @pytest.fixture def graphconv_init(fake_graph): num_layers = 2 diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 6101c143..90468e2b 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -17,7 +17,6 @@ def fake_graph() -> dict: return {"edge_attr": torch.rand((100, 128)), "edge_index": torch.randint(0, 100, (2, 100))} - @pytest.fixture def graphtransformer_init(fake_graph): num_layers = 2 From d2644fba1e659b886e1cacbebbf62e3c3440d8e9 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 31 May 2024 10:54:18 +0000 Subject: [PATCH 3/7] fix: DotDict import --- src/anemoi/models/models/encoder_processor_decoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 4d66d523..3891b51b 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -12,7 +12,7 @@ import einops import torch -from anemoi.utils.config import DotConfig +from anemoi.utils.config import DotDict from hydra.utils import instantiate from torch import Tensor from torch import nn @@ -31,7 +31,7 @@ class AnemoiModelEncProcDec(nn.Module): def __init__( self, *, - config: DotConfig, + config: DotDict, data_indices: dict, graph_data: dict, ) -> None: @@ -126,7 +126,7 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotConfig) -> None: + def _define_tensor_sizes(self, config: DotDict) -> None: # Define Sizes of different tensors self.num_nodes = {name: self._graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} self.num_node_features = { From bb9bbd67a6ed4dd4969ec92dbaeaa50f2321308e Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Fri, 31 May 2024 10:57:07 +0000 Subject: [PATCH 4/7] fix: DotDict import --- src/anemoi/models/interface/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index c15d2ca3..85545e6e 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -10,7 +10,7 @@ import uuid import torch -from anemoi.utils.config import DotConfig +from anemoi.utils.config import DotDict from hydra.utils import instantiate from anemoi.models.data_indices.collection import IndexCollection @@ -22,7 +22,7 @@ class AnemoiModelInterface(torch.nn.Module): """Anemoi model on torch level.""" def __init__( - self, *, config: DotConfig, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict + self, *, config: DotDict, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict ) -> None: super().__init__() self.config = config From 8b63544ef39cf3b50ddaed077c26ebe69352c4d1 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 5 Jun 2024 10:03:33 +0000 Subject: [PATCH 5/7] fix: code review feedback --- .../models/encoder_processor_decoder.py | 112 +++++++++--------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 4d66d523..682e1058 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -19,6 +19,7 @@ from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint +from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import TrainableTensor @@ -32,7 +33,7 @@ def __init__( self, *, config: DotConfig, - data_indices: dict, + data_indices: IndexCollection, graph_data: dict, ) -> None: """Initializes the graph neural network. @@ -41,12 +42,13 @@ def __init__( ---------- config : DictConfig Job configuration + data_indices : IndexCollection + Data indices graph_data : dict Graph definition """ super().__init__() - self._graph_data = graph_data self._graph_name_hidden = config.graphs.hidden_mesh.name self._graph_mesh_names = [name for name in graph_data if isinstance(name, str)] self._graph_input_meshes = [ @@ -61,11 +63,13 @@ def __init__( self.multi_step = config.training.multistep_input - self._define_tensor_sizes(config) + self._define_tensor_sizes(config, graph_data) + + self._create_trainable_attributes() # Register lat/lon - for name in self._graph_mesh_names: - self._register_latlon(name) + for mesh_key in self._graph_mesh_names: + self._register_latlon(mesh_key, graph_data[mesh_key]["coords"]) self.num_channels = config.model.num_channels @@ -73,15 +77,15 @@ def __init__( # Encoder data -> hidden self.encoders = nn.ModuleDict() - for data in self._graph_input_meshes: - self.encoders[data] = instantiate( + for in_mesh in self._graph_input_meshes: + self.encoders[in_mesh] = instantiate( config.model.encoder, - in_channels_src=input_dim + self.num_node_features[data] + self.num_trainable_params[data], + in_channels_src=input_dim + self.num_node_features[in_mesh] + self.num_trainable_params[in_mesh], in_channels_dst=self.num_node_features[self._graph_name_hidden] + self.num_trainable_params[self._graph_name_hidden], hidden_dim=self.num_channels, - sub_graph=self._graph_data[(data, "to", self._graph_name_hidden)], - src_grid_size=self.num_nodes[data], + sub_graph=graph_data[(in_mesh, "to", self._graph_name_hidden)], + src_grid_size=self.num_nodes[in_mesh], dst_grid_size=self.num_nodes[self._graph_name_hidden], ) @@ -89,23 +93,23 @@ def __init__( self.processor = instantiate( config.model.processor, num_channels=self.num_channels, - sub_graph=self._graph_data.get((self._graph_name_hidden, "to", self._graph_name_hidden), None), + sub_graph=graph_data.get((self._graph_name_hidden, "to", self._graph_name_hidden), None), src_grid_size=self.num_nodes[self._graph_name_hidden], dst_grid_size=self.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data self.decoders = nn.ModuleDict() - for data in self._graph_output_meshes: - self.decoders[data] = instantiate( + for out_mesh in self._graph_output_meshes: + self.decoders[out_mesh] = instantiate( config.model.decoder, in_channels_src=self.num_channels, - in_channels_dst=input_dim + self.num_node_features[data] + self.num_trainable_params[data], + in_channels_dst=input_dim + self.num_node_features[out_mesh] + self.num_trainable_params[out_mesh], hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, - sub_graph=self._graph_data[(self._graph_name_hidden, "to", data)], + sub_graph=graph_data[(self._graph_name_hidden, "to", out_mesh)], src_grid_size=self.num_nodes[self._graph_name_hidden], - dst_grid_size=self.num_nodes[data], + dst_grid_size=self.num_nodes[out_mesh], ) def _calculate_shapes_and_indices(self, data_indices: dict) -> None: @@ -126,35 +130,35 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotConfig) -> None: + def _define_tensor_sizes(self, config: DotConfig, graph_data: dict) -> None: # Define Sizes of different tensors - self.num_nodes = {name: self._graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} - self.num_node_features = { - name: 2 * self._graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names - } + self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} + self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names} self.num_trainable_params = { - name: config.model.trainable_parameters.get("data" if name != "hidden" else name, 0) + name: config.model.trainable_parameters["data" if name != "hidden" else name] for name in self._graph_mesh_names } - def _register_latlon(self, name: str) -> None: + def _create_trainable_attributes(self) -> None: + """Create all trainable attributes.""" + self.trainable_tensors = nn.ModuleDict() + for mesh in self._graph_mesh_names: + self.trainable_tensors[mesh] = TrainableTensor( + trainable_size=self.num_trainable_params[mesh], tensor_size=self.num_nodes[mesh] + ) + + def _register_latlon(self, name: str, coords: torch.Tensor) -> None: """Register lat/lon buffers. Parameters ---------- name : str Name of grid to map + coords: torch.Tensor + Coordinates of the grid """ - trainable_tensor = TrainableTensor( - trainable_size=self.num_trainable_params[name], tensor_size=self._graph_data[name]["coords"].shape[0] - ) - setattr(self, f"trainable_{name}", trainable_tensor) self.register_buffer( - f"latlons_{name}", - torch.cat( - [torch.sin(self._graph_data[name]["coords"]), torch.cos(self._graph_data[name]["coords"])], dim=-1 - ), - persistent=True, + f"latlons_{name}", torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1), persistent=True ) def _run_mapper( @@ -199,40 +203,39 @@ def _run_mapper( ) def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: - batch_size, _, ensemble_size, *_ = x.shape + batch_size = x.shape[0] + ensemble_size = x.shape[2] # add data positional info (lat/lon) x_data_latent = {} - for data in self._graph_input_meshes: - x_data_latent[data] = torch.cat( + for in_mesh in self._graph_input_meshes: + x_data_latent[in_mesh] = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - getattr(self, f"trainable_{data}")(getattr(self, f"latlons_{data}"), batch_size=batch_size), + self.trainable_tensors[in_mesh](getattr(self, f"latlons_{in_mesh}"), batch_size=batch_size), ), dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.trainable_tensors[self._graph_name_hidden](self.latlons_hidden, batch_size=batch_size) # get shard shapes - shard_shapes_data = {} - for data in self._graph_input_meshes: - shard_shapes_data[data] = get_shape_shards(x_data_latent[data], 0, model_comm_group) + shard_shapes_data = {name: get_shape_shards(data, 0, model_comm_group) for name, data in x_data_latent.items()} shard_shapes_hidden = get_shape_shards(x_hidden_latent, 0, model_comm_group) # Run encoders x_latents = [] - for data, encoder in self.encoders.items(): - x_data_latent[data], x_latent = self._run_mapper( + for in_data_name, encoder in self.encoders.items(): + x_data_latent[in_data_name], x_latent = self._run_mapper( encoder, - (x_data_latent[data], x_hidden_latent), + (x_data_latent[in_data_name], x_hidden_latent), batch_size=batch_size, - shard_shapes=(shard_shapes_data[data], shard_shapes_hidden), + shard_shapes=(shard_shapes_data[in_data_name], shard_shapes_hidden), model_comm_group=model_comm_group, ) x_latents.append(x_latent) - # TODO: This operation can be a desing choice (sum, mean, attention, ...) + # TODO: This operation can be a design choice (sum, mean, attention, ...) x_latent = torch.stack(x_latents).sum(dim=0) if len(x_latents) > 1 else x_latents[0] x_latent_proc = self.processor( @@ -247,18 +250,18 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # Run decoders x_out = {} - for data, decoder in self.decoders.items(): - x_out[data] = self._run_mapper( + for out_data_name, decoder in self.decoders.items(): + x_out[out_data_name] = self._run_mapper( decoder, - (x_latent_proc, x_data_latent[data]), + (x_latent_proc, x_data_latent[out_data_name]), batch_size=batch_size, - shard_shapes=(shard_shapes_hidden, shard_shapes_data[data]), + shard_shapes=(shard_shapes_hidden, shard_shapes_data[out_data_name]), model_comm_group=model_comm_group, ) - x_out[data] = ( + x_out[out_data_name] = ( einops.rearrange( - x_out[data], + x_out[out_data_name], "(batch ensemble grid) vars -> batch ensemble grid vars", batch=batch_size, ensemble=ensemble_size, @@ -267,7 +270,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> .clone() ) - # residual connection (just for the prognostic variables) - if data in self._graph_input_meshes: - x_out[data][..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + if out_data_name in self._graph_input_meshes: # check if the mesh is in the input meshes + # residual connection (just for the prognostic variables) + x_out[out_data_name][..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + return x_out[self._graph_output_meshes[0]] From c84924dbafc7e77ae2a5fac4e27dab2f46fb4521 Mon Sep 17 00:00:00 2001 From: Matthew Chantry Date: Thu, 6 Jun 2024 12:43:49 +0000 Subject: [PATCH 6/7] Remove unnecessary anemoi-datasets dep --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2282b95f..7abfa130 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dependencies = [ "torch-geometric==2.4", "einops==0.6.1", "hydra-core==1.3", - "anemoi-datasets==0.2.1", "anemoi-utils==0.1.9", ] From dc0dc7d38089d4c65f2e402b3e0970dcd0667c1b Mon Sep 17 00:00:00 2001 From: Matthew Chantry Date: Thu, 6 Jun 2024 12:50:35 +0000 Subject: [PATCH 7/7] Relax anemoi-utils version dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7abfa130..0dd90b01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "torch-geometric==2.4", "einops==0.6.1", "hydra-core==1.3", - "anemoi-utils==0.1.9", + "anemoi-utils>=0.1.9", ] [project.optional-dependencies]