From c45f408af56b03c04cc34b591a6ee679752c6989 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 7 Jan 2025 13:52:50 -0800 Subject: [PATCH 1/2] fix num_microbatches input for PP [ghstack-poisoned] --- torchtitan/parallelisms/pipeline_llama.py | 6 ++--- torchtitan/parallelisms/pipelining_utils.py | 30 ++++++++++++++++----- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index c4582846..6605a57d 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -64,9 +64,7 @@ def pipeline_llama_manual_split( """ pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() - microbatches = ( - job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp - ) + splits = ( job_config.experimental.pipeline_parallel_split_points or generate_split_points(job_config, parallel_dims.pp, model_config) @@ -117,7 +115,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal ) logger.info( f"PP rank {pp_rank} is building stage_idx {stage_idx}" - f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}" + f" with start_layer {start_layer}, stop_layer {stop_layer}" ) stages.append(stage) models.append(model_chunk) diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 1c90a0ea..4a88f779 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os -from typing import Tuple +from typing import List, Tuple from torch.distributed.pipelining.schedules import ( _PipelineScheduleRuntime, @@ -12,10 +12,19 @@ PipelineScheduleMulti, PipelineScheduleSingle, ) +from torchtitan.config_manager import JobConfig from torchtitan.logging import logger +from torchtitan.models.llama.model import ModelArgs -def generate_split_points(job_config, pp_dim, model_config): +def generate_split_points( + job_config: JobConfig, pp_dim: int, model_config: ModelArgs +) -> List[str]: + """ + Generate a default split point based on the number of layers and + pipeline parallel dimension. + """ + schedule_class = get_schedule_class( job_config.experimental.pipeline_parallel_schedule ) @@ -51,7 +60,7 @@ def generate_split_points(job_config, pp_dim, model_config): current_layer += base_interval splits.append("layers." + str(current_layer)) logger.info( - f"No 'pipeline_parallel_split_points' so the generated splits are: {splits} \ + f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} \ This may be sub-optimal as the number of layers per stage may be unbalanced." ) return splits @@ -73,18 +82,25 @@ def build_pipeline_schedule(job_config, stages, loss_fn): ) looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) - logger.info( - f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" - ) n_microbatches = job_config.experimental.pipeline_parallel_microbatches + num_stages = job_config.experimental.pipeline_parallel_degree * len(stages) if n_microbatches is None: - n_microbatches = job_config.experimental.pipeline_parallel_degree + n_microbatches = num_stages + elif n_microbatches < num_stages: + logger.warning( + f"Number of microbatches ({n_microbatches}) is less than the total number \ +of stages ({num_stages}) which may result in a bubble in the pipeline." + ) schedule = schedule_class( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn, ) + logger.info( + f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \ +with {n_microbatches} and {num_stages} stages." + ) if pp_schedule_csv: assert schedule_class in [ From 40d2aec795e9ac052e9004f208bca86518a71649 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 7 Jan 2025 15:02:08 -0800 Subject: [PATCH 2/2] Update on "fix num_microbatches input for PP" PP is using the `pipeline_parallel_degree` as the default number of microbatches when it should be `pipeline_parallel_degree * len(stages)` for multi-stage schedules. Also fixed up some logging / added logging. [ghstack-poisoned] --- torchtitan/parallelisms/pipelining_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 4a88f779..fb5b565f 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -83,13 +83,14 @@ def build_pipeline_schedule(job_config, stages, loss_fn): looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) n_microbatches = job_config.experimental.pipeline_parallel_microbatches - num_stages = job_config.experimental.pipeline_parallel_degree * len(stages) + # We expect that the number of local stages (`len(stages)`) is the same across all ranks + num_total_stages = job_config.experimental.pipeline_parallel_degree * len(stages) if n_microbatches is None: - n_microbatches = num_stages - elif n_microbatches < num_stages: + n_microbatches = num_total_stages + elif n_microbatches < num_total_stages: logger.warning( f"Number of microbatches ({n_microbatches}) is less than the total number \ -of stages ({num_stages}) which may result in a bubble in the pipeline." +of stages ({num_total_stages}) which may result in a bubble in the pipeline." ) schedule = schedule_class( @@ -99,7 +100,7 @@ def build_pipeline_schedule(job_config, stages, loss_fn): ) logger.info( f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \ -with {n_microbatches} and {num_stages} stages." +with {n_microbatches} and {num_total_stages} stages." ) if pp_schedule_csv: