diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 8197a0a7..04efdf0a 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -135,8 +135,7 @@ def _register_edges( Trainable tensor size """ if edge_attributes is None: - LOGGER.warning("No edge attributes provided.") - edge_attributes = [] + raise ValueError("Edge attributes must be provided") edge_attr_tensor = torch.cat([sub_graph[attr] for attr in edge_attributes], axis=1) @@ -288,6 +287,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -326,6 +326,7 @@ def __init__( num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) @@ -359,6 +360,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -397,6 +399,7 @@ def __init__( num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) @@ -533,6 +536,7 @@ def __init__( activation: str = "SiLU", mlp_extra_layers: int = 0, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -570,6 +574,7 @@ def __init__( activation, mlp_extra_layers, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) @@ -617,6 +622,7 @@ def __init__( activation: str = "SiLU", mlp_extra_layers: int = 0, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, ) -> None: @@ -654,6 +660,7 @@ def __init__( activation=activation, mlp_extra_layers=mlp_extra_layers, sub_graph=sub_graph, + sub_graph_edge_attributes=sub_graph_edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, ) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 5a0ca324..bb336091 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -171,6 +171,7 @@ def __init__( activation: str = "SiLU", cpu_offload: bool = False, sub_graph: Optional[HeteroData] = None, + sub_graph_edge_attributes: Optional[list[str]] = None, src_grid_size: int = 0, dst_grid_size: int = 0, **kwargs, @@ -201,7 +202,7 @@ def __init__( mlp_extra_layers=mlp_extra_layers, ) - self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size) + self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size) self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0]) diff --git a/tests/layers/mapper/test_base_mapper.py b/tests/layers/mapper/test_base_mapper.py index a0cbd643..c9e051f8 100644 --- a/tests/layers/mapper/test_base_mapper.py +++ b/tests/layers/mapper/test_base_mapper.py @@ -55,6 +55,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) diff --git a/tests/layers/mapper/test_graphconv_mapper.py b/tests/layers/mapper/test_graphconv_mapper.py index 58852fe6..9fd34fbf 100644 --- a/tests/layers/mapper/test_graphconv_mapper.py +++ b/tests/layers/mapper/test_graphconv_mapper.py @@ -58,6 +58,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -166,6 +167,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) @@ -254,6 +256,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) diff --git a/tests/layers/mapper/test_graphtransformer_mapper.py b/tests/layers/mapper/test_graphtransformer_mapper.py index 75b37a1f..b3386a23 100644 --- a/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/tests/layers/mapper/test_graphtransformer_mapper.py @@ -64,6 +64,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, @@ -182,6 +183,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, num_heads=num_heads, mlp_hidden_ratio=mlp_hidden_ratio, @@ -278,6 +280,7 @@ def mapper(self, mapper_init, fake_graph): cpu_offload=cpu_offload, activation=activation, sub_graph=fake_graph[("src", "to", "dst")], + sub_graph_edge_attributes=["edge_attr1", "edge_attr2"], trainable_size=trainable_size, ) diff --git a/tests/layers/processor/test_graphconv_processor.py b/tests/layers/processor/test_graphconv_processor.py index 3b4168a0..ecfc15e8 100644 --- a/tests/layers/processor/test_graphconv_processor.py +++ b/tests/layers/processor/test_graphconv_processor.py @@ -35,6 +35,7 @@ def graphconv_init(fake_graph: HeteroData): activation = "SiLU" cpu_offload = False sub_graph = fake_graph[("nodes", "to", "nodes")] + edge_attributes = ["edge_attr1", "edge_attr2"] src_grid_size = 0 dst_grid_size = 0 trainable_size = 8 @@ -46,6 +47,7 @@ def graphconv_init(fake_graph: HeteroData): activation, cpu_offload, sub_graph, + edge_attributes, src_grid_size, dst_grid_size, trainable_size, @@ -62,6 +64,7 @@ def graphconv_processor(graphconv_init): activation, cpu_offload, sub_graph, + edge_attributes, src_grid_size, dst_grid_size, trainable_size, @@ -74,6 +77,7 @@ def graphconv_processor(graphconv_init): activation=activation, cpu_offload=cpu_offload, sub_graph=sub_graph, + sub_graph_edge_attributes=edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, trainable_size=trainable_size, @@ -89,6 +93,7 @@ def test_graphconv_processor_init(graphconv_processor, graphconv_init): _activation, _cpu_offload, _sub_graph, + _edge_attributes, _src_grid_size, _dst_grid_size, _trainable_size, @@ -109,6 +114,7 @@ def test_forward(graphconv_processor, graphconv_init): _activation, _cpu_offload, _sub_graph, + _edge_attributes, _src_grid_size, _dst_grid_size, trainable_size, diff --git a/tests/layers/processor/test_graphtransformer_processor.py b/tests/layers/processor/test_graphtransformer_processor.py index 57bc0146..06a392d8 100644 --- a/tests/layers/processor/test_graphtransformer_processor.py +++ b/tests/layers/processor/test_graphtransformer_processor.py @@ -36,6 +36,7 @@ def graphtransformer_init(fake_graph: HeteroData): activation = "GELU" cpu_offload = False sub_graph = fake_graph[("nodes", "to", "nodes")] + edge_attributes = ["edge_attr1", "edge_attr2"] src_grid_size = 0 dst_grid_size = 0 trainable_size = 6 @@ -48,6 +49,7 @@ def graphtransformer_init(fake_graph: HeteroData): activation, cpu_offload, sub_graph, + edge_attributes, src_grid_size, dst_grid_size, trainable_size, @@ -65,6 +67,7 @@ def graphtransformer_processor(graphtransformer_init): activation, cpu_offload, sub_graph, + edge_attributes, src_grid_size, dst_grid_size, trainable_size, @@ -78,6 +81,7 @@ def graphtransformer_processor(graphtransformer_init): activation=activation, cpu_offload=cpu_offload, sub_graph=sub_graph, + sub_graph_edge_attributes=edge_attributes, src_grid_size=src_grid_size, dst_grid_size=dst_grid_size, trainable_size=trainable_size, @@ -94,6 +98,7 @@ def test_graphtransformer_processor_init(graphtransformer_processor, graphtransf _activation, _cpu_offload, _sub_graph, + _edge_attributes, _src_grid_size, _dst_grid_size, _trainable_size, @@ -115,6 +120,7 @@ def test_forward(graphtransformer_processor, graphtransformer_init): _activation, _cpu_offload, _sub_graph, + _edge_attributes, _src_grid_size, _dst_grid_size, trainable_size,