This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: split distributed files, add licenses
Co-authored-by: Simon Lang <[email protected]> Co-authored-by: Helen Theissen <[email protected]>
- Loading branch information
1 parent
fccc5a4
commit 6148711
Showing
11 changed files
with
641 additions
and
521 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
# (C) Copyright 2024 ECMWF. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
# | ||
|
||
|
||
import torch | ||
from torch import Tensor | ||
from torch.distributed.distributed_c10d import ProcessGroup | ||
|
||
from anemoi.models.distributed.primitives import _gather | ||
from anemoi.models.distributed.primitives import _reduce | ||
from anemoi.models.distributed.primitives import _split | ||
|
||
|
||
def shard_tensor( | ||
input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup, gather_in_backward: bool = True | ||
) -> Tensor: | ||
"""Shard tensor. | ||
Keeps only part of the tensor that is relevant for the current rank. | ||
Parameters | ||
---------- | ||
input_ : Tensor | ||
Input | ||
dim : int | ||
dimension along which to shard | ||
shapes : tuple | ||
Shapes of sharded Tensors | ||
mgroup : ProcessGroup | ||
model communication group | ||
gather_in_backward : bool | ||
perform gather in backward, default True | ||
Returns | ||
------- | ||
Tensor | ||
Sharded tensor. | ||
""" | ||
return _ShardParallelSection.apply(input_, dim, shapes, gather_in_backward, mgroup) | ||
|
||
|
||
def gather_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: | ||
"""Gather tensor. | ||
Gathers tensor shards from ranks. | ||
Parameters | ||
---------- | ||
input_ : Tensor | ||
Input | ||
dim : int | ||
dimension along which to gather | ||
shapes : tuple | ||
Shapes of sharded Tensors | ||
mgroup : ProcessGroup | ||
model communication group | ||
Returns | ||
------- | ||
Tensor | ||
Gathered tensor. | ||
""" | ||
return _GatherParallelSection.apply(input_, dim, shapes, mgroup) | ||
|
||
|
||
def reduce_tensor(input_: Tensor, mgroup: ProcessGroup) -> Tensor: | ||
"""Reduce tensor. | ||
Reduces tensor across ranks. | ||
Parameters | ||
---------- | ||
input_ : Tensor | ||
Input | ||
mgroup : ProcessGroup | ||
model communication group | ||
Returns | ||
------- | ||
Tensor | ||
Reduced tensor. | ||
""" | ||
return _ReduceParallelSection.apply(input_, mgroup) | ||
|
||
|
||
def sync_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: | ||
"""Sync tensor. | ||
Perform a gather in the forward pass and an allreduce followed by a split in the backward pass. | ||
Parameters | ||
---------- | ||
input_ : Tensor | ||
Input | ||
dim : int | ||
dimension along which to gather | ||
shapes : tuple | ||
Shapes of sharded Tensors | ||
mgroup : ProcessGroup | ||
model communication group | ||
Returns | ||
------- | ||
Tensor | ||
Synced tensor. | ||
""" | ||
return _SyncParallelSection.apply(input_, dim, shapes, mgroup) | ||
|
||
|
||
def reduce_shard_tensor(input_: Tensor, dim: int, shapes: tuple, mgroup: ProcessGroup) -> Tensor: | ||
"""Reduces and then shards tensor. | ||
Perform an allreduce followed by a split in the forward pass and a gather in the backward pass. | ||
Parameters | ||
---------- | ||
input_ : Tensor | ||
Input | ||
dim : int | ||
dimension along which to gather | ||
shapes : tuple | ||
Shapes of sharded Tensors | ||
mgroup : ProcessGroup | ||
model communication group | ||
Returns | ||
------- | ||
Tensor | ||
Reduced sharded tensor. | ||
""" | ||
return _ReduceShardParallelSection.apply(input_, dim, shapes, mgroup) | ||
|
||
|
||
class _SyncParallelSection(torch.autograd.Function): | ||
"""Sync the input from parallel section.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_, dim_, shapes_, mgroup_): | ||
ctx.dim = dim_ | ||
ctx.comm_group = mgroup_ | ||
ctx.shapes = shapes_ | ||
if mgroup_: | ||
return _gather(input_, dim_, shapes_, group=mgroup_) | ||
return input_ | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
if ctx.comm_group: | ||
grad_output = _reduce(grad_output, group=ctx.comm_group) | ||
return ( | ||
_split(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group), | ||
None, | ||
None, | ||
None, | ||
) | ||
return grad_output, None, None, None | ||
|
||
|
||
class _ReduceShardParallelSection(torch.autograd.Function): | ||
"""All-reduce and shard the input from the parallel section.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_, dim_, shapes_, mgroup_): | ||
ctx.dim = dim_ | ||
ctx.comm_group = mgroup_ | ||
ctx.shapes = shapes_ | ||
if mgroup_: | ||
input_ = _reduce(input_, group=mgroup_) | ||
return _split(input_, dim_, shapes_, group=mgroup_) | ||
return input_ | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
if ctx.comm_group: | ||
return ( | ||
_gather(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group), | ||
None, | ||
None, | ||
None, | ||
) | ||
return grad_output, None, None, None | ||
|
||
|
||
class _ShardParallelSection(torch.autograd.Function): | ||
"""Split the input and keep only the relevant chunck to the rank.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_, dim_, shapes_, gather_in_backward_, mgroup_): | ||
ctx.dim = dim_ | ||
ctx.comm_group = mgroup_ | ||
ctx.shapes = shapes_ | ||
ctx.gather_in_backward = gather_in_backward_ | ||
if mgroup_: | ||
return _split(input_, dim_, shapes_, group=mgroup_) | ||
return input_ | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
if ctx.comm_group: | ||
return ( | ||
_gather( | ||
grad_output, ctx.dim, ctx.shapes, gather_in_backward=ctx.gather_in_backward, group=ctx.comm_group | ||
), | ||
None, | ||
None, | ||
None, | ||
None, | ||
) | ||
return grad_output, None, None, None, None | ||
|
||
|
||
class _GatherParallelSection(torch.autograd.Function): | ||
"""Gather the input from parallel section and concatenate.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_, dim_, shapes_, mgroup_): | ||
ctx.dim = dim_ | ||
ctx.comm_group = mgroup_ | ||
ctx.shapes = shapes_ | ||
if mgroup_: | ||
return _gather(input_, dim_, shapes_, group=mgroup_) | ||
return input_ | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
if ctx.comm_group: | ||
return ( | ||
_split(grad_output, ctx.dim, ctx.shapes, group=ctx.comm_group), | ||
None, | ||
None, | ||
None, | ||
) | ||
return grad_output, None, None, None | ||
|
||
|
||
class _ReduceParallelSection(torch.autograd.Function): | ||
"""All-reduce the input from the parallel section.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_, mgroup_): | ||
if mgroup_: | ||
return _reduce(input_, group=mgroup_) | ||
return input_ | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
return grad_output, None |
Oops, something went wrong.