-
Notifications
You must be signed in to change notification settings - Fork 253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update performance and loss converging results #800
Open
tianyu-l
wants to merge
2
commits into
gh/tianyu-l/28/base
Choose a base branch
from
gh/tianyu-l/28/head
base: gh/tianyu-l/28/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,68 @@ | ||
To demonstrate the effectiveness of PyTorch distributed training techniques used in torchtitan, we report both the infra metrics and loss curves of Llama 3 (8B and 70B) training on 64 A100 (80GB memory) GPUs and Llama 3.1 (405B) on 128 H100 (94GB memory). | ||
We report infra metrics achieved by [FSDP2](fsdp.md) (1D parallelism) under various configurations, and loss curves for both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel) training. (We only report 2D for 405B) | ||
We demonstrate the effectiveness of elastic distributed training using torchtitan, via experiments on Llama 3.1 8B, 70B, and 405B models, from 1D parallelism to 4D parallelism, at the scale from 8 GPUs to 512 GPUs. | ||
|
||
The experiments are conducted on NVIDIA H100 GPUs[^1] with 95 GiB memory, where each host is equipped with 8 GPUs and NVSwitch. Two hosts form a rack connected to a TOR switch. A backend RDMA network connects the TOR switches. | ||
|
||
## Llama 3.1 performance numbers | ||
We note that, throughout our experimentation, memory readings are stable across the whole training process, whereas throughput numbers (TPS/GPU) are calculated and logged every 10 iterations, and always read at the (arbitrarily determined) 90th iteration. | ||
|
||
Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on the 405B model released in [Llama 3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). The way we compute WPS and MFU can be found in `train.py`. Because the model now is larger, we run on 128 H100 GPUs to test both performance and loss curves. Below is the performance result of 405B model with optimizations we have developed. We do see OOM for 1D parallelism (FSDP2), so we only tested 2D parallelism (FSDP2 + Tensor Parallel). | ||
We do not report Model FLOPS Utilization (MFU) because when Float8 is enabled, both BFLOAT16 Tensor Core and FP8 Tensor Core are involved in model training, but they have different peak FLOPS and the definition of MFU under such scenario is not well-defined. We note that the 1D Llama 3.1 8B model training on 8 or 128 H100 GPUs without Float8 achieves 33% to 39% MFU (with or without torch.compile, respectively). | ||
|
||
| Model size | Batch size | Activation checkpointing | WPS | MFU | Optimizations | | ||
| ----- | ----- | ----- | ----- | ----- | ----- | | ||
| 405B | 2 | full | 109 | 29.0%[^1] | None | ||
| 405B | 2 | full | 177 | 23.46%[^2] | Float8 | ||
| 405B | 2 | full | 185 | 24.52% | Float8 + Async TP | ||
**Table 1** 1D Parallelism (FSDP). Llama 3.1 8B model. 8 GPUs. Local batch size 2, global batch size 16. Selective activation checkpointing. | ||
|
||
Here, we use local batch size 2 (global batch size = local batch size 2 * number of FSDP ranks 16 = 32). | ||
| Techniques | TPS/GPU | Memory(GiB) | | ||
| ----- | ----: | ----: | | ||
| FSDP | 5,762 | 82.4 | | ||
| FSDP + torch.compile | 6,667 | 77.0 | | ||
| FSDP + torch.compile + Float8 | 8,532 | 76.8 | | ||
|
||
Next, we show the loss curves, all models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 32. We have to use full AC to save memory usage. The results are shown in the picture (a TensorBoard screenshot) below. | ||
**Table 2** FSDP + CP. Llama 3.1 8B model. 8 GPUs. Full activation checkpointing. Local batch size 1. torch.compile + Float8. | ||
|
||
![image](../assets/images/llama3_1_405B_loss_curves.png) | ||
| Parallelism | Sequence Length | TPS/GPU | Memory(GiB) | | ||
| ----- | ----: | ----: | ----: | | ||
| FSDP 8, CP 1 | 32768 | 3,890 | 83.9 | | ||
| FSDP 4, CP 2 | 65536 | 2,540 | 84.2 | | ||
| FSDP 2, CP 4 | 131072 | 1,071 | 84.0 | | ||
| FSDP 1, CP 8 | 262144 | 548 | 84.5 | | ||
|
||
## Llama 3 performance numbers | ||
**Table 3** 1D Parallelism (FSDP). Llama 3.1 8B model. 128 GPUs. Local batch size 2, global batch size 256. Selective activation checkpointing. | ||
|
||
Below are the WPS and MFU results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs. | ||
| Techniques | TPS/GPU | Memory(GiB) | | ||
| ----- | ----: | ----: | | ||
| FSDP | 5,605 | 67.0 | | ||
| FSDP + torch.compile | 6,514 | 62.0 | | ||
| FSDP + torch.compile + Float8 | 8,380 | 61.8 | | ||
|
||
| Model size | Batch size | Activation checkpointing | WPS | MFU | | ||
| ----- | ----- | ----- | ----- | ----- | | ||
| 8B | 1 | selective layer | 2904 | 56.8% | | ||
| 8B | 1 | selective op | 2973 | 58.2% | | ||
| 70B | 1 | full | 331 | 51.7% | | ||
**Table 4** 2D parallelism (FSDP + TP) + torch.compile + Float8. Llama 3.1 70B model. 256 GPUs (FSDP 32, TP 8). Local batch size 16, global batch size 512. Full activation checkpointing. | ||
|
||
We use local batch size 1 (global batch size = local batch size 1 * number of FSDP ranks 64 = 64), because it mimics the small local batch size in large scaled training, and moreoever allows us to compare 1D (FSDP) and 2D (FSDP + TP) training under the same global batch size on both 8B and 70B Llama 3 models, without the out-of-memory (OOM) issue. | ||
| Techniques | TPS/GPU | Memory(GiB) | | ||
| ----- | ----: | ----: | | ||
| 2D | 829 | 71.9 | | ||
| 2D + AsyncTP | 876 | 67.6 | | ||
|
||
Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained the same way as mentioned above with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below. | ||
**Table 5** 3D parallelism (FSDP + TP + PP) + torch.compile + Float8 + AsyncTP. Llama 3.1 405B model. 512 GPUs (FSDP 8, TP 8, PP8). Local batch size 32, global batch size 256. Full activation checkpointing. | ||
|
||
| Schedule | TPS/GPU | Memory(GiB)[^2] | | ||
| ----- | ----: | ----: | | ||
| 1F1B | 100 | 82.5 | | ||
| Interleaved 1F1B | 128 | 72.7 | | ||
|
||
**Table 6** 4D parallelism (FSDP + TP + PP + CP) + torch.compile + Float8 + AsyncTP + 1F1B. Llama 3.1 405B model. 512 GPUs (TP 8, PP8). Local batch size 8. Full activation checkpointing. | ||
|
||
| Parallelism | Sequence Length | TPS/GPU | Memory(GiB) | | ||
| ----- | ----: | ----: | ----: | | ||
| FSDP 8, CP 1 | 32768 | 76 | 75.3 | | ||
| FSDP 4, CP 2 | 65536 | 47 | 75.9 | | ||
| FSDP 2, CP 4 | 131072 | 31 | 77.1 | | ||
| FSDP 1, CP 8 | 262144 | 16 | 84.9 | | ||
|
||
|
||
#### Versions | ||
| repo | commit | date | | ||
| --- | --- | --- | | ||
| torch | [1963fc8](https://github.com/pytorch/pytorch/commit/1963fc83a1c32e162162e2414f78b043f0674bae) | 2024/12/23 | | ||
| torchao | [eab345c](https://github.com/pytorch/ao/commit/eab345c2268a7506355d506ebfc27b5d28e5e7d0) | 2024/12/23 | | ||
| torchtitan | [9dec370](https://github.com/pytorch/torchtitan/commit/9dec370ad26b5f8e9a7333a0e36165018262644b) | 2024/12/26 | | ||
|
||
![image](../assets/images/llama3_loss_curves.png) | ||
|
||
[^1]: We used HBM2e based lower TDP SXM H100(95GB) for our test, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly. | ||
|
||
[^2]: Since for Float8, we are not converting all the matmuls to Float8 because our fused attention implementation is not done in Float8, so this number is lower than expected. | ||
[^2]: Different PP ranks can have different peak memory usages. We take the maximum across all GPUs. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the recent Meta CP paper (https://arxiv.org/abs/2411.01783), they mentioned:
Would like to hear your thought on this @yifuwang