Skip to content

Commit

Permalink
Add Sequence Parallelism to llama
Browse files Browse the repository at this point in the history
Somehow the torch.compile not working although eager sequence
parallelism working, so currently don't turn it on by default

ghstack-source-id: 73c094b08f6dda91bf79d19297bb88a2050c7286
Pull Request resolved: #32
  • Loading branch information
wanchaol committed Feb 1, 2024
1 parent 782d8d6 commit 672308b
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 7 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ NGPU=8
MP=4

torchrun --nproc_per_node=${NGPU} \
train.py --steps 10
train.py --steps 10 --compile
17 changes: 13 additions & 4 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def forward(
torch.Tensor: Output tensor after attention.
"""
bsz, seqlen, _ = x.shape
seqlen, _ = freqs_cis.shape
bs_seqlen, _ = x.shape
bsz = bs_seqlen // seqlen

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
Expand All @@ -236,7 +239,8 @@ def forward(
xq, xk, xv, is_causal=True
)
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bsz, seqlen, -1)
# output stay folded with batch and sequence dimension
output = output.view(bsz * seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -301,7 +305,6 @@ def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
Expand Down Expand Up @@ -390,14 +393,20 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.
"""
_bsz, seqlen = tokens.shape
bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.params.dim)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0 : seqlen]

for layer in self.layers:
h = layer(h, freqs_cis)

h = self.norm(h)
# unfold batch and sequence dimension
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen//bsz, self.params.dim)
output = self.output(h).float()
return output

Expand Down
106 changes: 104 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import logging

from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
Expand All @@ -16,12 +18,44 @@
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
PrepareModuleInput,
)

from torchtrain.logging_utils import rank0_log
from typing import Dict

logger = logging.getLogger(__name__)

def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
def prepare_input_fn(inputs, device_mesh):
if isinstance(inputs[0], DTensor):
return inputs
elif isinstance(inputs[0], torch.Tensor):
shard_tensor = DTensor.from_local(inputs[0], device_mesh, [Shard(0)], run_check=False)
return shard_tensor
else:
raise NotImplementedError("!!")

def partition_fn(name, module, device_mesh):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
)
module.register_parameter(name, dist_param)

return distribute_module(
module,
device_mesh,
partition_fn,
input_fn = prepare_input_fn,
)


# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
return ptd_checkpoint_wrapper(module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False)
Expand All @@ -40,7 +74,75 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")
if parallel_dims.sp_enabled:
raise NotImplementedError("SP not implemented yet.")
# First we apply Sequnece Parallelism if it's enabled
tp_mesh = world_mesh["sp"]
sp_degree = args.sp_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
# TODO: enable loss parallel once it's ready
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
input_layouts=Shard(0),
output_layouts=Replicate(),
),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
use_local_output=True,
)
}
)

# apply sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
}
# if layer_id == 0:
# # in first transformer block we need to shard the input
# layer_plan[""] = PrepareModuleInput(
# input_layouts=(Replicate(), None),
# desired_input_layouts=(Shard(0), None),
# )

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // sp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree

# shard RMSNorm layers
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh)
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh)

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

rank0_log(f"Applied Sequence Parallelism to the model...")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -70,6 +172,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
# wrap the rest layers with FSDP
model = wrap(model.cuda())

rank0_log(f"Applied parallelisms to the model...")
rank0_log(f"Applied FSDP to the model...")

return model

0 comments on commit 672308b

Please sign in to comment.