Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jul 10, 2024
1 parent 1754690 commit 0edc033
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 285 deletions.
7 changes: 4 additions & 3 deletions tests/layers/mapper/test_base_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


class TestBaseMapper:
NUM_EDGES = 100
NUM_SRC_NODES = 100
NUM_DST_NODES = 200
"""Test the BaseMapper class."""
NUM_EDGES: int = 100
NUM_SRC_NODES: int = 100
NUM_DST_NODES: int = 200

@pytest.fixture
def mapper_init(self):
Expand Down
9 changes: 6 additions & 3 deletions tests/layers/mapper/test_graphconv_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


class TestGNNBaseMapper:
NUM_SRC_NODES = 200
NUM_DST_NODES = 178
NUM_EDGES = 300
"""Test the GNNBaseMapper class."""
NUM_SRC_NODES: int = 200
NUM_DST_NODES: int = 178
NUM_EDGES: int = 300

@pytest.fixture
def mapper_init(self):
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_post_process(self, mapper, pair_tensor):


class TestGNNForwardMapper(TestGNNBaseMapper):
"""Test the GNNForwardMapper class."""
@pytest.fixture
def mapper(self, mapper_init, fake_graph):
(
Expand Down Expand Up @@ -237,6 +239,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor):


class TestGNNBackwardMapper(TestGNNBaseMapper):
"""Test the GNNBackwardMapper class."""
@pytest.fixture
def mapper(self, mapper_init, fake_graph):
(
Expand Down
9 changes: 6 additions & 3 deletions tests/layers/mapper/test_graphtransformer_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


class TestGraphTransformerBaseMapper:
NUM_EDGES = 150
NUM_SRC_NODES = 100
NUM_DST_NODES = 200
"""Test the GraphTransformerBaseMapper class."""
NUM_EDGES: int = 150
NUM_SRC_NODES: int = 100
NUM_DST_NODES: int = 200

@pytest.fixture
def mapper_init(self):
Expand Down Expand Up @@ -162,6 +163,7 @@ def test_post_process(self, mapper, pair_tensor):


class TestGraphTransformerForwardMapper(TestGraphTransformerBaseMapper):
"""Test the GraphTransformerForwardMapper class."""
@pytest.fixture
def mapper(self, mapper_init, fake_graph):
(
Expand Down Expand Up @@ -259,6 +261,7 @@ def test_forward_backward(self, mapper_init, mapper, pair_tensor):


class TestGraphTransformerBackwardMapper(TestGraphTransformerBaseMapper):
"""Test the GraphTransformerBackwardMapper class."""
@pytest.fixture
def mapper(self, mapper_init, fake_graph):
(
Expand Down
264 changes: 131 additions & 133 deletions tests/layers/processor/test_graphconv_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,137 +12,135 @@
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.processor import GNNProcessor

num_edges = 200


@pytest.fixture
def fake_graph() -> HeteroData:
num_nodes = 100
graph = HeteroData()
graph["nodes"].x = torch.rand((num_nodes, 2))
graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, num_nodes, (2, num_edges))
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((num_edges, 3))
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((num_edges, 4))
return graph


@pytest.fixture
def graphconv_init(fake_graph: HeteroData):
num_layers = 2
num_channels = 128
num_chunks = 2
mlp_extra_layers = 0
activation = "SiLU"
cpu_offload = False
sub_graph = fake_graph[("nodes", "to", "nodes")]
edge_attributes = ["edge_attr1", "edge_attr2"]
src_grid_size = 0
dst_grid_size = 0
trainable_size = 8
return (
num_layers,
num_channels,
num_chunks,
mlp_extra_layers,
activation,
cpu_offload,
sub_graph,
edge_attributes,
src_grid_size,
dst_grid_size,
trainable_size,
)


@pytest.fixture
def graphconv_processor(graphconv_init):
(
num_layers,
num_channels,
num_chunks,
mlp_extra_layers,
activation,
cpu_offload,
sub_graph,
edge_attributes,
src_grid_size,
dst_grid_size,
trainable_size,
) = graphconv_init
return GNNProcessor(
num_layers,
num_channels=num_channels,
num_chunks=num_chunks,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
cpu_offload=cpu_offload,
sub_graph=sub_graph,
sub_graph_edge_attributes=edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
trainable_size=trainable_size,
)


def test_graphconv_processor_init(graphconv_processor, graphconv_init):
(
num_layers,
num_channels,
num_chunks,
_mlp_extra_layers,
_activation,
_cpu_offload,
_sub_graph,
_edge_attributes,
_src_grid_size,
_dst_grid_size,
_trainable_size,
) = graphconv_init
assert graphconv_processor.num_chunks == num_chunks
assert graphconv_processor.num_channels == num_channels
assert graphconv_processor.chunk_size == num_layers // num_chunks
assert isinstance(graphconv_processor.trainable, TrainableTensor)


def test_forward(graphconv_processor, graphconv_init):
batch_size = 1
(
_num_layers,
num_channels,
_num_chunks,
_mlp_extra_layers,
_activation,
_cpu_offload,
_sub_graph,
_edge_attributes,
_src_grid_size,
_dst_grid_size,
trainable_size,
) = graphconv_init
x = torch.rand((num_edges, num_channels))
shard_shapes = [list(x.shape)]

# Run forward pass of processor
output = graphconv_processor.forward(x, batch_size, shard_shapes)
assert output.shape == (num_edges, num_channels)

# Generate dummy target and loss function
loss_fn = torch.nn.MSELoss()
target = torch.rand((num_edges, num_channels))
loss = loss_fn(output, target)

# Check loss
assert loss.item() >= 0

# Backward pass
loss.backward()

# Check gradients of trainable tensor
assert graphconv_processor.trainable.trainable.grad.shape == (num_edges, trainable_size)

# Check gradients of processor
for param in graphconv_processor.parameters():
assert param.grad is not None, f"param.grad is None for {param}"
assert (
param.grad.shape == param.shape
), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}"
class TestGNNProcessor:
"""Test the GNNProcessor class."""
NUM_NODES: int = 100
NUM_EDGES: int = 200

@pytest.fixture
def fake_graph(self) -> tuple[HeteroData, int]:
graph = HeteroData()
graph["nodes"].x = torch.rand((self.NUM_NODES, 2))
graph[("nodes", "to", "nodes")].edge_index = torch.randint(0, self.NUM_NODES, (2, self.NUM_EDGES))
graph[("nodes", "to", "nodes")].edge_attr1 = torch.rand((self.NUM_EDGES, 3))
graph[("nodes", "to", "nodes")].edge_attr2 = torch.rand((self.NUM_EDGES, 4))
return graph

@pytest.fixture
def graphconv_init(self, fake_graph: HeteroData):
num_layers = 2
num_channels = 128
num_chunks = 2
mlp_extra_layers = 0
activation = "SiLU"
cpu_offload = False
sub_graph = fake_graph[("nodes", "to", "nodes")]
edge_attributes = ["edge_attr1", "edge_attr2"]
src_grid_size = 0
dst_grid_size = 0
trainable_size = 8
return (
num_layers,
num_channels,
num_chunks,
mlp_extra_layers,
activation,
cpu_offload,
sub_graph,
edge_attributes,
src_grid_size,
dst_grid_size,
trainable_size,
)

@pytest.fixture
def graphconv_processor(self, graphconv_init):
(
num_layers,
num_channels,
num_chunks,
mlp_extra_layers,
activation,
cpu_offload,
sub_graph,
edge_attributes,
src_grid_size,
dst_grid_size,
trainable_size,
) = graphconv_init
return GNNProcessor(
num_layers,
num_channels=num_channels,
num_chunks=num_chunks,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
cpu_offload=cpu_offload,
sub_graph=sub_graph,
sub_graph_edge_attributes=edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
trainable_size=trainable_size,
)

def test_graphconv_processor_init(self, graphconv_processor, graphconv_init):
(
num_layers,
num_channels,
num_chunks,
_mlp_extra_layers,
_activation,
_cpu_offload,
_sub_graph,
_edge_attributes,
_src_grid_size,
_dst_grid_size,
_trainable_size,
) = graphconv_init
assert graphconv_processor.num_chunks == num_chunks
assert graphconv_processor.num_channels == num_channels
assert graphconv_processor.chunk_size == num_layers // num_chunks
assert isinstance(graphconv_processor.trainable, TrainableTensor)

def test_forward(self, graphconv_processor, graphconv_init):
batch_size = 1
(
_num_layers,
num_channels,
_num_chunks,
_mlp_extra_layers,
_activation,
_cpu_offload,
_sub_graph,
_edge_attributes,
_src_grid_size,
_dst_grid_size,
trainable_size,
) = graphconv_init
x = torch.rand((self.NUM_EDGES, num_channels))
shard_shapes = [list(x.shape)]

# Run forward pass of processor
output = graphconv_processor.forward(x, batch_size, shard_shapes)
assert output.shape == (self.NUM_EDGES, num_channels)

# Generate dummy target and loss function
loss_fn = torch.nn.MSELoss()
target = torch.rand((self.NUM_EDGES, num_channels))
loss = loss_fn(output, target)

# Check loss
assert loss.item() >= 0

# Backward pass
loss.backward()

# Check gradients of trainable tensor
assert graphconv_processor.trainable.trainable.grad.shape == (self.NUM_EDGES, trainable_size)

# Check gradients of processor
for param in graphconv_processor.parameters():
assert param.grad is not None, f"param.grad is None for {param}"
assert (
param.grad.shape == param.shape
), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}"
Loading

0 comments on commit 0edc033

Please sign in to comment.