-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Sequence Parallelism to llama #32
Conversation
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 27c7a076c1549707c9d759e11aae51a245021940 Pull Request resolved: #32
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 73c094b08f6dda91bf79d19297bb88a2050c7286 Pull Request resolved: #32
parallelize_plan=layer_plan, | ||
) | ||
|
||
rank0_log(f"Applied Sequence Parallelism to the model...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if its useful to log more info about the SP plan. I was thinking about it for PP, what info do we want to print. Should each parallelism print its own summary, or should we have one overall function that prints overall parallel info in a unified way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 That's a good point. I think yeah we should probably log the parallelize plan for SP. This would require some changes in PyTorch to add __str__
to our ParallelStyles, I can add the log once the PyTorch PR is merged.
Should each parallelism print its own summary, or should we have one overall function that prints overall parallel info in a unified way
My two cents: It's a bit tricky to give overall summary. I think we can figure out how to even print the intended summary for each parallelism first, i.e. when transformerblock stacked too many, we can't log/print every layer parallel plan, so I think maybe we print pp
degree of transformerblock, and we might not want to print the SP plan for each PP transformerblock.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great to me! one inline question
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh) | ||
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also apply it on the final norm after all transformer blocks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sth currently enabled, but I think we can explore this in real training and see if shard the final norm would give additional memory/perf benefits :)
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 16fe643eb2ee10f45cf67d42fe20e063a1ad4669 Pull Request resolved: #32
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: c1b8f3a645bdd1568b81210f1775b700fd8c2336 Pull Request resolved: #32
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634 Pull Request resolved: #32
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634 Pull Request resolved: #32
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634 Pull Request resolved: #32
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2efe36e71eae71549d863cb3e128e92634 Pull Request resolved: #32
Stack from ghstack (oldest at bottom):
Somehow the torch.compile not working although eager sequence
parallelism working, so currently don't turn it on by default