diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index c4582846..6605a57d 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -64,9 +64,7 @@ def pipeline_llama_manual_split( """ pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() - microbatches = ( - job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp - ) + splits = ( job_config.experimental.pipeline_parallel_split_points or generate_split_points(job_config, parallel_dims.pp, model_config) @@ -117,7 +115,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal ) logger.info( f"PP rank {pp_rank} is building stage_idx {stage_idx}" - f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}" + f" with start_layer {start_layer}, stop_layer {stop_layer}" ) stages.append(stage) models.append(model_chunk) diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 1c90a0ea..fb5b565f 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os -from typing import Tuple +from typing import List, Tuple from torch.distributed.pipelining.schedules import ( _PipelineScheduleRuntime, @@ -12,10 +12,19 @@ PipelineScheduleMulti, PipelineScheduleSingle, ) +from torchtitan.config_manager import JobConfig from torchtitan.logging import logger +from torchtitan.models.llama.model import ModelArgs -def generate_split_points(job_config, pp_dim, model_config): +def generate_split_points( + job_config: JobConfig, pp_dim: int, model_config: ModelArgs +) -> List[str]: + """ + Generate a default split point based on the number of layers and + pipeline parallel dimension. + """ + schedule_class = get_schedule_class( job_config.experimental.pipeline_parallel_schedule ) @@ -51,7 +60,7 @@ def generate_split_points(job_config, pp_dim, model_config): current_layer += base_interval splits.append("layers." + str(current_layer)) logger.info( - f"No 'pipeline_parallel_split_points' so the generated splits are: {splits} \ + f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} \ This may be sub-optimal as the number of layers per stage may be unbalanced." ) return splits @@ -73,18 +82,26 @@ def build_pipeline_schedule(job_config, stages, loss_fn): ) looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) - logger.info( - f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" - ) n_microbatches = job_config.experimental.pipeline_parallel_microbatches + # We expect that the number of local stages (`len(stages)`) is the same across all ranks + num_total_stages = job_config.experimental.pipeline_parallel_degree * len(stages) if n_microbatches is None: - n_microbatches = job_config.experimental.pipeline_parallel_degree + n_microbatches = num_total_stages + elif n_microbatches < num_total_stages: + logger.warning( + f"Number of microbatches ({n_microbatches}) is less than the total number \ +of stages ({num_total_stages}) which may result in a bubble in the pipeline." + ) schedule = schedule_class( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn, ) + logger.info( + f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \ +with {n_microbatches} and {num_total_stages} stages." + ) if pp_schedule_csv: assert schedule_class in [