Skip to content

Commit

Permalink
Add Sequence Parallelism to llama
Browse files Browse the repository at this point in the history
Somehow the torch.compile not working although eager sequence
parallelism working, so currently don't turn it on by default

ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634
Pull Request resolved: #32
  • Loading branch information
wanchaol committed Feb 7, 2024
1 parent 6bd9082 commit 7a73979
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 5 deletions.
16 changes: 13 additions & 3 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def forward(
torch.Tensor: Output tensor after attention.
"""
bsz, seqlen, _ = x.shape
seqlen, _ = freqs_cis.shape
bs_seqlen, _ = x.shape
bsz = bs_seqlen // seqlen

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
Expand All @@ -237,7 +240,8 @@ def forward(
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bsz, seqlen, -1)
# output stay folded with batch and sequence dimension
output = output.view(bsz * seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -342,7 +346,6 @@ def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
Expand Down Expand Up @@ -422,10 +425,17 @@ def forward(self, tokens: torch.Tensor):
"""
h, freqs_cis = self.embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.params.dim)

for layer in self.layers:
h = layer(h, freqs_cis)

h = self.norm(h)
# unfold batch and sequence dimension
bsz = tokens.shape[0]
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen // bsz, self.params.dim)
output = self.output(h).float()
return output

Expand Down
114 changes: 112 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import logging

import torch
from torch.distributed._tensor import (
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand All @@ -19,11 +26,46 @@
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
)

from torchtrain.logging_utils import rank0_log

logger = logging.getLogger(__name__)


def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
def prepare_input_fn(inputs, device_mesh):
if isinstance(inputs[0], DTensor):
return inputs
elif isinstance(inputs[0], torch.Tensor):
shard_tensor = DTensor.from_local(
inputs[0], device_mesh, [Shard(0)], run_check=False
)
return shard_tensor
else:
raise NotImplementedError("!!")

def partition_fn(name, module, device_mesh):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
)
module.register_parameter(name, dist_param)

return distribute_module(
module,
device_mesh,
partition_fn,
input_fn=prepare_input_fn,
)


# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
return ptd_checkpoint_wrapper(
Expand All @@ -43,7 +85,75 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")
if parallel_dims.sp_enabled:
raise NotImplementedError("SP not implemented yet.")
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
sp_degree = args.sp_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
# TODO: enable loss parallel once it's ready
model = parallelize_module(
model,
tp_mesh,
{
"embeddings.tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
input_layouts=Shard(0),
output_layouts=Replicate(),
),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
use_local_output=True,
),
},
)

# apply sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
}
# if layer_id == 0:
# # in first transformer block we need to shard the input
# layer_plan[""] = PrepareModuleInput(
# input_layouts=(Replicate(), None),
# desired_input_layouts=(Shard(0), None),
# )

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // sp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree

# shard RMSNorm layers
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh)
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh)

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

rank0_log("Applied Sequence Parallelism to the model...")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -73,6 +183,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
# wrap the rest layers with FSDP
model = wrap(model.cuda())

rank0_log("Applied parallelisms to the model...")
rank0_log("Applied FSDP to the model...")

return model

0 comments on commit 7a73979

Please sign in to comment.