Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
chore: move khop functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JesperDramsch committed May 27, 2024
1 parent 31b3ce5 commit fccc5a4
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 89 deletions.
87 changes: 0 additions & 87 deletions src/anemoi/models/distributed/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@
#

from typing import Optional
from typing import Union

import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.typing import Adj
from torch_geometric.utils import bipartite_subgraph
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.utils import mask_to_index


def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor:
Expand Down Expand Up @@ -164,88 +159,6 @@ def sync_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -
return _SyncParallelSection.apply(input_, dim, shapes, mgroup)


def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops: int = 1) -> tuple[Adj, Tensor]:
"""Return 1 hop subgraph.
Parameters
----------
nodes : Tensor
destination nodes
edge_attr : Tensor
edge attributes
edge_index : Adj
edge index
num_hops: int, Optional, by default 1
number of required hops
Returns
-------
tuple[Adj, Tensor]
K-hop subgraph of edge index and edge attributes
"""
_, edge_index_k, _, edge_mask_k = k_hop_subgraph(
node_idx=nodes, num_hops=num_hops, edge_index=edge_index, directed=True
)

return edge_attr[mask_to_index(edge_mask_k)], edge_index_k


def sort_edges_1hop(
num_nodes: Union[int, tuple[int, int]],
edge_attr: Tensor,
edge_index: Adj,
mgroup: Optional[ProcessGroup] = None,
) -> tuple[Adj, Tensor, list, list]:
"""Rearanges edges into 1 hop neighbourhoods for sharding across GPUs.
Parameters
----------
num_nodes : Union[int, tuple[int, int]]
Number of (target) nodes in Graph
edge_attr : Tensor
edge attributes
edge_index : Adj
edge index
mgroup : ProcessGroup
model communication group
Returns
-------
tuple[Adj, Tensor, list, list]
edges sorted according to k hop neigh., edge attributes of sorted edges,
shapes of edge indices for partitioning between GPUs, shapes of edge attr for partitioning between GPUs
"""
if mgroup:
num_chunks = dist.get_world_size(group=mgroup)

if isinstance(num_nodes, int):
node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks)
else:
nodes_src = torch.arange(num_nodes[0], device=edge_index.device)
node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks)

edge_index_list = []
edge_attr_list = []
for node_chunk in node_chunks:
if isinstance(num_nodes, int):
edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index)
else:
edge_index_chunk, edge_attr_chunk = bipartite_subgraph(
(nodes_src, node_chunk),
edge_index,
edge_attr,
size=(num_nodes[0], num_nodes[1]),
)
edge_index_list.append(edge_index_chunk)
edge_attr_list.append(edge_attr_chunk)
edge_index_shapes = [x.shape for x in edge_index_list]
edge_attr_shapes = [x.shape for x in edge_attr_list]

return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes

return edge_attr, edge_index, [], []


def reduce_shard_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor:
"""Reduces and then shards tensor.
Expand Down
102 changes: 102 additions & 0 deletions src/anemoi/models/distributed/khop_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from typing import Optional
from typing import Union

import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.typing import Adj
from torch_geometric.utils import bipartite_subgraph
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.utils import mask_to_index


def get_k_hop_edges(nodes: Tensor, edge_attr: Tensor, edge_index: Adj, num_hops: int = 1) -> tuple[Adj, Tensor]:
"""Return 1 hop subgraph.
Parameters
----------
nodes : Tensor
destination nodes
edge_attr : Tensor
edge attributes
edge_index : Adj
edge index
num_hops: int, Optional, by default 1
number of required hops
Returns
-------
tuple[Adj, Tensor]
K-hop subgraph of edge index and edge attributes
"""
_, edge_index_k, _, edge_mask_k = k_hop_subgraph(
node_idx=nodes, num_hops=num_hops, edge_index=edge_index, directed=True
)

return edge_attr[mask_to_index(edge_mask_k)], edge_index_k


def sort_edges_1hop(
num_nodes: Union[int, tuple[int, int]],
edge_attr: Tensor,
edge_index: Adj,
mgroup: Optional[ProcessGroup] = None,
) -> tuple[Adj, Tensor, list, list]:
"""Rearanges edges into 1 hop neighbourhoods for sharding across GPUs.
Parameters
----------
num_nodes : Union[int, tuple[int, int]]
Number of (target) nodes in Graph
edge_attr : Tensor
edge attributes
edge_index : Adj
edge index
mgroup : ProcessGroup
model communication group
Returns
-------
tuple[Adj, Tensor, list, list]
edges sorted according to k hop neigh., edge attributes of sorted edges,
shapes of edge indices for partitioning between GPUs, shapes of edge attr for partitioning between GPUs
"""
if mgroup:
num_chunks = dist.get_world_size(group=mgroup)

if isinstance(num_nodes, int):
node_chunks = torch.arange(num_nodes, device=edge_index.device).tensor_split(num_chunks)
else:
nodes_src = torch.arange(num_nodes[0], device=edge_index.device)
node_chunks = torch.arange(num_nodes[1], device=edge_index.device).tensor_split(num_chunks)

edge_index_list = []
edge_attr_list = []
for node_chunk in node_chunks:
if isinstance(num_nodes, int):
edge_attr_chunk, edge_index_chunk = get_k_hop_edges(node_chunk, edge_attr, edge_index)
else:
edge_index_chunk, edge_attr_chunk = bipartite_subgraph(
(nodes_src, node_chunk),
edge_index,
edge_attr,
size=(num_nodes[0], num_nodes[1]),
)
edge_index_list.append(edge_index_chunk)
edge_attr_list.append(edge_attr_chunk)
edge_index_shapes = [x.shape for x in edge_index_list]
edge_attr_shapes = [x.shape for x in edge_attr_list]

return torch.cat(edge_attr_list, dim=0), torch.cat(edge_index_list, dim=1), edge_attr_shapes, edge_index_shapes

return edge_attr, edge_index, [], []
2 changes: 1 addition & 1 deletion src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from anemoi.models.distributed.helpers import gather_tensor
from anemoi.models.distributed.helpers import get_shape_shards
from anemoi.models.distributed.helpers import shard_tensor
from anemoi.models.distributed.helpers import sort_edges_1hop
from anemoi.models.distributed.khop_edges import sort_edges_1hop
from anemoi.models.layers.block import GraphConvMapperBlock
from anemoi.models.layers.block import GraphTransformerMapperBlock
from anemoi.models.layers.graph import TrainableTensor
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from anemoi.models.distributed.helpers import change_channels_in_shape
from anemoi.models.distributed.helpers import get_shape_shards
from anemoi.models.distributed.helpers import shard_tensor
from anemoi.models.distributed.helpers import sort_edges_1hop
from anemoi.models.distributed.khop_edges import sort_edges_1hop
from anemoi.models.layers.chunk import GNNProcessorChunk
from anemoi.models.layers.chunk import GraphTransformerProcessorChunk
from anemoi.models.layers.chunk import TransformerProcessorChunk
Expand Down

0 comments on commit fccc5a4

Please sign in to comment.