Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
refactor: split distributed files, add licenses
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Lang <[email protected]>
Co-authored-by: Helen Theissen <[email protected]>
  • Loading branch information
3 people committed May 27, 2024
1 parent fccc5a4 commit 6148711
Show file tree
Hide file tree
Showing 11 changed files with 641 additions and 521 deletions.
253 changes: 253 additions & 0 deletions src/anemoi/models/distributed/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import torch
from torch import Tensor
from torch.distributed.distributed_c10d import ProcessGroup

from anemoi.models.distributed.primitives import _gather
from anemoi.models.distributed.primitives import _reduce
from anemoi.models.distributed.primitives import _split


def shard_tensor(
input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup, gather_in_backward: bool = True
) -> Tensor:
"""Shard tensor.
Keeps only part of the tensor that is relevant for the current rank.
Parameters
----------
input_ : Tensor
Input
dim : int
dimension along which to shard
shapes : tuple
Shapes of sharded Tensors
mgroup : ProcessGroup
model communication group
gather_in_backward : bool
perform gather in backward, default True
Returns
-------
Tensor
Sharded tensor.
"""
return _ShardParallelSection.apply(input_, dim, shapes, gather_in_backward, mgroup)


def gather_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor:
"""Gather tensor.
Gathers tensor shards from ranks.
Parameters
----------
input_ : Tensor
Input
dim : int
dimension along which to gather
shapes : tuple
Shapes of sharded Tensors
mgroup : ProcessGroup
model communication group
Returns
-------
Tensor
Gathered tensor.
"""
return _GatherParallelSection.apply(input_, dim, shapes, mgroup)


def reduce_tensor(input_: Tensor, mgroup: ProcessGroup) -> Tensor:
"""Reduce tensor.
Reduces tensor across ranks.
Parameters
----------
input_ : Tensor
Input
mgroup : ProcessGroup
model communication group
Returns
-------
Tensor
Reduced tensor.
"""
return _ReduceParallelSection.apply(input_, mgroup)


def sync_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor:
"""Sync tensor.
Perform a gather in the forward pass and an allreduce followed by a split in the backward pass.
Parameters
----------
input_ : Tensor
Input
dim : int
dimension along which to gather
shapes : tuple
Shapes of sharded Tensors
mgroup : ProcessGroup
model communication group
Returns
-------
Tensor
Synced tensor.
"""
return _SyncParallelSection.apply(input_, dim, shapes, mgroup)


def reduce_shard_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor:
"""Reduces and then shards tensor.
Perform an allreduce followed by a split in the forward pass and a gather in the backward pass.
Parameters
----------
input_ : Tensor
Input
dim : int
dimension along which to gather
shapes : tuple
Shapes of sharded Tensors
mgroup : ProcessGroup
model communication group
Returns
-------
Tensor
Reduced sharded tensor.
"""
return _ReduceShardParallelSection.apply(input_, dim, shapes, mgroup)


class _SyncParallelSection(torch.autograd.Function):
"""Sync the input from parallel section."""

@staticmethod
def forward(ctx, input_, dim_, shapes_, mgroup_):
ctx.dim = dim_
ctx.comm_group = mgroup_
ctx.shapes = shapes_
if mgroup_:
return _gather(input_, dim_, shapes_, group=mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.comm_group:
grad_output = _reduce(grad_output, group=ctx.comm_group)
return (
_split(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group),
None,
None,
None,
)
return grad_output, None, None, None


class _ReduceShardParallelSection(torch.autograd.Function):
"""All-reduce and shard the input from the parallel section."""

@staticmethod
def forward(ctx, input_, dim_, shapes_, mgroup_):
ctx.dim = dim_
ctx.comm_group = mgroup_
ctx.shapes = shapes_
if mgroup_:
input_ = _reduce(input_, group=mgroup_)
return _split(input_, dim_, shapes_, group=mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.comm_group:
return (
_gather(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group),
None,
None,
None,
)
return grad_output, None, None, None


class _ShardParallelSection(torch.autograd.Function):
"""Split the input and keep only the relevant chunck to the rank."""

@staticmethod
def forward(ctx, input_, dim_, shapes_, gather_in_backward_, mgroup_):
ctx.dim = dim_
ctx.comm_group = mgroup_
ctx.shapes = shapes_
ctx.gather_in_backward = gather_in_backward_
if mgroup_:
return _split(input_, dim_, shapes_, group=mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.comm_group:
return (
_gather(
grad_output, ctx.dim, ctx.shapes, gather_in_backward=ctx.gather_in_backward, group=ctx.comm_group
),
None,
None,
None,
None,
)
return grad_output, None, None, None, None


class _GatherParallelSection(torch.autograd.Function):
"""Gather the input from parallel section and concatenate."""

@staticmethod
def forward(ctx, input_, dim_, shapes_, mgroup_):
ctx.dim = dim_
ctx.comm_group = mgroup_
ctx.shapes = shapes_
if mgroup_:
return _gather(input_, dim_, shapes_, group=mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.comm_group:
return (
_split(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group),
None,
None,
None,
)
return grad_output, None, None, None


class _ReduceParallelSection(torch.autograd.Function):
"""All-reduce the input from the parallel section."""

@staticmethod
def forward(ctx, input_, mgroup_):
if mgroup_:
return _reduce(input_, group=mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
return grad_output, None
Loading

0 comments on commit 6148711

Please sign in to comment.