Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 21, 2024
1 parent 6665a52 commit d7463ab
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 84 deletions.
7 changes: 4 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ def __init__(self):
type=int,
default=None,
help="""
How many microbatches to split the full training batch into when using pipeline parallelism.
How many microbatches to split the global training batch into when using pipeline parallelism.
The overall training batch size must be evenly divisible by the number of microbatches.
The global training batch size must be evenly divisible by the number of microbatches.
The default value will be the number of pipeline stages, if unspecified.
""",
Expand Down Expand Up @@ -500,7 +500,8 @@ def parse_args_from_command_line(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
# type inference breaks here, since the type is just 'list' and it ends up flattening
# without this special case, type inference breaks here,
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
else:
Expand Down
92 changes: 32 additions & 60 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.pipelining_utils import split_stage_fqns

# for selective AC
no_recompute_list = {
Expand Down Expand Up @@ -134,19 +133,6 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def _llama_fqns(num_layers):
return (
[
"tok_embeddings",
]
+ [f"layers.{i}" for i in range(num_layers)]
+ [
"norm",
"output",
]
)


def pipeline_llama(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
):
Expand Down Expand Up @@ -177,9 +163,12 @@ def pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
):
"""
This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages).
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
The SPMD parallelisms should be applied to
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
parallelism.
"""
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
Expand All @@ -188,74 +177,57 @@ def pipeline_llama_manual(
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
stage_idx = pp_rank
this_stage_layer_names = split_stage_fqns(
_llama_fqns(len(model.layers)),
job_config.experimental.pipeline_parallel_split_points,
pp_rank,
)

if pp_rank < pp_size - 1:
model.norm = None
model.output = None
splits = job_config.experimental.pipeline_parallel_split_points
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None

if pp_rank > 0:
model.tok_embeddings = None
names = list(model.layers.keys())
for name in names:
if f"layers.{name}" not in this_stage_layer_names:

drop_layers = True
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if start_layer is None or f"layers.{name}" == start_layer:
drop_layers = False
if stop_layer is not None and f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]

if pp_rank < pp_size - 1:
model.norm = None
model.output = None

logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}")

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.

mp_arg = job_config.training.mixed_precision_param
mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size)
if pp_rank == 0:
# first layer
input = torch.randint(
model_config.vocab_size,
size=(job_config.training.batch_size, job_config.training.seq_len),
size=(batch_size, job_config.training.seq_len),
dtype=torch.int64,
device=device,
)
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(
size=(
job_config.training.batch_size,
int(job_config.training.seq_len // parallel_dims.tp),
model_config.dim,
),
dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
if parallel_dims.dp_enabled
else torch.float32,
device=device,
)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

if pp_rank == pp_size - 1:
# last layer
output = torch.rand(
size=(
job_config.training.batch_size,
int(job_config.training.seq_len // parallel_dims.tp),
model_config.vocab_size,
),
dtype=torch.float32,
device=device,
)
output = torch.rand(output_layer_shape, dtype=torch.float32, device=device)
else:
# earlier layers (assume all end in a transformer layer)
output = torch.rand(
size=(
job_config.training.batch_size,
int(job_config.training.seq_len // parallel_dims.tp),
model_config.dim,
),
dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
if parallel_dims.dp_enabled
else torch.float32,
device=device,
)
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

model.to_empty(device=device)
stage = ManualPipelineStage(
Expand Down
21 changes: 0 additions & 21 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,3 @@ def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn):
n_microbatches=stage.chunks,
loss_fn=loss_fn,
)


def split_stage_fqns(fqns, split_points, stage_id):
"""Helper for splitting ordered list of layer names into layers per stage.
split_points is a list of layer names, each layer will be the first layer in a stage
"""
stages = []
cur = []

for name in fqns:
if name in split_points:
assert len(
cur
), f"{name} is not a valid split point, do not specify the first layer of stage 0"
stages.append(cur)
cur = []
cur.append(name)

stages.append(cur)
return stages[stage_id]

0 comments on commit d7463ab

Please sign in to comment.