Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Include changes from aifs-mono: feature/graph_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed May 30, 2024
1 parent 6f053a8 commit c27d597
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 144 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch
from anemoi.utils.config import DotConfig
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec
from anemoi.models.preprocessing import Processors

Expand All @@ -22,7 +22,7 @@ class AnemoiModelInterface(torch.nn.Module):
"""Anemoi model on torch level."""

def __init__(
self, *, config: DotConfig, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict
self, *, config: DotConfig, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict
) -> None:
super().__init__()
self.config = config
Expand Down
23 changes: 11 additions & 12 deletions src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.data import HeteroData
from torch_geometric.typing import Adj
from torch_geometric.typing import PairTensor

Expand Down Expand Up @@ -113,12 +112,12 @@ def pre_process(self, x, shard_shapes, model_comm_group=None):


class GraphEdgeMixin:
def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, trainable_size: int) -> None:
def _register_edges(self, sub_graph: dict, src_size: int, dst_size: int, trainable_size: int) -> None:
"""Register edge dim, attr, index_base, and increment.
Parameters
----------
sub_graph : HeteroData
sub_graph : dict
Sub graph of the full structure
src_size : int
Source size
Expand All @@ -127,9 +126,9 @@ def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, t
trainable_size : int
Trainable tensor size
"""
self.edge_dim = sub_graph.edge_attr.shape[1] + trainable_size
self.register_buffer("edge_attr", sub_graph.edge_attr, persistent=False)
self.register_buffer("edge_index_base", sub_graph.edge_index, persistent=False)
self.edge_dim = sub_graph["edge_attr"].shape[1] + trainable_size
self.register_buffer("edge_attr", sub_graph["edge_attr"], persistent=False)
self.register_buffer("edge_index_base", sub_graph["edge_index"], persistent=False)
self.register_buffer(
"edge_inc", torch.from_numpy(np.asarray([[src_size], [dst_size]], dtype=np.int64)), persistent=True
)
Expand Down Expand Up @@ -173,7 +172,7 @@ def __init__(
activation: str = "GELU",
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -273,7 +272,7 @@ def __init__(
activation: str = "GELU",
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -344,7 +343,7 @@ def __init__(
activation: str = "GELU",
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -414,7 +413,7 @@ def __init__(
cpu_offload: bool = False,
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -517,7 +516,7 @@ def __init__(
cpu_offload: bool = False,
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -601,7 +600,7 @@ def __init__(
cpu_offload: bool = False,
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down
5 changes: 2 additions & 3 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
from torch.distributed.distributed_c10d import ProcessGroup
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.khop_edges import sort_edges_1hop
Expand Down Expand Up @@ -170,7 +169,7 @@ def __init__(
mlp_extra_layers: int = 0,
activation: str = "SiLU",
cpu_offload: bool = False,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
**kwargs,
Expand Down Expand Up @@ -257,7 +256,7 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
cpu_offload: bool = False,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
**kwargs,
Expand Down
Loading

0 comments on commit c27d597

Please sign in to comment.