Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable SequenceParallel in te.LayerNormMLP layer #1

Draft
wants to merge 7 commits into
base: cfgte
Choose a base branch
from
Draft

Conversation

vchiley
Copy link
Owner

@vchiley vchiley commented Jul 7, 2023

Given mosaicml#432 we can relatively easily enable SequenceParallel training.

This branch is just for testing for now; if it works well, we can consider merging it.

Note: I'm pretty sure model ckpt using tp, then load without tp will be a headache and will require tooling be built.

@vchiley vchiley self-assigned this Jul 7, 2023
@vchiley vchiley marked this pull request as draft July 7, 2023 00:29
@vchiley
Copy link
Owner Author

vchiley commented Jul 8, 2023

On a single node of 8 GPUs

MPT 125M training with / without Sequence Parallel is nearly identical with TP size in [2, 4, 8]
Screenshot 2023-07-07 at 5 47 31 PM
BUT, since the 125M model is small, its always slower to use TP vs the standard model even if mbs=1
Screenshot 2023-07-07 at 5 49 43 PM

If we bump up the model size to 7B, at mbs=1, a TP world size of 2 and 4 are slightly faster than the baseline
Screenshot 2023-07-07 at 5 55 12 PM
At mbs=2, the advantage does away.

If we bump up the model size to 13B, at mbs=1, again a TP world size of 2 and 4 are slightly faster than the baseline
Screenshot 2023-07-07 at 5 56 40 PM
At mbs=2, a TP world size of 2 is almost as fast, but again the advantage goes away

Note this is done on one node and does not factor in interconnect speed.
wandb here

@vchiley
Copy link
Owner Author

vchiley commented Jul 8, 2023

Issue: FSDP operates on the module level. The module has [layer_norm_weight, layer_norm_bias, fc1_weight, fc1_bias, fc2_weight, fc2_bias] parameters, where FSDP needs to tree [layer_norm_weight, layer_norm_bias, fc2_bias] the standard sharded way, and [fc1_weight, fc1_bias, fc2_weight] need to be treated standard TP sharded way...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant