From 7a73979e0b712d0c9f2449189c179e524b7fc785 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 6 Feb 2024 23:15:28 -0800 Subject: [PATCH] Add Sequence Parallelism to llama Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634 Pull Request resolved: https://github.com/pytorch-labs/torchtrain/pull/32 --- torchtrain/models/llama/model.py | 16 ++- torchtrain/parallelisms/parallelize_llama.py | 114 ++++++++++++++++++- 2 files changed, 125 insertions(+), 5 deletions(-) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index ee504f7f..7485d32b 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -211,7 +211,10 @@ def forward( torch.Tensor: Output tensor after attention. """ - bsz, seqlen, _ = x.shape + seqlen, _ = freqs_cis.shape + bs_seqlen, _ = x.shape + bsz = bs_seqlen // seqlen + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) @@ -237,7 +240,8 @@ def forward( output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bsz, seqlen, -1) + # output stay folded with batch and sequence dimension + output = output.view(bsz * seqlen, -1) return self.wo(output) @@ -342,7 +346,6 @@ def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim - self.head_dim = args.dim // args.n_heads self.attention = Attention(args) self.feed_forward = FeedForward( dim=args.dim, @@ -422,10 +425,17 @@ def forward(self, tokens: torch.Tensor): """ h, freqs_cis = self.embeddings(tokens) + # fold batch and sequence dimension for more efficient allgather/reduce_scatter + h = h.view(-1, self.params.dim) for layer in self.layers: h = layer(h, freqs_cis) + h = self.norm(h) + # unfold batch and sequence dimension + bsz = tokens.shape[0] + bs_seqlen = h.shape[0] + h = h.view(bsz, bs_seqlen // bsz, self.params.dim) output = self.output(h).float() return output diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d4950ef6..dbf418ea 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -7,6 +7,13 @@ import logging import torch +from torch.distributed._tensor import ( + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -19,11 +26,46 @@ ShardingStrategy, ) from torch.distributed.fsdp.wrap import enable_wrap, wrap +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, +) from torchtrain.logging_utils import rank0_log logger = logging.getLogger(__name__) + +def distribute_rmsnorm(module, device_mesh): + # temp sharding API until PTD API is added + def prepare_input_fn(inputs, device_mesh): + if isinstance(inputs[0], DTensor): + return inputs + elif isinstance(inputs[0], torch.Tensor): + shard_tensor = DTensor.from_local( + inputs[0], device_mesh, [Shard(0)], run_check=False + ) + return shard_tensor + else: + raise NotImplementedError("!!") + + def partition_fn(name, module, device_mesh): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, [Replicate()]) + ) + module.register_parameter(name, dist_param) + + return distribute_module( + module, + device_mesh, + partition_fn, + input_fn=prepare_input_fn, + ) + + # Uses PTD FSDP AC wrapper def checkpoint_wrapper(module, config): return ptd_checkpoint_wrapper( @@ -43,7 +85,75 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): if parallel_dims.pp_enabled: raise NotImplementedError("PP not implemented yet.") if parallel_dims.sp_enabled: - raise NotImplementedError("SP not implemented yet.") + # First we apply Sequence Parallelism if it's enabled + tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh + sp_degree = args.sp_degree + # First: + # 1. parallelize the first embedding and the last linear proj layer + # 2. shard the first layer of transformer block + # TODO: enable loss parallel once it's ready + model = parallelize_module( + model, + tp_mesh, + { + "embeddings.tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + ), + "output": ColwiseParallel( + input_layouts=Shard(0), + output_layouts=Replicate(), + ), + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Shard(0), None), + use_local_output=True, + ), + }, + ) + + # apply sequence parallelism to every transformer block + for layer_id, transformer_block in enumerate(model.layers): + layer_plan = { + "attention": PrepareModuleInput( + input_layouts=(Shard(0), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(0)), + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(0),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)), + "feed_forward.w3": ColwiseParallel(), + } + # if layer_id == 0: + # # in first transformer block we need to shard the input + # layer_plan[""] = PrepareModuleInput( + # input_layouts=(Replicate(), None), + # desired_input_layouts=(Shard(0), None), + # ) + + # adjust num_heads in attention layer to local heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // sp_degree + attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree + + # shard RMSNorm layers + distribute_rmsnorm(transformer_block.attention_norm, tp_mesh) + distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + rank0_log("Applied Sequence Parallelism to the model...") + if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names @@ -73,6 +183,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): # wrap the rest layers with FSDP model = wrap(model.cuda()) - rank0_log("Applied parallelisms to the model...") + rank0_log("Applied FSDP to the model...") return model