Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix num_microbatches input for PP #781

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def pipeline_llama_manual_split(
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)

splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
Expand Down Expand Up @@ -117,7 +115,7 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}"
f" with start_layer {start_layer}, stop_layer {stop_layer}"
)
stages.append(stage)
models.append(model_chunk)
Expand Down
30 changes: 23 additions & 7 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Tuple
from typing import List, Tuple

from torch.distributed.pipelining.schedules import (
_PipelineScheduleRuntime,
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs


def generate_split_points(job_config, pp_dim, model_config):
def generate_split_points(
job_config: JobConfig, pp_dim: int, model_config: ModelArgs
) -> List[str]:
"""
Generate a default split point based on the number of layers and
pipeline parallel dimension.
"""

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
Expand Down Expand Up @@ -51,7 +60,7 @@ def generate_split_points(job_config, pp_dim, model_config):
current_layer += base_interval
splits.append("layers." + str(current_layer))
logger.info(
f"No 'pipeline_parallel_split_points' so the generated splits are: {splits} \
f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} \
This may be sub-optimal as the number of layers per stage may be unbalanced."
)
return splits
Expand All @@ -73,18 +82,25 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
)

looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
n_microbatches = job_config.experimental.pipeline_parallel_microbatches
num_stages = job_config.experimental.pipeline_parallel_degree * len(stages)
H-Huang marked this conversation as resolved.
Show resolved Hide resolved
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree
n_microbatches = num_stages
elif n_microbatches < num_stages:
logger.warning(
f"Number of microbatches ({n_microbatches}) is less than the total number \
of stages ({num_stages}) which may result in a bubble in the pipeline."
)

schedule = schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=loss_fn,
)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \
with {n_microbatches} and {num_stages} stages."
)

if pp_schedule_csv:
assert schedule_class in [
Expand Down
Loading