Skip to content

Commit

Permalink
Add a context manager for activation sharding.
Browse files Browse the repository at this point in the history
  • Loading branch information
luyug committed Jun 12, 2023
1 parent 164cc0f commit a04e4b8
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions fairscale/nn/model_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from typing import Tuple

import torch
from torch.autograd.graph import saved_tensors_hooks
from .initialize import get_model_parallel_rank, get_model_parallel_world_size
from .mappings import get_model_parallel_group


def ensure_divisibility(numerator: int, denominator: int) -> None:
Expand Down Expand Up @@ -76,3 +79,44 @@ def vocab_range_from_per_partition_vocab_size(
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Tuple[int, int]:
per_partition_vocab_size = divide_and_check_no_remainder(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)


def _pack_over_mp(tensor):
mp_world_size = get_model_parallel_world_size()
if mp_world_size == 1:
return tensor # no-op for mp=1
full_tensor_shape = list(tensor.shape)
shard = tensor.view(-1).chunk(mp_world_size, dim=0)[get_model_parallel_rank()]
shard = shard.detach().clone().contiguous() # clone to explicitly release memory of the full tensor
del tensor
return shard, full_tensor_shape


def _unpack_over_mp(sharded_tensor):
sharded_tensor, full_tensor_shape = sharded_tensor
mp_world_size = get_model_parallel_world_size()
if mp_world_size == 1:
return sharded_tensor # no-op for mp=1
full_tensor = torch.empty(
*full_tensor_shape,
dtype=sharded_tensor.dtype,
device=sharded_tensor.device)

torch.distributed.all_gather_into_tensor(
full_tensor.view(-1), sharded_tensor, group=get_model_parallel_group()
)

return full_tensor


class shard_over_mp_group(saved_tensors_hooks):
"""Context manager for activatoin sharding.
This context manager shard tensors saved by autograd over the
model parallel group in the forward pass and unshards them
in the backward pass. Useful to remove redundancy in the
long-living activation tensors.
"""

def __init__(self):
super().__init__(_pack_over_mp, _unpack_over_mp)

0 comments on commit a04e4b8

Please sign in to comment.