Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Use anemoi-graphs HeteroData (#8)
Browse files Browse the repository at this point in the history
- Support new PyTorch Geometric HeteroData structure (defined by anemoi-graphs)

Co-authored-by: Jesper Dramsch <[email protected]>
Co-authored-by: Helen Theissen <[email protected]>
  • Loading branch information
3 people authored Jul 26, 2024
1 parent 09338e4 commit a34cb8b
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 357 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ Keep it human-readable, your future self will thank you!

### Removed

## 0.2.0

### Added

- Option to choose the edge attributes

### Changed

- Updated to support new PyTorch Geometric HeteroData structure (defined by `anemoi-graphs` package).

### Removed

## 0.1.0 Initial Release

### Added
Expand Down
32 changes: 27 additions & 5 deletions src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# nor does it submit to any jurisdiction.
#

import logging
from abc import ABC
from typing import Optional

Expand All @@ -30,6 +31,8 @@
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.mlp import MLP

LOGGER = logging.getLogger(__name__)


class BaseMapper(nn.Module, ABC):
"""Base Mapper from souce dimension to destination dimension."""
Expand Down Expand Up @@ -113,22 +116,31 @@ def pre_process(self, x, shard_shapes, model_comm_group=None):


class GraphEdgeMixin:
def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, trainable_size: int) -> None:
def _register_edges(
self, sub_graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int
) -> None:
"""Register edge dim, attr, index_base, and increment.
Parameters
----------
sub_graph : HeteroData
Sub graph of the full structure
edge_attributes : list[str]
Edge attributes to use.
src_size : int
Source size
dst_size : int
Target size
trainable_size : int
Trainable tensor size
"""
self.edge_dim = sub_graph.edge_attr.shape[1] + trainable_size
self.register_buffer("edge_attr", sub_graph.edge_attr, persistent=False)
if edge_attributes is None:
raise ValueError("Edge attributes must be provided")

edge_attr_tensor = torch.cat([sub_graph[attr] for attr in edge_attributes], axis=1)

self.edge_dim = edge_attr_tensor.shape[1] + trainable_size
self.register_buffer("edge_attr", edge_attr_tensor, persistent=False)
self.register_buffer("edge_index_base", sub_graph.edge_index, persistent=False)
self.register_buffer(
"edge_inc", torch.from_numpy(np.asarray([[src_size], [dst_size]], dtype=np.int64)), persistent=True
Expand Down Expand Up @@ -174,6 +186,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -210,7 +223,7 @@ def __init__(
activation=activation,
)

self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size)
self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)

self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0])

Expand Down Expand Up @@ -274,6 +287,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -312,6 +326,7 @@ def __init__(
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
sub_graph=sub_graph,
sub_graph_edge_attributes=sub_graph_edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
)
Expand Down Expand Up @@ -345,6 +360,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -383,6 +399,7 @@ def __init__(
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
sub_graph=sub_graph,
sub_graph_edge_attributes=sub_graph_edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
)
Expand Down Expand Up @@ -415,6 +432,7 @@ def __init__(
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -451,7 +469,7 @@ def __init__(
activation=activation,
)

self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size)
self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)

self.emb_edges = MLP(
in_features=self.edge_dim,
Expand Down Expand Up @@ -518,6 +536,7 @@ def __init__(
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -555,6 +574,7 @@ def __init__(
activation,
mlp_extra_layers,
sub_graph=sub_graph,
sub_graph_edge_attributes=sub_graph_edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
)
Expand Down Expand Up @@ -602,6 +622,7 @@ def __init__(
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -639,6 +660,7 @@ def __init__(
activation=activation,
mlp_extra_layers=mlp_extra_layers,
sub_graph=sub_graph,
sub_graph_edge_attributes=sub_graph_edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
)
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
activation: str = "SiLU",
cpu_offload: bool = False,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
**kwargs,
Expand Down Expand Up @@ -201,7 +202,7 @@ def __init__(
mlp_extra_layers=mlp_extra_layers,
)

self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size)
self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)

self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0])

Expand Down Expand Up @@ -258,6 +259,7 @@ def __init__(
activation: str = "GELU",
cpu_offload: bool = False,
sub_graph: Optional[HeteroData] = None,
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
**kwargs,
Expand Down Expand Up @@ -291,7 +293,7 @@ def __init__(
mlp_hidden_ratio=mlp_hidden_ratio,
)

self._register_edges(sub_graph, src_grid_size, dst_grid_size, trainable_size)
self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)

self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0])

Expand Down
35 changes: 12 additions & 23 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
----------
config : DotDict
Job configuration
data_indices : dict
Data indices
graph_data : HeteroData
Graph definition
"""
Expand All @@ -61,7 +63,7 @@ def __init__(
# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon
# Register lat/lon of nodes
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

Expand Down Expand Up @@ -120,38 +122,25 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
# Define Sizes of different tensors
self._data_grid_size = self._graph_data[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.shape[
0
]
self._hidden_grid_size = self._graph_data[
(self._graph_name_hidden, "to", self._graph_name_hidden)
].hcoords_rad.shape[0]
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes

self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _register_latlon(self, name: str, key: str) -> None:
def _register_latlon(self, name: str, nodes: str) -> None:
"""Register lat/lon buffers.
Parameters
----------
name : str
Name of grid to map
key : str
Key of the grid
Name to store the lat-lon coordinates of the nodes.
nodes : str
Name of nodes to map
"""
self.register_buffer(
f"latlons_{name}",
torch.cat(
[
torch.sin(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]),
torch.cos(self._graph_data[(key, "to", key)][f"{key[:1]}coords_rad"]),
],
dim=-1,
),
persistent=True,
)
coords = self._graph_data[nodes].x
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
Expand Down
23 changes: 19 additions & 4 deletions tests/layers/mapper/test_base_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@


class TestBaseMapper:
"""Test the BaseMapper class."""

NUM_EDGES: int = 100
NUM_SRC_NODES: int = 100
NUM_DST_NODES: int = 200

@pytest.fixture
def mapper_init(self):
in_channels_src: int = 3
Expand Down Expand Up @@ -50,7 +56,8 @@ def mapper(self, mapper_init, fake_graph):
out_channels_dst=out_channels_dst,
cpu_offload=cpu_offload,
activation=activation,
sub_graph=fake_graph,
sub_graph=fake_graph[("src", "to", "dst")],
sub_graph_edge_attributes=["edge_attr1", "edge_attr2"],
trainable_size=trainable_size,
)

Expand All @@ -71,10 +78,18 @@ def pair_tensor(self, mapper_init):
)

@pytest.fixture
def fake_graph(self):
def fake_graph(self) -> HeteroData:
"""Fake graph."""
graph = HeteroData()
graph.edge_attr = torch.rand((100, 128))
graph.edge_index = torch.randint(0, 100, (2, 100))
graph[("src", "to", "dst")].edge_index = torch.concat(
[
torch.randint(0, self.NUM_SRC_NODES, (1, self.NUM_EDGES)),
torch.randint(0, self.NUM_DST_NODES, (1, self.NUM_EDGES)),
],
axis=0,
)
graph[("src", "to", "dst")].edge_attr1 = torch.rand((self.NUM_EDGES, 1))
graph[("src", "to", "dst")].edge_attr2 = torch.rand((self.NUM_EDGES, 32))
return graph

def test_initialization(self, mapper, mapper_init):
Expand Down
Loading

0 comments on commit a34cb8b

Please sign in to comment.