forked from state-spaces/mamba
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
6,707 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
__version__ = "1.2.2" | ||
__version__ = "2.0.0" | ||
|
||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn | ||
from mamba_ssm.modules.mamba_simple import Mamba | ||
from mamba_ssm.modules.mamba2 import Mamba2 | ||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.distributed import ProcessGroup | ||
|
||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for | ||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent | ||
# version of PyTorch. The following 4 lines are for backward compatibility with | ||
# older PyTorch. | ||
if "all_gather_into_tensor" not in dir(torch.distributed): | ||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base | ||
if "reduce_scatter_tensor" not in dir(torch.distributed): | ||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base | ||
|
||
|
||
# Raw operation, does not support autograd, but does support async | ||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | ||
world_size = torch.distributed.get_world_size(process_group) | ||
output = torch.empty( | ||
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device | ||
) | ||
handle = torch.distributed.all_gather_into_tensor( | ||
output, input_.contiguous(), group=process_group, async_op=async_op | ||
) | ||
return output, handle | ||
|
||
|
||
# Raw operation, does not support autograd, but does support async | ||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | ||
world_size = torch.distributed.get_world_size(process_group) | ||
assert input_.shape[0] % world_size == 0 | ||
output = torch.empty( | ||
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device | ||
) | ||
handle = torch.distributed.reduce_scatter_tensor( | ||
output, input_.contiguous(), group=process_group, async_op=async_op | ||
) | ||
return output, handle | ||
|
||
|
||
# Raw operation, does not support autograd, but does support async | ||
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | ||
input_ = input_.contiguous() | ||
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) | ||
return input_, handle | ||
|
||
|
||
class AllGatherFunc(torch.autograd.Function): | ||
"""Gather the input from sequence parallel region and concatenate.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | ||
ctx.process_group = process_group | ||
output, _ = all_gather_raw(input_, process_group) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output: Tensor): | ||
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) | ||
return grad_input, None | ||
|
||
|
||
# Supports autograd, but does not support async | ||
all_gather = AllGatherFunc.apply | ||
|
||
|
||
class ReduceScatterFunc(torch.autograd.Function): | ||
"""Reduce scatter the input from the sequence parallel region and concatenate.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | ||
ctx.process_group = process_group | ||
output, _ = reduce_scatter_raw(input_, process_group) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output: Tensor): | ||
grad_input, _ = all_gather_raw(grad_output, ctx.process_group) | ||
return grad_input, None | ||
|
||
|
||
# Supports autograd, but does not support async | ||
reduce_scatter = ReduceScatterFunc.apply | ||
|
||
|
||
class AllReduceFunc(torch.autograd.Function): | ||
"""Gather the input from sequence parallel region and concatenate.""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | ||
ctx.process_group = process_group | ||
output, _ = all_reduce_raw(input_, process_group) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output: Tensor): | ||
return grad_output, None | ||
|
||
|
||
# Supports autograd, but does not support async | ||
all_reduce = AllReduceFunc.apply | ||
|
||
|
||
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): | ||
# We want to iterate over parameters with _shared_params=True in the same order, | ||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). | ||
pamams_shared = { | ||
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) | ||
} | ||
for _, p in sorted(pamams_shared.items()): | ||
with torch.no_grad(): | ||
# Broadcast needs src to be global rank, not group rank | ||
torch.distributed.broadcast( | ||
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group | ||
) | ||
|
||
|
||
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 | ||
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): | ||
# We want to iterate over parameters with _sequence_parallel=True in the same order, | ||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). | ||
params_seqparallel = { | ||
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) | ||
} | ||
grads = [p.grad for _, p in sorted(params_seqparallel.items())] | ||
if grads: | ||
with torch.no_grad(): | ||
coalesced = torch._utils._flatten_dense_tensors(grads) | ||
torch.distributed.all_reduce(coalesced, group=process_group) | ||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): | ||
buf.copy_(synced) | ||
|
||
|
||
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: | ||
"""Get the dim for the local rank derived from splitting dim on world_size processes. | ||
The split may not be even across the world_size processes. | ||
""" | ||
multiple = dim // multiple_of | ||
div = multiple // world_size | ||
mod = multiple % world_size | ||
local_multiple = div + int(local_rank < mod) | ||
return local_multiple * multiple_of |
Oops, something went wrong.