From fccc5a416362e0811c5eb3ca14169aed7022aacc Mon Sep 17 00:00:00 2001 From: Jesper Dramsch Date: Mon, 27 May 2024 12:34:28 +0000 Subject: [PATCH] chore: move khop functions --- src/anemoi/models/distributed/helpers.py | 87 ----------------- src/anemoi/models/distributed/khop_edges.py | 102 ++++++++++++++++++++ src/anemoi/models/layers/mapper.py | 2 +- src/anemoi/models/layers/processor.py | 2 +- 4 files changed, 104 insertions(+), 89 deletions(-) create mode 100644 src/anemoi/models/distributed/khop_edges.py diff --git a/src/anemoi/models/distributed/helpers.py b/src/anemoi/models/distributed/helpers.py index 0d53e2ee..aaee73b0 100644 --- a/src/anemoi/models/distributed/helpers.py +++ b/src/anemoi/models/distributed/helpers.py @@ -8,16 +8,11 @@ # from typing import Optional -from typing import Union import torch import torch.distributed as dist from torch import Tensor from torch.distributed.distributed_c10d import ProcessGroup -from torch_geometric.typing import Adj -from torch_geometric.utils import bipartite_subgraph -from torch_geometric.utils import k_hop_subgraph -from torch_geometric.utils import mask_to_index def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor: @@ -164,88 +159,6 @@ def sync_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) - return _SyncParallelSection.apply(input_, dim, shapes, mgroup) -def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops: int = 1) -> tuple[Adj, Tensor]: - """Return 1 hop subgraph. - - Parameters - ---------- - nodes : Tensor - destination nodes - edge_attr : Tensor - edge attributes - edge_index : Adj - edge index - num_hops: int, Optional, by default 1 - number of required hops - - Returns - ------- - tuple[Adj, Tensor] - K-hop subgraph of edge index and edge attributes - """ - _, edge_index_k, _, edge_mask_k = k_hop_subgraph( - node_idx=nodes, num_hops=num_hops, edge_index=edge_index, directed=True - ) - - return edge_attr[mask_to_index(edge_mask_k)], edge_index_k - - -def sort_edges_1hop( - num_nodes: Union[int, tuple[int, int]], - edge_attr: Tensor, - edge_index: Adj, - mgroup: Optional[ProcessGroup] = None, -) -> tuple[Adj, Tensor, list, list]: - """Rearanges edges into 1 hop neighbourhoods for sharding across GPUs. - - Parameters - ---------- - num_nodes : Union[int, tuple[int, int]] - Number of (target) nodes in Graph - edge_attr : Tensor - edge attributes - edge_index : Adj - edge index - mgroup : ProcessGroup - model communication group - - Returns - ------- - tuple[Adj, Tensor, list, list] - edges sorted according to k hop neigh., edge attributes of sorted edges, - shapes of edge indices for partitioning between GPUs, shapes of edge attr for partitioning between GPUs - """ - if mgroup: - num_chunks = dist.get_world_size(group=mgroup) - - if isinstance(num_nodes, int): - node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks) - else: - nodes_src = torch.arange(num_nodes[0], device=edge_index.device) - node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks) - - edge_index_list = [] - edge_attr_list = [] - for node_chunk in node_chunks: - if isinstance(num_nodes, int): - edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index) - else: - edge_index_chunk, edge_attr_chunk = bipartite_subgraph( - (nodes_src, node_chunk), - edge_index, - edge_attr, - size=(num_nodes[0], num_nodes[1]), - ) - edge_index_list.append(edge_index_chunk) - edge_attr_list.append(edge_attr_chunk) - edge_index_shapes = [x.shape for x in edge_index_list] - edge_attr_shapes = [x.shape for x in edge_attr_list] - - return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes - - return edge_attr, edge_index, [], [] - - def reduce_shard_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: """Reduces and then shards tensor. diff --git a/src/anemoi/models/distributed/khop_edges.py b/src/anemoi/models/distributed/khop_edges.py new file mode 100644 index 00000000..5b4bd815 --- /dev/null +++ b/src/anemoi/models/distributed/khop_edges.py @@ -0,0 +1,102 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from typing import Optional +from typing import Union + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.typing import Adj +from torch_geometric.utils import bipartite_subgraph +from torch_geometric.utils import k_hop_subgraph +from torch_geometric.utils import mask_to_index + + +def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops: int = 1) -> tuple[Adj, Tensor]: + """Return 1 hop subgraph. + + Parameters + ---------- + nodes : Tensor + destination nodes + edge_attr : Tensor + edge attributes + edge_index : Adj + edge index + num_hops: int, Optional, by default 1 + number of required hops + + Returns + ------- + tuple[Adj, Tensor] + K-hop subgraph of edge index and edge attributes + """ + _, edge_index_k, _, edge_mask_k = k_hop_subgraph( + node_idx=nodes, num_hops=num_hops, edge_index=edge_index, directed=True + ) + + return edge_attr[mask_to_index(edge_mask_k)], edge_index_k + + +def sort_edges_1hop( + num_nodes: Union[int, tuple[int, int]], + edge_attr: Tensor, + edge_index: Adj, + mgroup: Optional[ProcessGroup] = None, +) -> tuple[Adj, Tensor, list, list]: + """Rearanges edges into 1 hop neighbourhoods for sharding across GPUs. + + Parameters + ---------- + num_nodes : Union[int, tuple[int, int]] + Number of (target) nodes in Graph + edge_attr : Tensor + edge attributes + edge_index : Adj + edge index + mgroup : ProcessGroup + model communication group + + Returns + ------- + tuple[Adj, Tensor, list, list] + edges sorted according to k hop neigh., edge attributes of sorted edges, + shapes of edge indices for partitioning between GPUs, shapes of edge attr for partitioning between GPUs + """ + if mgroup: + num_chunks = dist.get_world_size(group=mgroup) + + if isinstance(num_nodes, int): + node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks) + else: + nodes_src = torch.arange(num_nodes[0], device=edge_index.device) + node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks) + + edge_index_list = [] + edge_attr_list = [] + for node_chunk in node_chunks: + if isinstance(num_nodes, int): + edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index) + else: + edge_index_chunk, edge_attr_chunk = bipartite_subgraph( + (nodes_src, node_chunk), + edge_index, + edge_attr, + size=(num_nodes[0], num_nodes[1]), + ) + edge_index_list.append(edge_index_chunk) + edge_attr_list.append(edge_attr_chunk) + edge_index_shapes = [x.shape for x in edge_index_list] + edge_attr_shapes = [x.shape for x in edge_attr_list] + + return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes + + return edge_attr, edge_index, [], [] diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 7cf71d69..5d77f321 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -24,7 +24,7 @@ from anemoi.models.distributed.helpers import gather_tensor from anemoi.models.distributed.helpers import get_shape_shards from anemoi.models.distributed.helpers import shard_tensor -from anemoi.models.distributed.helpers import sort_edges_1hop +from anemoi.models.distributed.khop_edges import sort_edges_1hop from anemoi.models.layers.block import GraphConvMapperBlock from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index c1f6efc8..28bedbbc 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -20,7 +20,7 @@ from anemoi.models.distributed.helpers import change_channels_in_shape from anemoi.models.distributed.helpers import get_shape_shards from anemoi.models.distributed.helpers import shard_tensor -from anemoi.models.distributed.helpers import sort_edges_1hop +from anemoi.models.distributed.khop_edges import sort_edges_1hop from anemoi.models.layers.chunk import GNNProcessorChunk from anemoi.models.layers.chunk import GraphTransformerProcessorChunk from anemoi.models.layers.chunk import TransformerProcessorChunk