Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP apply PP manually #308

Closed
wants to merge 9 commits into from
Closed

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented May 4, 2024

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 4, 2024
wconstab added a commit that referenced this pull request May 4, 2024
ghstack-source-id: d43a299e8427cd333f334f9891e294662295f43b
Pull Request resolved: #308
@codecov-commenter
Copy link

Welcome to Codecov 🎉

Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests.

Thanks for integrating Codecov - We've got you covered ☂️

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 6, 2024
ghstack-source-id: 3f2cc8b539fd88d8ba5099f6bc026ae31ad1005a
Pull Request resolved: #308
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 7, 2024
runs 1D ok now. still broken for DP+PP (haven't tried TP)

ghstack-source-id: 52da24a8b23e8631275171cf25b4e62575aced35
Pull Request resolved: #308
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty surprised to see that manual splitting isn't that complicated! Very interesting direction!

Have some comments about code structure inlined. I'm thinking if we can generalize the TransformerChunk to be sth that can be commonly used by other models, but no need to happen in this PR i think

train.py Outdated

model, pipe_info = apply_pipeline_parallelism(
model, world_mesh, parallel_dims, job_config
stage_models = extract_pipeline_stage_models_manual(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if extract_pipeline_stage_models_manual can be aligned with the graph split approach? i.e. maybe we always have apply_pipeline_parallelism, and have a mode=["trace"/"manual"]

In this way, we could simply specify this as a cmd arg to switch between manual/trace mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i plan to align this.

next step is to align the UX for specifying the split points.

i had the idea that if we are given a dict of split points like "beginning of layer2", i can use that in either tracer or manual version, and we can somehow expose that in the toml or cmdline configs

train.py Outdated
# Get example input
if pp_rank == 0:
input_shape = (job_config.training.batch_size, job_config.training.seq_len)
input_ids = torch.randint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can these logics be put into the parallelize_llama?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which part?

In the tracing version we return the pipe info, and the model chunk. Then in train.py we create the Stage but don't need shape info since that came from the graph already in the pipe info.

For the manual version, the shape inference part happens during stage creation. I was asking @H-Huang and @kwen2501 about making that lazy on first forward, which could let us avoid all of this. that's my preference.

torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
h = self.norm(h)
output = h

if self.output:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking into this, I feel these are some pretty common code, would love to brain storming this a bit more to see we could make this general enough

iiuc, this might be a bit tricky to generalize is that there's no easy way for us to know the output of a submodule would be look like, and whether it should just return the input directly without something like tracing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been thinking about this a lot too.

i think one proposal could be to introduce new building blocks that users can build their modules out of. The two that come to mind are OptionalLayer and OptionalModuleList.

I haven't thought about this enough to write down a strong proposal though. I will try to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A downside to having 'OptionalLayer' is that it might "look" like the layer exists on the stage, when its actually disabled. It'd be a shame if we registered a hook on the disabled layer, and the hook actually did something, which would be expected to only happen on some other stage where the layer actually exists.

The upside of it is that we can write simpler code in the top level model .forward() where we dont have to use conditionals.

At the end of the day, the transformer is simple enough that it might be OK to just check in the conditionals into the top level model, then hopefully I could totally delete the 'TransformerChunk' and 'DummyTransformerLayer' concepts.

torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
return input


class TransformerChunk(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we can put DummyTransformerLayer and TransformerChunk to a separate file in this folder (i.e. pp_llama_utils.py) and import them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to kill DummyTransformerLayer actually.

I thought about introducing some kind of PPModuleList that can wrap an original ModuleList and delete some layers but preserve fqn's of the non-deleted layers. Then the forward code can still iterate over its remaining layers, but we dont run into stupid issues like stage1's TP tries to apply PrepareInput hook onto layers.0 which is a "DummyModule" but actually that hook will run. (This is a real issue that i hit, but its fixed now since your change to remove PrepareInput hook from layers.0).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same issue has been discussed when doing EP. ModuleDict was one potential solution at the time. But that also means that the model needs to use ModuleDict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot about this. it solves my problem. thank you!

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 8, 2024
runs 1D+2D ok now. Working on TP.

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: 66614e4fc893a46ae5b35f83e5e7f0bbf4b8624c
Pull Request resolved: #308
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 9, 2024
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: eb4157a4534c62a2661a65c4c3b94974c4944e33
Pull Request resolved: #308
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 9, 2024
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: 952f364427a3340c9de81caa8d47a4135a33d511
Pull Request resolved: #308
# inferring seqlen from forward(input) only works on stage0, bc on later stages
# the hidden state input may have reduced seqlen due to TP. We need to use the
# original (full) seqlen for freqs_cis to be correct.
self.input_seqlen = input_seqlen
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol actually is this a serious bug? i think seq_len might vary from batch to batch, but i was about to treat it as a static quantity.

A problem is that in the PP case, no tensor that is visible from the top level module has the right seq_len - the input tensor has the shape [bsz//#chunks, seqlen // TP], and I dont really want my code to know about TP and multiply that back in. Its not a problem in the non-PP case, since you have the full model and you can easily store the input seqlen from the top of the forward.

The solution space seems to be either

  1. make seq_len an output/input of the pipeline stages (maybe have to wrap/unwrap in tensor to make PP happy), so later stages can access the real value

  2. slice the self.freqs_cis buffer deeper inside the layer code, after we gather the sharded seq parallel activations and restore the input seqlen

I think 2 is nicer if we can do it easily. but i didnt try it yet to see if its easy

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 9, 2024
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: 9caf325395e05a26916e8a4f70a4b4cff5f4c052
Pull Request resolved: #308
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 9, 2024
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: 9f7ac5f2c2a787079e20a83e056c509e78cc0b19
Pull Request resolved: #308
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 10, 2024
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: 5a09c62d439c6bd53d614fad9d470e84e404a8c3
Pull Request resolved: #308
@wconstab wconstab closed this May 10, 2024
@wconstab
Copy link
Contributor Author

squashed

tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

TODOs
- clean up manualstage creation
- config options for configuring stage split
- a way to switch between tracer/manual

ghstack-source-id: 5a09c62d439c6bd53d614fad9d470e84e404a8c3
Pull Request resolved: #308
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants