diff --git a/CHANGELOG.md b/CHANGELOG.md index a77dd6fa..661ed936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,18 @@ Keep it human-readable, your future self will thank you! ### Removed +## 0.2.0 + +### Added + +- Option to choose the edge attributes + +### Changed + +- Updated to support new PyTorch Geometric HeteroData structure (defined by `anemoi-graphs` package). + +### Removed + ## 0.1.0 Initial Release ### Added diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 9f5f90bf..04efdf0a 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. # +import logging from abc import ABC from typing import Optional @@ -30,6 +31,8 @@ from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.mlp import MLP +LOGGER = logging.getLogger(__name__) + class BaseMapper(nn.Module, ABC): """Base Mapper from souce dimension to destination dimension.""" @@ -113,13 +116,17 @@ def pre_process(self, x, shard_shapes, model_comm_group=None): class GraphEdgeMixin: - def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, trainable_size: int) -> None: + def _register_edges( + self, sub_graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int + ) -> None: """Register edge dim, attr, index_base, and increment. Parameters ---------- sub_graph : HeteroData Sub graph of the full structure + edge_attributes : list[str] + Edge attributes to use. src_size : int Source size dst_size : int @@ -127,8 +134,13 @@ def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, t trainable_size : int Trainable tensor size """ - self.edge_dim = sub_graph.edge_attr.shape[1] + trainable_size - self.register_buffer("edge_attr", sub_graph.edge_attr, persistent=False) + if edge_attributes is None: + raise ValueError("Edge attributes must be provided") + + edge_attr_tensor = torch.cat([sub_graph[attr] for attr in edge_attributes], axis=1) + + self.edge_dim = edge_attr_tensor.shape[1] + trainable_size + self.register_buffer("edge_attr", edge_attr_tensor, persistent=False) self.register_buffer("edge_index_base", sub_graph.edge_index, persistent=False) self.register_buffer( "edge_inc", torch.from_numpy(np.asarray([[src_size], [dst_size]], dtype=np.int64)), persistent=True @@ -174,6 +186,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -210,7 +223,7 @@ def __init__( activation=activation, ) - self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size) + self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) @@ -274,6 +287,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -312,6 +326,7 @@ def __init__( num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) @@ -345,6 +360,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -383,6 +399,7 @@ def __init__( num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) @@ -415,6 +432,7 @@ def __init__( activation: str = "SiLU", mlp_extra_layers: int = 0, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -451,7 +469,7 @@ def __init__( activation=activation, ) - self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size) + self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) self.emb_edges = MLP( in_features=self.edge_dim, @@ -518,6 +536,7 @@ def __init__( activation: str = "SiLU", mlp_extra_layers: int = 0, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -555,6 +574,7 @@ def __init__( activation, mlp_extra_layers, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) @@ -602,6 +622,7 @@ def __init__( activation: str = "SiLU", mlp_extra_layers: int = 0, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -639,6 +660,7 @@ def __init__( activation=activation, mlp_extra_layers=mlp_extra_layers, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 39a6f24a..bb336091 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -171,6 +171,7 @@ def __init__( activation: str = "SiLU", cpu_offload: bool = False, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, **kwargs, @@ -201,7 +202,7 @@ def __init__( mlp_extra_layers=mlp_extra_layers, ) - self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size) + self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) @@ -258,6 +259,7 @@ def __init__( activation: str = "GELU", cpu_offload: bool = False, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, **kwargs, @@ -291,7 +293,7 @@ def __init__( mlp_hidden_ratio=mlp_hidden_ratio, ) - self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size) + self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 633ea6a4..0f374742 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -42,6 +42,8 @@ def __init__( ---------- config : DotDict Job configuration + data_indices : dict + Data indices graph_data : HeteroData Graph definition """ @@ -61,7 +63,7 @@ def __init__( # Create trainable tensors self._create_trainable_attributes() - # Register lat/lon + # Register lat/lon of nodes self._register_latlon("data", self._graph_name_data) self._register_latlon("hidden", self._graph_name_hidden) @@ -120,38 +122,25 @@ def _assert_matching_indices(self, data_indices: dict) -> None: ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" def _define_tensor_sizes(self, config: DotDict) -> None: - # Define Sizes of different tensors - self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[ - 0 - ] - self._hidden_grid_size = self._graph_data[ - (self._graph_name_hidden, "to", self._graph_name_hidden) - ].hcoords_rad.shape[0] + self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes + self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes self.trainable_data_size = config.model.trainable_parameters.data self.trainable_hidden_size = config.model.trainable_parameters.hidden - def _register_latlon(self, name: str, key: str) -> None: + def _register_latlon(self, name: str, nodes: str) -> None: """Register lat/lon buffers. Parameters ---------- name : str - Name of grid to map - key : str - Key of the grid + Name to store the lat-lon coordinates of the nodes. + nodes : str + Name of nodes to map """ - self.register_buffer( - f"latlons_{name}", - torch.cat( - [ - torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), - torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]), - ], - dim=-1, - ), - persistent=True, - ) + coords = self._graph_data[nodes].x + sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) def _create_trainable_attributes(self) -> None: """Create all trainable attributes.""" diff --git a/tests/layers/mapper/test_base_mapper.py b/tests/layers/mapper/test_base_mapper.py index 5b82b658..3cc4ef0f 100644 --- a/tests/layers/mapper/test_base_mapper.py +++ b/tests/layers/mapper/test_base_mapper.py @@ -13,6 +13,12 @@ class TestBaseMapper: + """Test the BaseMapper class.""" + + NUM_EDGES: int = 100 + NUM_SRC_NODES: int = 100 + NUM_DST_NODES: int = 200 + @pytest.fixture def mapper_init(self): in_channels_src: int = 3 @@ -50,7 +56,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -71,10 +78,18 @@ def pair_tensor(self, mapper_init): ) @pytest.fixture - def fake_graph(self): + def fake_graph(self) -> HeteroData: + """Fake graph.""" graph = HeteroData() - graph.edge_attr = torch.rand((100, 128)) - graph.edge_index = torch.randint(0, 100, (2, 100)) + graph[("src", "to", "dst")].edge_index = torch.concat( + [ + torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES)), + torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES)), + ], + axis=0, + ) + graph[("src", "to", "dst")].edge_attr1 = torch.rand((self.NUM_EDGES, 1)) + graph[("src", "to", "dst")].edge_attr2 = torch.rand((self.NUM_EDGES, 32)) return graph def test_initialization(self, mapper, mapper_init): diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 480a4948..4be7130c 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -16,15 +16,18 @@ class TestGNNBaseMapper: - BIG_GRID_SIZE = 1000 - GRID_SIZE = 100 + """Test the GNNBaseMapper class.""" + + NUM_SRC_NODES: int = 200 + NUM_DST_NODES: int = 178 + NUM_EDGES: int = 300 @pytest.fixture def mapper_init(self): in_channels_src: int = 3 in_channels_dst: int = 4 hidden_dim: int = 256 - out_channels_dst: int = 5 + out_channels_dst: int = 8 cpu_offload: bool = False activation: str = "SiLU" trainable_size: int = 6 @@ -56,7 +59,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -72,15 +76,23 @@ def pair_tensor(self, mapper_init): _trainable_size, ) = mapper_init return ( - torch.rand(self.BIG_GRID_SIZE, in_channels_src), - torch.rand(self.GRID_SIZE, in_channels_dst), + torch.rand(self.NUM_SRC_NODES, in_channels_src), + torch.rand(self.NUM_DST_NODES, in_channels_dst), ) @pytest.fixture - def fake_graph(self): + def fake_graph(self) -> HeteroData: + """Fake graph.""" graph = HeteroData() - graph.edge_attr = torch.rand((self.GRID_SIZE, 128)) - graph.edge_index = torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE)) + graph[("src", "to", "dst")].edge_index = torch.concat( + [ + torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES)), + torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES)), + ], + axis=0, + ) + graph[("src", "to", "dst")].edge_attr1 = torch.rand((self.NUM_EDGES, 1)) + graph[("src", "to", "dst")].edge_attr2 = torch.rand((self.NUM_EDGES, 32)) return graph def test_initialization(self, mapper, mapper_init): @@ -138,6 +150,8 @@ def test_post_process(self, mapper, pair_tensor): class TestGNNForwardMapper(TestGNNBaseMapper): + """Test the GNNForwardMapper class.""" + @pytest.fixture def mapper(self, mapper_init, fake_graph): ( @@ -156,7 +170,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -174,16 +189,16 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) - assert x_src.shape == torch.Size([self.BIG_GRID_SIZE, hidden_dim]), ( + assert x_src.shape == torch.Size([self.NUM_SRC_NODES, hidden_dim]), ( f"x_src.shape ({x_src.shape}) != torch.Size" - f"([self.BIG_GRID_SIZE, hidden_dim]) ({torch.Size([self.BIG_GRID_SIZE, hidden_dim])})" + f"([self.NUM_SRC_NODES, hidden_dim]) ({torch.Size([self.NUM_SRC_NODES, hidden_dim])})" ) - assert x_dst.shape == torch.Size([self.GRID_SIZE, hidden_dim]), ( + assert x_dst.shape == torch.Size([self.NUM_DST_NODES, hidden_dim]), ( f"x_dst.shape ({x_dst.shape}) != torch.Size" - "([self.GRID_SIZE, hidden_dim]) ({torch.Size([self.GRID_SIZE, hidden_dim])})" + "([self.NUM_DST_NODES, hidden_dim]) ({torch.Size([self.NUM_DST_NODES, hidden_dim])})" ) - assert shapes_src == [[self.BIG_GRID_SIZE, hidden_dim]] - assert shapes_dst == [[self.GRID_SIZE, hidden_dim]] + assert shapes_src == [[self.NUM_SRC_NODES, hidden_dim]] + assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_forward_backward(self, mapper_init, mapper, pair_tensor): ( @@ -200,11 +215,11 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst = mapper.forward(x, batch_size, shard_shapes) - assert x_src.shape == torch.Size([self.BIG_GRID_SIZE, hidden_dim]) - assert x_dst.shape == torch.Size([self.GRID_SIZE, hidden_dim]) + assert x_src.shape == torch.Size([self.NUM_SRC_NODES, hidden_dim]) + assert x_dst.shape == torch.Size([self.NUM_DST_NODES, hidden_dim]) # Dummy loss - target = torch.rand(self.GRID_SIZE, hidden_dim) + target = torch.rand(self.NUM_DST_NODES, hidden_dim) loss_fn = nn.MSELoss() loss = loss_fn(x_dst, target) @@ -226,6 +241,8 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): class TestGNNBackwardMapper(TestGNNBaseMapper): + """Test the GNNBackwardMapper class.""" + @pytest.fixture def mapper(self, mapper_init, fake_graph): ( @@ -244,7 +261,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -262,16 +280,16 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) - assert x_src.shape == torch.Size([self.BIG_GRID_SIZE, in_channels_src]), ( + assert x_src.shape == torch.Size([self.NUM_SRC_NODES, in_channels_src]), ( f"x_src.shape ({x_src.shape}) != torch.Size" - f"([self.BIG_GRID_SIZE, in_channels_src]) ({torch.Size([self.BIG_GRID_SIZE, in_channels_src])})" + f"([self.NUM_SRC_NODES, in_channels_src]) ({torch.Size([self.NUM_SRC_NODES, in_channels_src])})" ) - assert x_dst.shape == torch.Size([self.GRID_SIZE, in_channels_dst]), ( + assert x_dst.shape == torch.Size([self.NUM_DST_NODES, in_channels_dst]), ( f"x_dst.shape ({x_dst.shape}) != torch.Size" - f"([self.GRID_SIZE, in_channels_dst]) ({torch.Size([self.GRID_SIZE, in_channels_dst])})" + f"([self.NUM_DST_NODES, in_channels_dst]) ({torch.Size([self.NUM_DST_NODES, in_channels_dst])})" ) - assert shapes_src == [[self.BIG_GRID_SIZE, hidden_dim]] - assert shapes_dst == [[self.GRID_SIZE, hidden_dim]] + assert shapes_src == [[self.NUM_SRC_NODES, hidden_dim]] + assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_post_process(self, mapper, mapper_init): ( @@ -283,13 +301,13 @@ def test_post_process(self, mapper, mapper_init): _activation, _trainable_size, ) = mapper_init - x_dst = torch.rand(self.GRID_SIZE, hidden_dim) + x_dst = torch.rand(self.NUM_DST_NODES, hidden_dim) shapes_dst = [list(x_dst.shape)] result = mapper.post_process(x_dst, shapes_dst) assert ( - torch.Size([self.GRID_SIZE, out_channels_dst]) == result.shape - ), f"[self.GRID_SIZE, out_channels_dst] ({[self.GRID_SIZE, out_channels_dst]}) != result.shape ({result.shape})" + torch.Size([self.NUM_DST_NODES, out_channels_dst]) == result.shape + ), f"[self.NUM_DST_NODES, out_channels_dst] ({[self.NUM_DST_NODES, out_channels_dst]}) != result.shape ({result.shape})" def test_forward_backward(self, mapper_init, mapper, pair_tensor): ( @@ -306,15 +324,15 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): batch_size = 1 x = ( - torch.rand(self.BIG_GRID_SIZE, hidden_dim), - torch.rand(self.GRID_SIZE, hidden_dim), + torch.rand(self.NUM_SRC_NODES, hidden_dim), + torch.rand(self.NUM_DST_NODES, hidden_dim), ) result = mapper.forward(x, batch_size, shard_shapes) - assert result.shape == torch.Size([self.GRID_SIZE, out_channels_dst]) + assert result.shape == torch.Size([self.NUM_DST_NODES, out_channels_dst]) # Dummy loss - target = torch.rand(self.GRID_SIZE, out_channels_dst) + target = torch.rand(self.NUM_DST_NODES, out_channels_dst) loss_fn = nn.MSELoss() loss = loss_fn(result, target) diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index 7fd0fc0c..c872422d 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -16,8 +16,11 @@ class TestGraphTransformerBaseMapper: - BIG_GRID_SIZE = 1000 - GRID_SIZE = 100 + """Test the GraphTransformerBaseMapper class.""" + + NUM_EDGES: int = 150 + NUM_SRC_NODES: int = 100 + NUM_DST_NODES: int = 200 @pytest.fixture def mapper_init(self): @@ -62,7 +65,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, @@ -82,15 +86,23 @@ def pair_tensor(self, mapper_init): _mlp_hidden_ratio, ) = mapper_init return ( - torch.rand(self.BIG_GRID_SIZE, in_channels_src), - torch.rand(self.GRID_SIZE, in_channels_dst), + torch.rand(self.NUM_SRC_NODES, in_channels_src), + torch.rand(self.NUM_DST_NODES, in_channels_dst), ) @pytest.fixture - def fake_graph(self): + def fake_graph(self) -> HeteroData: + """Fake graph.""" graph = HeteroData() - graph.edge_attr = torch.rand((self.GRID_SIZE, 128)) - graph.edge_index = torch.randint(0, self.GRID_SIZE, (2, self.GRID_SIZE)) + graph[("src", "to", "dst")].edge_index = torch.concat( + [ + torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES)), + torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES)), + ], + axis=0, + ) + graph[("src", "to", "dst")].edge_attr1 = torch.rand((self.NUM_EDGES, 1)) + graph[("src", "to", "dst")].edge_attr2 = torch.rand((self.NUM_EDGES, 32)) return graph def test_initialization(self, mapper, mapper_init): @@ -152,6 +164,8 @@ def test_post_process(self, mapper, pair_tensor): class TestGraphTransformerForwardMapper(TestGraphTransformerBaseMapper): + """Test the GraphTransformerForwardMapper class.""" + @pytest.fixture def mapper(self, mapper_init, fake_graph): ( @@ -172,7 +186,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, @@ -194,16 +209,16 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) - assert x_src.shape == torch.Size([self.BIG_GRID_SIZE, hidden_dim]), ( + assert x_src.shape == torch.Size([self.NUM_SRC_NODES, hidden_dim]), ( f"x_src.shape ({x_src.shape}) != torch.Size" - f"([self.BIG_GRID_SIZE, hidden_dim]) ({torch.Size([self.BIG_GRID_SIZE, hidden_dim])})" + f"([self.NUM_SRC_NODES, hidden_dim]) ({torch.Size([self.NUM_SRC_NODES, hidden_dim])})" ) - assert x_dst.shape == torch.Size([self.GRID_SIZE, hidden_dim]), ( + assert x_dst.shape == torch.Size([self.NUM_DST_NODES, hidden_dim]), ( f"x_dst.shape ({x_dst.shape}) != torch.Size" - "([self.GRID_SIZE, hidden_dim]) ({torch.Size([self.GRID_SIZE, hidden_dim])})" + "([self.NUM_DST_NODES, hidden_dim]) ({torch.Size([self.NUM_DST_NODES, hidden_dim])})" ) - assert shapes_src == [[self.BIG_GRID_SIZE, hidden_dim]] - assert shapes_dst == [[self.GRID_SIZE, hidden_dim]] + assert shapes_src == [[self.NUM_SRC_NODES, hidden_dim]] + assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_forward_backward(self, mapper_init, mapper, pair_tensor): ( @@ -222,11 +237,11 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst = mapper.forward(x, batch_size, shard_shapes) - assert x_src.shape == torch.Size([self.BIG_GRID_SIZE, in_channels_src]) - assert x_dst.shape == torch.Size([self.GRID_SIZE, hidden_dim]) + assert x_src.shape == torch.Size([self.NUM_SRC_NODES, in_channels_src]) + assert x_dst.shape == torch.Size([self.NUM_DST_NODES, hidden_dim]) # Dummy loss - target = torch.rand(self.GRID_SIZE, hidden_dim) + target = torch.rand(self.NUM_DST_NODES, hidden_dim) loss_fn = nn.MSELoss() loss = loss_fn(x_dst, target) @@ -248,6 +263,8 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): class TestGraphTransformerBackwardMapper(TestGraphTransformerBaseMapper): + """Test the GraphTransformerBackwardMapper class.""" + @pytest.fixture def mapper(self, mapper_init, fake_graph): ( @@ -268,7 +285,8 @@ def mapper(self, mapper_init, fake_graph): out_channels_dst=out_channels_dst, cpu_offload=cpu_offload, activation=activation, - sub_graph=fake_graph, + sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -288,16 +306,16 @@ def test_pre_process(self, mapper, mapper_init, pair_tensor): shard_shapes = [list(x[0].shape)], [list(x[1].shape)] x_src, x_dst, shapes_src, shapes_dst = mapper.pre_process(x, shard_shapes) - assert x_src.shape == torch.Size([self.BIG_GRID_SIZE, in_channels_src]), ( + assert x_src.shape == torch.Size([self.NUM_SRC_NODES, in_channels_src]), ( f"x_src.shape ({x_src.shape}) != torch.Size" - f"([self.BIG_GRID_SIZE, in_channels_src]) ({torch.Size([self.BIG_GRID_SIZE, in_channels_src])})" + f"([self.NUM_SRC_NODES, in_channels_src]) ({torch.Size([self.NUM_SRC_NODES, in_channels_src])})" ) - assert x_dst.shape == torch.Size([self.GRID_SIZE, hidden_dim]), ( + assert x_dst.shape == torch.Size([self.NUM_DST_NODES, hidden_dim]), ( f"x_dst.shape ({x_dst.shape}) != torch.Size" - f"([self.GRID_SIZE, hidden_dim]) ({torch.Size([self.GRID_SIZE, hidden_dim])})" + f"([self.NUM_DST_NODES, hidden_dim]) ({torch.Size([self.NUM_DST_NODES, hidden_dim])})" ) - assert shapes_src == [[self.BIG_GRID_SIZE, hidden_dim]] - assert shapes_dst == [[self.GRID_SIZE, hidden_dim]] + assert shapes_src == [[self.NUM_SRC_NODES, hidden_dim]] + assert shapes_dst == [[self.NUM_DST_NODES, hidden_dim]] def test_post_process(self, mapper, mapper_init): ( @@ -311,13 +329,13 @@ def test_post_process(self, mapper, mapper_init): _num_heads, _mlp_hidden_ratio, ) = mapper_init - x_dst = torch.rand(self.GRID_SIZE, hidden_dim) + x_dst = torch.rand(self.NUM_DST_NODES, hidden_dim) shapes_dst = [list(x_dst.shape)] result = mapper.post_process(x_dst, shapes_dst) assert ( - torch.Size([self.GRID_SIZE, out_channels_dst]) == result.shape - ), f"[self.GRID_SIZE, out_channels_dst] ({[self.GRID_SIZE, out_channels_dst]}) != result.shape ({result.shape})" + torch.Size([self.NUM_DST_NODES, out_channels_dst]) == result.shape + ), f"[self.NUM_DST_NODES, out_channels_dst] ({[self.NUM_DST_NODES, out_channels_dst]}) != result.shape ({result.shape})" def test_forward_backward(self, mapper_init, mapper, pair_tensor): ( @@ -337,15 +355,15 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): # Different size for x_dst, as the Backward mapper changes the channels in shape in pre-processor x = ( - torch.rand(self.BIG_GRID_SIZE, hidden_dim), - torch.rand(self.GRID_SIZE, in_channels_src), + torch.rand(self.NUM_SRC_NODES, hidden_dim), + torch.rand(self.NUM_DST_NODES, in_channels_src), ) result = mapper.forward(x, batch_size, shard_shapes) - assert result.shape == torch.Size([self.GRID_SIZE, out_channels_dst]) + assert result.shape == torch.Size([self.NUM_DST_NODES, out_channels_dst]) # Dummy loss - target = torch.rand(self.GRID_SIZE, out_channels_dst) + target = torch.rand(self.NUM_DST_NODES, out_channels_dst) loss_fn = nn.MSELoss() loss = loss_fn(result, target) diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index 569319be..2505515c 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -13,129 +13,135 @@ from anemoi.models.layers.processor import GNNProcessor -@pytest.fixture -def fake_graph(): - graph = HeteroData() - graph.edge_attr = torch.rand((100, 128)) - graph.edge_index = torch.randint(0, 100, (2, 100)) - return graph - - -@pytest.fixture -def graphconv_init(fake_graph): - num_layers = 2 - num_channels = 128 - num_chunks = 2 - mlp_extra_layers = 0 - activation = "SiLU" - cpu_offload = False - sub_graph = fake_graph - src_grid_size = 0 - dst_grid_size = 0 - trainable_size = 8 - return ( - num_layers, - num_channels, - num_chunks, - mlp_extra_layers, - activation, - cpu_offload, - sub_graph, - src_grid_size, - dst_grid_size, - trainable_size, - ) - - -@pytest.fixture -def graphconv_processor(graphconv_init): - ( - num_layers, - num_channels, - num_chunks, - mlp_extra_layers, - activation, - cpu_offload, - sub_graph, - src_grid_size, - dst_grid_size, - trainable_size, - ) = graphconv_init - return GNNProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - mlp_extra_layers=mlp_extra_layers, - activation=activation, - cpu_offload=cpu_offload, - sub_graph=sub_graph, - src_grid_size=src_grid_size, - dst_grid_size=dst_grid_size, - trainable_size=trainable_size, - ) - - -def test_graphconv_processor_init(graphconv_processor, graphconv_init): - ( - num_layers, - num_channels, - num_chunks, - _mlp_extra_layers, - _activation, - _cpu_offload, - _sub_graph, - _src_grid_size, - _dst_grid_size, - _trainable_size, - ) = graphconv_init - assert graphconv_processor.num_chunks == num_chunks - assert graphconv_processor.num_channels == num_channels - assert graphconv_processor.chunk_size == num_layers // num_chunks - assert isinstance(graphconv_processor.trainable, TrainableTensor) - - -def test_forward(graphconv_processor, graphconv_init): - gridpoints = 100 - batch_size = 1 - ( - _num_layers, - num_channels, - _num_chunks, - _mlp_extra_layers, - _activation, - _cpu_offload, - _sub_graph, - _src_grid_size, - _dst_grid_size, - trainable_size, - ) = graphconv_init - x = torch.rand((gridpoints, num_channels)) - shard_shapes = [list(x.shape)] - - # Run forward pass of processor - output = graphconv_processor.forward(x, batch_size, shard_shapes) - assert output.shape == (gridpoints, num_channels) - - # Generate dummy target and loss function - loss_fn = torch.nn.MSELoss() - target = torch.rand((gridpoints, num_channels)) - loss = loss_fn(output, target) - - # Check loss - assert loss.item() >= 0 - - # Backward pass - loss.backward() - - # Check gradients of trainable tensor - assert graphconv_processor.trainable.trainable.grad.shape == ( - gridpoints, - trainable_size, - ) - - # Check gradients of processor - for param in graphconv_processor.parameters(): - assert param.grad is not None, f"param.grad is None for {param}" - assert ( - param.grad.shape == param.shape - ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" +class TestGNNProcessor: + """Test the GNNProcessor class.""" + + NUM_NODES: int = 100 + NUM_EDGES: int = 200 + + @pytest.fixture + def fake_graph(self) -> tuple[HeteroData, int]: + graph = HeteroData() + graph["nodes"].x = torch.rand((self.NUM_NODES, 2)) + graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, self.NUM_NODES, (2, self.NUM_EDGES)) + graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3)) + graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4)) + return graph + + @pytest.fixture + def graphconv_init(self, fake_graph: HeteroData): + num_layers = 2 + num_channels = 128 + num_chunks = 2 + mlp_extra_layers = 0 + activation = "SiLU" + cpu_offload = False + sub_graph = fake_graph[("nodes", "to", "nodes")] + edge_attributes = ["edge_attr1", "edge_attr2"] + src_grid_size = 0 + dst_grid_size = 0 + trainable_size = 8 + return ( + num_layers, + num_channels, + num_chunks, + mlp_extra_layers, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) + + @pytest.fixture + def graphconv_processor(self, graphconv_init): + ( + num_layers, + num_channels, + num_chunks, + mlp_extra_layers, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) = graphconv_init + return GNNProcessor( + num_layers, + num_channels=num_channels, + num_chunks=num_chunks, + mlp_extra_layers=mlp_extra_layers, + activation=activation, + cpu_offload=cpu_offload, + sub_graph=sub_graph, + sub_graph_edge_attributes=edge_attributes, + src_grid_size=src_grid_size, + dst_grid_size=dst_grid_size, + trainable_size=trainable_size, + ) + + def test_graphconv_processor_init(self, graphconv_processor, graphconv_init): + ( + num_layers, + num_channels, + num_chunks, + _mlp_extra_layers, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + _trainable_size, + ) = graphconv_init + assert graphconv_processor.num_chunks == num_chunks + assert graphconv_processor.num_channels == num_channels + assert graphconv_processor.chunk_size == num_layers // num_chunks + assert isinstance(graphconv_processor.trainable, TrainableTensor) + + def test_forward(self, graphconv_processor, graphconv_init): + batch_size = 1 + ( + _num_layers, + num_channels, + _num_chunks, + _mlp_extra_layers, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + trainable_size, + ) = graphconv_init + x = torch.rand((self.NUM_EDGES, num_channels)) + shard_shapes = [list(x.shape)] + + # Run forward pass of processor + output = graphconv_processor.forward(x, batch_size, shard_shapes) + assert output.shape == (self.NUM_EDGES, num_channels) + + # Generate dummy target and loss function + loss_fn = torch.nn.MSELoss() + target = torch.rand((self.NUM_EDGES, num_channels)) + loss = loss_fn(output, target) + + # Check loss + assert loss.item() >= 0 + + # Backward pass + loss.backward() + + # Check gradients of trainable tensor + assert graphconv_processor.trainable.trainable.grad.shape == (self.NUM_EDGES, trainable_size) + + # Check gradients of processor + for param in graphconv_processor.parameters(): + assert param.grad is not None, f"param.grad is None for {param}" + assert ( + param.grad.shape == param.shape + ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 81095a2e..dfba417e 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -13,135 +13,144 @@ from anemoi.models.layers.processor import GraphTransformerProcessor -@pytest.fixture -def fake_graph(): - graph = HeteroData() - graph.edge_attr = torch.rand((100, 128)) - graph.edge_index = torch.randint(0, 100, (2, 100)) - return graph - - -@pytest.fixture -def graphtransformer_init(fake_graph): - num_layers = 2 - num_channels = 128 - num_chunks = 2 - num_heads = 16 - mlp_hidden_ratio = 4 - activation = "GELU" - cpu_offload = False - sub_graph = fake_graph - src_grid_size = 0 - dst_grid_size = 0 - trainable_size = 6 - return ( - num_layers, - num_channels, - num_chunks, - num_heads, - mlp_hidden_ratio, - activation, - cpu_offload, - sub_graph, - src_grid_size, - dst_grid_size, - trainable_size, - ) - - -@pytest.fixture -def graphtransformer_processor(graphtransformer_init): - ( - num_layers, - num_channels, - num_chunks, - num_heads, - mlp_hidden_ratio, - activation, - cpu_offload, - sub_graph, - src_grid_size, - dst_grid_size, - trainable_size, - ) = graphtransformer_init - return GraphTransformerProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, - activation=activation, - cpu_offload=cpu_offload, - sub_graph=sub_graph, - src_grid_size=src_grid_size, - dst_grid_size=dst_grid_size, - trainable_size=trainable_size, - ) - - -def test_graphtransformer_processor_init(graphtransformer_processor, graphtransformer_init): - ( - num_layers, - num_channels, - num_chunks, - _num_heads, - _mlp_hidden_ratio, - _activation, - _cpu_offload, - _sub_graph, - _src_grid_size, - _dst_grid_size, - _trainable_size, - ) = graphtransformer_init - assert graphtransformer_processor.num_chunks == num_chunks - assert graphtransformer_processor.num_channels == num_channels - assert graphtransformer_processor.chunk_size == num_layers // num_chunks - assert isinstance(graphtransformer_processor.trainable, TrainableTensor) - - -def test_forward(graphtransformer_processor, graphtransformer_init): - gridpoints = 100 - batch_size = 1 - ( - _num_layers, - num_channels, - _num_chunks, - _num_heads, - _mlp_hidden_ratio, - _activation, - _cpu_offload, - _sub_graph, - _src_grid_size, - _dst_grid_size, - trainable_size, - ) = graphtransformer_init - x = torch.rand((gridpoints, num_channels)) - shard_shapes = [list(x.shape)] - - # Run forward pass of processor - output = graphtransformer_processor.forward(x, batch_size, shard_shapes) - assert output.shape == (gridpoints, num_channels) - - # Generate dummy target and loss function - loss_fn = torch.nn.MSELoss() - target = torch.rand((gridpoints, num_channels)) - loss = loss_fn(output, target) - - # Check loss - assert loss.item() >= 0 - - # Backward pass - loss.backward() - - # Check gradients of trainable tensor - assert graphtransformer_processor.trainable.trainable.grad.shape == ( - gridpoints, - trainable_size, - ) - - # Check gradients of processor - for param in graphtransformer_processor.parameters(): - assert param.grad is not None, f"param.grad is None for {param}" - assert ( - param.grad.shape == param.shape - ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" +class TestGraphTransformerProcessor: + """Test the GraphTransformerProcessor class.""" + + NUM_NODES: int = 100 + NUM_EDGES: int = 200 + + @pytest.fixture + def fake_graph(self) -> tuple[HeteroData, int]: + graph = HeteroData() + graph["nodes"].x = torch.rand((self.NUM_NODES, 2)) + graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, self.NUM_NODES, (2, self.NUM_EDGES)) + graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3)) + graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4)) + return graph + + @pytest.fixture + def graphtransformer_init(self, fake_graph: HeteroData): + num_layers = 2 + num_channels = 128 + num_chunks = 2 + num_heads = 16 + mlp_hidden_ratio = 4 + activation = "GELU" + cpu_offload = False + sub_graph = fake_graph[("nodes", "to", "nodes")] + edge_attributes = ["edge_attr1", "edge_attr2"] + src_grid_size = 0 + dst_grid_size = 0 + trainable_size = 6 + return ( + num_layers, + num_channels, + num_chunks, + num_heads, + mlp_hidden_ratio, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) + + @pytest.fixture + def graphtransformer_processor(self, graphtransformer_init): + ( + num_layers, + num_channels, + num_chunks, + num_heads, + mlp_hidden_ratio, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) = graphtransformer_init + return GraphTransformerProcessor( + num_layers, + num_channels=num_channels, + num_chunks=num_chunks, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + activation=activation, + cpu_offload=cpu_offload, + sub_graph=sub_graph, + sub_graph_edge_attributes=edge_attributes, + src_grid_size=src_grid_size, + dst_grid_size=dst_grid_size, + trainable_size=trainable_size, + ) + + def test_graphtransformer_processor_init(self, graphtransformer_processor, graphtransformer_init): + ( + num_layers, + num_channels, + num_chunks, + _num_heads, + _mlp_hidden_ratio, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + _trainable_size, + ) = graphtransformer_init + assert graphtransformer_processor.num_chunks == num_chunks + assert graphtransformer_processor.num_channels == num_channels + assert graphtransformer_processor.chunk_size == num_layers // num_chunks + assert isinstance(graphtransformer_processor.trainable, TrainableTensor) + + def test_forward(self, graphtransformer_processor, graphtransformer_init): + batch_size = 1 + ( + _num_layers, + num_channels, + _num_chunks, + _num_heads, + _mlp_hidden_ratio, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + trainable_size, + ) = graphtransformer_init + x = torch.rand((self.NUM_EDGES, num_channels)) + shard_shapes = [list(x.shape)] + + # Run forward pass of processor + output = graphtransformer_processor.forward(x, batch_size, shard_shapes) + assert output.shape == (self.NUM_EDGES, num_channels) + + # Generate dummy target and loss function + loss_fn = torch.nn.MSELoss() + target = torch.rand((self.NUM_EDGES, num_channels)) + loss = loss_fn(output, target) + + # Check loss + assert loss.item() >= 0 + + # Backward pass + loss.backward() + + # Check gradients of trainable tensor + assert graphtransformer_processor.trainable.trainable.grad.shape == ( + self.NUM_EDGES, + trainable_size, + ) + + # Check gradients of processor + for param in graphtransformer_processor.parameters(): + assert param.grad is not None, f"param.grad is None for {param}" + assert ( + param.grad.shape == param.shape + ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}"