-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Pipeline Parallel (and 2D PP+FSDP) support #161
Conversation
ghstack-source-id: 14902407f0c573a4b4e9f615495b805af0ed8afc Pull Request resolved: #161
train.py
Outdated
logger.info( | ||
f"{Color.blue}Extracting pipeline module for stage {pp_mesh.get_local_rank()}{Color.reset}" | ||
) | ||
model = pmod.get_stage_module(pp_mesh.get_local_rank()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: watch out for rank-stage inequality in case of Interleaved 1F1B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i need to switch to an interleaved schedule and clean this up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the demo! LGTM!
traced module is burning in a 'meta' device arg for one 'ones' op which breaks runtime after moving model to 'cuda'. Haven't worked on loss fn yet. ghstack-source-id: 47735f666b6086e179699b1bbfb06168b488d4d4 Pull Request resolved: #161
Haven't worked on loss fn yet. ghstack-source-id: 4c438ddd2989e427489c4e2d5a9ddd35711bdb78 Pull Request resolved: #161
(fake) Loss now runs and propagates to logger ghstack-source-id: b5a290878909ebc67bbcfda25809be439e222523 Pull Request resolved: #161
Loss now runs and propagates to logger, but optimizer isn't working ghstack-source-id: 56b0ef0ed92d181126e6866a153316f00431c7e7 Pull Request resolved: #161
Loss now runs and propagates to logger, but optimizer isn't working ghstack-source-id: 56b0ef0ed92d181126e6866a153316f00431c7e7 Pull Request resolved: #161
Loss now runs and propagates to logger, but optimizer isn't working ghstack-source-id: 4ede08f5a9d1bc994448cb057bb491d24866d078 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now - supports 1D parallelism currently. WIP: support 2D/3D parallel and clean up seed-checkpoint ux ghstack-source-id: 7055ffe515b79fa6edad58a72543d9bc8e866f80 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now - supports 1D parallelism currently. WIP: support 2D/3D parallel and clean up seed-checkpoint ux ghstack-source-id: 7055ffe515b79fa6edad58a72543d9bc8e866f80 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 6bd801399be3f77a45d1dda11bc87e9a90b92df4 Pull Request resolved: #161
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Thanks for pulling PP in!
print("labels: ", labels.shape, labels.dtype) | ||
|
||
# Create a pipeline representation from the model | ||
pipe = pipeline(model, parallel_dims.pp, example_args=(input_ids,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: strictly speaking, the second arg is the number of microbatches -- it is okay if you using PP dim to represent it for now. Longer term I think it should be exposed as a field in the config file.
) | ||
|
||
# Get example input | ||
label_shape = input_shape = (8, 2048) # TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm would PP be working for all cases that are not this shape, or it requires the shape to be the exact input shape of the runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to double check how this works and fix.
# TODO(whc) need to fix PP + FSDP-mixed-precision | ||
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs | ||
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
param_dtype=torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't by default change this, this would make the cases where FSDP or FSDP + TP use fp32 instead of bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if supporting bf16 should be a criteria for landing. I would imagine that training with FSDP + PP in fp32 is not really viable efficiency-wise (at least for larger jobs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should fix this before landing the PP change. I think there was a possible way to fix this in the tracer, but lost track of it, will dig it up
train.py
Outdated
# there are virtual stages | ||
if parallel_dims.pp_enabled: | ||
stage = PipelineStage( | ||
pipe=pipe_meta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be pipe_meta
or model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its correct. Ke proposed an alternative, but we'd still have to pass the pipe_info and the model into _PipelineStage
in that case. I could make this change.
pipe=pipe_meta, | ||
stage_index=pp_rank, | ||
device=device, | ||
group=pp_mesh.get_group(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if we should put the stage creation into parallelize_llama, IMO we only need pp_schedule in train.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, I think this question and Ke's suggestion about returning a PipelineStage from parallelize_llama are better taken in context of a next PR that also adds support for looped schedules.
Looped schedules further complicate things bc the PP logic first needs to chunk up the model, then apply the DP/TP portion of parallelize_llama on each chunk, and finally pass all the chunks into the schedule.
I think in the end, I might prefer to separate out PP from parallelize_llama, and have a flow where we can take the return from PP apply function and iteratively call parallelize_llama on those chunks.
loss = ( | ||
torch.mean(torch.stack(losses)) | ||
if is_last_stage | ||
else torch.Tensor([-1.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need the default -1 value? because of logging purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, yea i could make it a 'None' but then i have to update logger to not log at all. maybe that's actually a better way to do it. let me try that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok- so what I could do is try to alter the metrics code so that on non-last-stage ranks, we omit printing loss, or, we print "loss: None" instead of -1.
The change will add more lines of code, since I need to deal with several places that expect loss and global_[avg/mean]_loss to be valid numbers
- avoid writing them into metrics dict
- replace their format string with a string value instead of a float value in the logger.info
- avoid calling loss.item() in the first place
I agree in principle that's the "right" fix, but i'm not sure if its worth the LOC / complexity. I don't totally hate the -1
thing.
Another option I considered is to skip the whole codeblock of '# log metrics' on non-last-stage ranks. I ruled this out, since it is still useful to log mfu, memory for other ranks.
So let me know what you want to do here @wanchaol
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 205f8b08eac15bb7bee66ecdec439b9828b0949c Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: cbbb628fd823d579064a8038e6511ec77457ef19 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 94f89f90787cca27310cb966a7edf7ea9bbc0098 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: ac8c37124f79f8246155e14da23c2f5cfd75c0de Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: feb45e115f7bbee37179887bb196c12d21d93b43 Pull Request resolved: #161
for i in range(1, parallel_dims.pp) | ||
} | ||
# Get example input | ||
label_shape = input_shape = (8, 2048) # TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501 any ideas for a clean way to do this in torchtrain? do we expect people to get a batch out of their dataloader and then reset it? or do we expect people to hardcode it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think what i might do is directly pass input_shape from train.py,
and in train.py i can set input_shape = (job_config.batch_size, job_config.seq_len) or something. is that clean enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok pushed a variation on this.
not sure if its better to hide this inside parallelize since we already have job config, or make it explicit from train.py that we are passing input_shape in for some reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way sounds okay to me -- eventually, the shape comes the config.
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 0616a1c0d40f8e51ddfc1b2d330dbddc491e00e2 Pull Request resolved: #161
layers_per_rank = len(model.layers) // parallel_dims.pp | ||
split_spec = { | ||
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING | ||
for i in range(1, parallel_dims.pp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm new to PP api and have a question:
If layers_per_rank
= 5, parallel_dims.pp
= 2, what should be the split_spec. My straightforward thought is SplitPoint.BEGINNING
should contain i = 1, 3, 5
, but according to the code it's just i = 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parallel_dims.pp
refers to the number of pipeline stages we split the model into.
For example, if model.layers
= 10, 10 // 2 = 5, then we put 5 layers per stage (i.e. layers_per_rank = 5
).
Hence we make a cut at model.layers.5
-- (nRanks - 1) split points.
squashed |
- dcp load seems to work now - need to pull in schedule object ghstack-source-id: cbbb8c9cd3b343952003b6314f1f2cc4a7a9e0cf Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 0616a1c0d40f8e51ddfc1b2d330dbddc491e00e2 Pull Request resolved: #161
Stack from ghstack (oldest at bottom):
chunks per stage
schedule and test other schedules)