diff --git a/tests/layers/mapper/test_base_mapper.py b/tests/layers/mapper/test_base_mapper.py index c9e051f8..a229b1e7 100644 --- a/tests/layers/mapper/test_base_mapper.py +++ b/tests/layers/mapper/test_base_mapper.py @@ -13,9 +13,10 @@ class TestBaseMapper: - NUM_EDGES = 100 - NUM_SRC_NODES = 100 - NUM_DST_NODES = 200 + """Test the BaseMapper class.""" + NUM_EDGES: int = 100 + NUM_SRC_NODES: int = 100 + NUM_DST_NODES: int = 200 @pytest.fixture def mapper_init(self): diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 9fd34fbf..1e4986a9 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -16,9 +16,10 @@ class TestGNNBaseMapper: - NUM_SRC_NODES = 200 - NUM_DST_NODES = 178 - NUM_EDGES = 300 + """Test the GNNBaseMapper class.""" + NUM_SRC_NODES: int = 200 + NUM_DST_NODES: int = 178 + NUM_EDGES: int = 300 @pytest.fixture def mapper_init(self): @@ -148,6 +149,7 @@ def test_post_process(self, mapper, pair_tensor): class TestGNNForwardMapper(TestGNNBaseMapper): + """Test the GNNForwardMapper class.""" @pytest.fixture def mapper(self, mapper_init, fake_graph): ( @@ -237,6 +239,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): class TestGNNBackwardMapper(TestGNNBaseMapper): + """Test the GNNBackwardMapper class.""" @pytest.fixture def mapper(self, mapper_init, fake_graph): ( diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index b3386a23..86e6be87 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -16,9 +16,10 @@ class TestGraphTransformerBaseMapper: - NUM_EDGES = 150 - NUM_SRC_NODES = 100 - NUM_DST_NODES = 200 + """Test the GraphTransformerBaseMapper class.""" + NUM_EDGES: int = 150 + NUM_SRC_NODES: int = 100 + NUM_DST_NODES: int = 200 @pytest.fixture def mapper_init(self): @@ -162,6 +163,7 @@ def test_post_process(self, mapper, pair_tensor): class TestGraphTransformerForwardMapper(TestGraphTransformerBaseMapper): + """Test the GraphTransformerForwardMapper class.""" @pytest.fixture def mapper(self, mapper_init, fake_graph): ( @@ -259,6 +261,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor): class TestGraphTransformerBackwardMapper(TestGraphTransformerBaseMapper): + """Test the GraphTransformerBackwardMapper class.""" @pytest.fixture def mapper(self, mapper_init, fake_graph): ( diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index ecfc15e8..84938a73 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -12,137 +12,135 @@ from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GNNProcessor -num_edges = 200 - -@pytest.fixture -def fake_graph() -> HeteroData: - num_nodes = 100 - graph = HeteroData() - graph["nodes"].x = torch.rand((num_nodes, 2)) - graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, num_nodes, (2, num_edges)) - graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((num_edges, 3)) - graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((num_edges, 4)) - return graph - - -@pytest.fixture -def graphconv_init(fake_graph: HeteroData): - num_layers = 2 - num_channels = 128 - num_chunks = 2 - mlp_extra_layers = 0 - activation = "SiLU" - cpu_offload = False - sub_graph = fake_graph[("nodes", "to", "nodes")] - edge_attributes = ["edge_attr1", "edge_attr2"] - src_grid_size = 0 - dst_grid_size = 0 - trainable_size = 8 - return ( - num_layers, - num_channels, - num_chunks, - mlp_extra_layers, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, - ) - - -@pytest.fixture -def graphconv_processor(graphconv_init): - ( - num_layers, - num_channels, - num_chunks, - mlp_extra_layers, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, - ) = graphconv_init - return GNNProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - mlp_extra_layers=mlp_extra_layers, - activation=activation, - cpu_offload=cpu_offload, - sub_graph=sub_graph, - sub_graph_edge_attributes=edge_attributes, - src_grid_size=src_grid_size, - dst_grid_size=dst_grid_size, - trainable_size=trainable_size, - ) - - -def test_graphconv_processor_init(graphconv_processor, graphconv_init): - ( - num_layers, - num_channels, - num_chunks, - _mlp_extra_layers, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - _trainable_size, - ) = graphconv_init - assert graphconv_processor.num_chunks == num_chunks - assert graphconv_processor.num_channels == num_channels - assert graphconv_processor.chunk_size == num_layers // num_chunks - assert isinstance(graphconv_processor.trainable, TrainableTensor) - - -def test_forward(graphconv_processor, graphconv_init): - batch_size = 1 - ( - _num_layers, - num_channels, - _num_chunks, - _mlp_extra_layers, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - trainable_size, - ) = graphconv_init - x = torch.rand((num_edges, num_channels)) - shard_shapes = [list(x.shape)] - - # Run forward pass of processor - output = graphconv_processor.forward(x, batch_size, shard_shapes) - assert output.shape == (num_edges, num_channels) - - # Generate dummy target and loss function - loss_fn = torch.nn.MSELoss() - target = torch.rand((num_edges, num_channels)) - loss = loss_fn(output, target) - - # Check loss - assert loss.item() >= 0 - - # Backward pass - loss.backward() - - # Check gradients of trainable tensor - assert graphconv_processor.trainable.trainable.grad.shape == (num_edges, trainable_size) - - # Check gradients of processor - for param in graphconv_processor.parameters(): - assert param.grad is not None, f"param.grad is None for {param}" - assert ( - param.grad.shape == param.shape - ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" +class TestGNNProcessor: + """Test the GNNProcessor class.""" + NUM_NODES: int = 100 + NUM_EDGES: int = 200 + + @pytest.fixture + def fake_graph(self) -> tuple[HeteroData, int]: + graph = HeteroData() + graph["nodes"].x = torch.rand((self.NUM_NODES, 2)) + graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, self.NUM_NODES, (2, self.NUM_EDGES)) + graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3)) + graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4)) + return graph + + @pytest.fixture + def graphconv_init(self, fake_graph: HeteroData): + num_layers = 2 + num_channels = 128 + num_chunks = 2 + mlp_extra_layers = 0 + activation = "SiLU" + cpu_offload = False + sub_graph = fake_graph[("nodes", "to", "nodes")] + edge_attributes = ["edge_attr1", "edge_attr2"] + src_grid_size = 0 + dst_grid_size = 0 + trainable_size = 8 + return ( + num_layers, + num_channels, + num_chunks, + mlp_extra_layers, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) + + @pytest.fixture + def graphconv_processor(self, graphconv_init): + ( + num_layers, + num_channels, + num_chunks, + mlp_extra_layers, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) = graphconv_init + return GNNProcessor( + num_layers, + num_channels=num_channels, + num_chunks=num_chunks, + mlp_extra_layers=mlp_extra_layers, + activation=activation, + cpu_offload=cpu_offload, + sub_graph=sub_graph, + sub_graph_edge_attributes=edge_attributes, + src_grid_size=src_grid_size, + dst_grid_size=dst_grid_size, + trainable_size=trainable_size, + ) + + def test_graphconv_processor_init(self, graphconv_processor, graphconv_init): + ( + num_layers, + num_channels, + num_chunks, + _mlp_extra_layers, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + _trainable_size, + ) = graphconv_init + assert graphconv_processor.num_chunks == num_chunks + assert graphconv_processor.num_channels == num_channels + assert graphconv_processor.chunk_size == num_layers // num_chunks + assert isinstance(graphconv_processor.trainable, TrainableTensor) + + def test_forward(self, graphconv_processor, graphconv_init): + batch_size = 1 + ( + _num_layers, + num_channels, + _num_chunks, + _mlp_extra_layers, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + trainable_size, + ) = graphconv_init + x = torch.rand((self.NUM_EDGES, num_channels)) + shard_shapes = [list(x.shape)] + + # Run forward pass of processor + output = graphconv_processor.forward(x, batch_size, shard_shapes) + assert output.shape == (self.NUM_EDGES, num_channels) + + # Generate dummy target and loss function + loss_fn = torch.nn.MSELoss() + target = torch.rand((self.NUM_EDGES, num_channels)) + loss = loss_fn(output, target) + + # Check loss + assert loss.item() >= 0 + + # Backward pass + loss.backward() + + # Check gradients of trainable tensor + assert graphconv_processor.trainable.trainable.grad.shape == (self.NUM_EDGES, trainable_size) + + # Check gradients of processor + for param in graphconv_processor.parameters(): + assert param.grad is not None, f"param.grad is None for {param}" + assert ( + param.grad.shape == param.shape + ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 06a392d8..bb39fae0 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -12,146 +12,145 @@ from anemoi.models.layers.graph import TrainableTensor from anemoi.models.layers.processor import GraphTransformerProcessor -num_edges: int = 200 - - -@pytest.fixture -def fake_graph() -> HeteroData: - num_nodes = 100 - graph = HeteroData() - graph["nodes"].x = torch.rand((num_nodes, 2)) - graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, num_nodes, (2, num_edges)) - graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((num_edges, 3)) - graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((num_edges, 4)) - return graph - - -@pytest.fixture -def graphtransformer_init(fake_graph: HeteroData): - num_layers = 2 - num_channels = 128 - num_chunks = 2 - num_heads = 16 - mlp_hidden_ratio = 4 - activation = "GELU" - cpu_offload = False - sub_graph = fake_graph[("nodes", "to", "nodes")] - edge_attributes = ["edge_attr1", "edge_attr2"] - src_grid_size = 0 - dst_grid_size = 0 - trainable_size = 6 - return ( - num_layers, - num_channels, - num_chunks, - num_heads, - mlp_hidden_ratio, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, - ) - - -@pytest.fixture -def graphtransformer_processor(graphtransformer_init): - ( - num_layers, - num_channels, - num_chunks, - num_heads, - mlp_hidden_ratio, - activation, - cpu_offload, - sub_graph, - edge_attributes, - src_grid_size, - dst_grid_size, - trainable_size, - ) = graphtransformer_init - return GraphTransformerProcessor( - num_layers, - num_channels=num_channels, - num_chunks=num_chunks, - num_heads=num_heads, - mlp_hidden_ratio=mlp_hidden_ratio, - activation=activation, - cpu_offload=cpu_offload, - sub_graph=sub_graph, - sub_graph_edge_attributes=edge_attributes, - src_grid_size=src_grid_size, - dst_grid_size=dst_grid_size, - trainable_size=trainable_size, - ) - - -def test_graphtransformer_processor_init(graphtransformer_processor, graphtransformer_init): - ( - num_layers, - num_channels, - num_chunks, - _num_heads, - _mlp_hidden_ratio, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - _trainable_size, - ) = graphtransformer_init - assert graphtransformer_processor.num_chunks == num_chunks - assert graphtransformer_processor.num_channels == num_channels - assert graphtransformer_processor.chunk_size == num_layers // num_chunks - assert isinstance(graphtransformer_processor.trainable, TrainableTensor) - - -def test_forward(graphtransformer_processor, graphtransformer_init): - batch_size = 1 - ( - _num_layers, - num_channels, - _num_chunks, - _num_heads, - _mlp_hidden_ratio, - _activation, - _cpu_offload, - _sub_graph, - _edge_attributes, - _src_grid_size, - _dst_grid_size, - trainable_size, - ) = graphtransformer_init - x = torch.rand((num_edges, num_channels)) - shard_shapes = [list(x.shape)] - - # Run forward pass of processor - output = graphtransformer_processor.forward(x, batch_size, shard_shapes) - assert output.shape == (num_edges, num_channels) - - # Generate dummy target and loss function - loss_fn = torch.nn.MSELoss() - target = torch.rand((num_edges, num_channels)) - loss = loss_fn(output, target) - - # Check loss - assert loss.item() >= 0 - - # Backward pass - loss.backward() - - # Check gradients of trainable tensor - assert graphtransformer_processor.trainable.trainable.grad.shape == ( - num_edges, - trainable_size, - ) - - # Check gradients of processor - for param in graphtransformer_processor.parameters(): - assert param.grad is not None, f"param.grad is None for {param}" - assert ( - param.grad.shape == param.shape - ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" + +class TestGraphTransformerProcessor: + """Test the GraphTransformerProcessor class.""" + NUM_NODES: int = 100 + NUM_EDGES: int = 200 + + @pytest.fixture + def fake_graph(self) -> tuple[HeteroData, int]: + graph = HeteroData() + graph["nodes"].x = torch.rand((self.NUM_NODES, 2)) + graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, self.NUM_NODES, (2, self.NUM_EDGES)) + graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3)) + graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4)) + return graph + + @pytest.fixture + def graphtransformer_init(self, fake_graph: HeteroData): + num_layers = 2 + num_channels = 128 + num_chunks = 2 + num_heads = 16 + mlp_hidden_ratio = 4 + activation = "GELU" + cpu_offload = False + sub_graph = fake_graph[("nodes", "to", "nodes")] + edge_attributes = ["edge_attr1", "edge_attr2"] + src_grid_size = 0 + dst_grid_size = 0 + trainable_size = 6 + return ( + num_layers, + num_channels, + num_chunks, + num_heads, + mlp_hidden_ratio, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) + + + @pytest.fixture + def graphtransformer_processor(self, graphtransformer_init): + ( + num_layers, + num_channels, + num_chunks, + num_heads, + mlp_hidden_ratio, + activation, + cpu_offload, + sub_graph, + edge_attributes, + src_grid_size, + dst_grid_size, + trainable_size, + ) = graphtransformer_init + return GraphTransformerProcessor( + num_layers, + num_channels=num_channels, + num_chunks=num_chunks, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + activation=activation, + cpu_offload=cpu_offload, + sub_graph=sub_graph, + sub_graph_edge_attributes=edge_attributes, + src_grid_size=src_grid_size, + dst_grid_size=dst_grid_size, + trainable_size=trainable_size, + ) + + def test_graphtransformer_processor_init(self, graphtransformer_processor, graphtransformer_init): + ( + num_layers, + num_channels, + num_chunks, + _num_heads, + _mlp_hidden_ratio, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + _trainable_size, + ) = graphtransformer_init + assert graphtransformer_processor.num_chunks == num_chunks + assert graphtransformer_processor.num_channels == num_channels + assert graphtransformer_processor.chunk_size == num_layers // num_chunks + assert isinstance(graphtransformer_processor.trainable, TrainableTensor) + + def test_forward(self, graphtransformer_processor, graphtransformer_init): + batch_size = 1 + ( + _num_layers, + num_channels, + _num_chunks, + _num_heads, + _mlp_hidden_ratio, + _activation, + _cpu_offload, + _sub_graph, + _edge_attributes, + _src_grid_size, + _dst_grid_size, + trainable_size, + ) = graphtransformer_init + x = torch.rand((self.NUM_EDGES, num_channels)) + shard_shapes = [list(x.shape)] + + # Run forward pass of processor + output = graphtransformer_processor.forward(x, batch_size, shard_shapes) + assert output.shape == (self.NUM_EDGES, num_channels) + + # Generate dummy target and loss function + loss_fn = torch.nn.MSELoss() + target = torch.rand((self.NUM_EDGES, num_channels)) + loss = loss_fn(output, target) + + # Check loss + assert loss.item() >= 0 + + # Backward pass + loss.backward() + + # Check gradients of trainable tensor + assert graphtransformer_processor.trainable.trainable.grad.shape == ( + self.NUM_EDGES, + trainable_size, + ) + + # Check gradients of processor + for param in graphtransformer_processor.parameters(): + assert param.grad is not None, f"param.grad is None for {param}" + assert ( + param.grad.shape == param.shape + ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}"