-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP apply PP manually #308
Conversation
ghstack-source-id: d43a299e8427cd333f334f9891e294662295f43b Pull Request resolved: #308
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. Thanks for integrating Codecov - We've got you covered ☂️ |
ghstack-source-id: 3f2cc8b539fd88d8ba5099f6bc026ae31ad1005a Pull Request resolved: #308
runs 1D ok now. still broken for DP+PP (haven't tried TP) ghstack-source-id: 52da24a8b23e8631275171cf25b4e62575aced35 Pull Request resolved: #308
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty surprised to see that manual splitting isn't that complicated! Very interesting direction!
Have some comments about code structure inlined. I'm thinking if we can generalize the TransformerChunk
to be sth that can be commonly used by other models, but no need to happen in this PR i think
train.py
Outdated
|
||
model, pipe_info = apply_pipeline_parallelism( | ||
model, world_mesh, parallel_dims, job_config | ||
stage_models = extract_pipeline_stage_models_manual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if extract_pipeline_stage_models_manual
can be aligned with the graph split approach? i.e. maybe we always have apply_pipeline_parallelism
, and have a mode=["trace"/"manual"]
In this way, we could simply specify this as a cmd arg to switch between manual/trace mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i plan to align this.
next step is to align the UX for specifying the split points.
i had the idea that if we are given a dict of split points like "beginning of layer2", i can use that in either tracer or manual version, and we can somehow expose that in the toml or cmdline configs
train.py
Outdated
# Get example input | ||
if pp_rank == 0: | ||
input_shape = (job_config.training.batch_size, job_config.training.seq_len) | ||
input_ids = torch.randint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can these logics be put into the parallelize_llama
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which part?
In the tracing version we return the pipe info, and the model chunk. Then in train.py we create the Stage but don't need shape info since that came from the graph already in the pipe info.
For the manual version, the shape inference part happens during stage creation. I was asking @H-Huang and @kwen2501 about making that lazy on first forward, which could let us avoid all of this. that's my preference.
h = self.norm(h) | ||
output = h | ||
|
||
if self.output: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After looking into this, I feel these are some pretty common code, would love to brain storming this a bit more to see we could make this general enough
iiuc, this might be a bit tricky to generalize is that there's no easy way for us to know the output of a submodule would be look like, and whether it should just return the input directly without something like tracing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have been thinking about this a lot too.
i think one proposal could be to introduce new building blocks that users can build their modules out of. The two that come to mind are OptionalLayer
and OptionalModuleList
.
I haven't thought about this enough to write down a strong proposal though. I will try to do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A downside to having 'OptionalLayer' is that it might "look" like the layer exists on the stage, when its actually disabled. It'd be a shame if we registered a hook on the disabled layer, and the hook actually did something, which would be expected to only happen on some other stage where the layer actually exists.
The upside of it is that we can write simpler code in the top level model .forward() where we dont have to use conditionals.
At the end of the day, the transformer is simple enough that it might be OK to just check in the conditionals into the top level model, then hopefully I could totally delete the 'TransformerChunk' and 'DummyTransformerLayer' concepts.
return input | ||
|
||
|
||
class TransformerChunk(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if we can put DummyTransformerLayer
and TransformerChunk
to a separate file in this folder (i.e. pp_llama_utils.py) and import them here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to kill DummyTransformerLayer actually.
I thought about introducing some kind of PPModuleList
that can wrap an original ModuleList and delete some layers but preserve fqn's of the non-deleted layers. Then the forward code can still iterate over its remaining layers, but we dont run into stupid issues like stage1's TP tries to apply PrepareInput hook onto layers.0 which is a "DummyModule" but actually that hook will run. (This is a real issue that i hit, but its fixed now since your change to remove PrepareInput hook from layers.0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same issue has been discussed when doing EP. ModuleDict
was one potential solution at the time. But that also means that the model needs to use ModuleDict
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot about this. it solves my problem. thank you!
runs 1D+2D ok now. Working on TP. TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 66614e4fc893a46ae5b35f83e5e7f0bbf4b8624c Pull Request resolved: #308
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: eb4157a4534c62a2661a65c4c3b94974c4944e33 Pull Request resolved: #308
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 952f364427a3340c9de81caa8d47a4135a33d511 Pull Request resolved: #308
# inferring seqlen from forward(input) only works on stage0, bc on later stages | ||
# the hidden state input may have reduced seqlen due to TP. We need to use the | ||
# original (full) seqlen for freqs_cis to be correct. | ||
self.input_seqlen = input_seqlen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol actually is this a serious bug? i think seq_len might vary from batch to batch, but i was about to treat it as a static quantity.
A problem is that in the PP case, no tensor that is visible from the top level module has the right seq_len - the input tensor has the shape [bsz//#chunks, seqlen // TP], and I dont really want my code to know about TP and multiply that back in. Its not a problem in the non-PP case, since you have the full model and you can easily store the input seqlen from the top of the forward.
The solution space seems to be either
-
make seq_len an output/input of the pipeline stages (maybe have to wrap/unwrap in tensor to make PP happy), so later stages can access the real value
-
slice the self.freqs_cis buffer deeper inside the layer code, after we gather the sharded seq parallel activations and restore the input seqlen
I think 2 is nicer if we can do it easily. but i didnt try it yet to see if its easy
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 9caf325395e05a26916e8a4f70a4b4cff5f4c052 Pull Request resolved: #308
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 9f7ac5f2c2a787079e20a83e056c509e78cc0b19 Pull Request resolved: #308
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 5a09c62d439c6bd53d614fad9d470e84e404a8c3 Pull Request resolved: #308
squashed |
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 5a09c62d439c6bd53d614fad9d470e84e404a8c3 Pull Request resolved: #308
Stack from ghstack (oldest at bottom):