Skip to content

Commit

Permalink
simplify embedding + first transformer block TP (#314)
Browse files Browse the repository at this point in the history
as titled, we can directly specify the rowwise parallel embedding output
layouts be shard on sequence dim, so that we don't need the first layer
prepare input.

Switching to output_layouts = Shard(1) would also trigger reduce_scatter
instead of allreduce for embedding layer, which could give some small
perf wins
  • Loading branch information
wanchaol authored May 8, 2024
1 parent 3295448 commit f5a3ad7
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(1), None),
use_local_output=True,
),
},
)

Expand Down

0 comments on commit f5a3ad7

Please sign in to comment.