Skip to content

Commit

Permalink
update document for v0.2.0 (#91)
Browse files Browse the repository at this point in the history
**Description**
Add optimization level O3 in homepage
**Major Revision**
- Add O3 description
  • Loading branch information
tocean authored Jul 20, 2023
1 parent ecc8939 commit 716ad89
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
34 changes: 24 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Features:

- Support O1 optimization: Apply FP8 to weights and weight gradients and support FP8 in communication.
- Support O2 optimization: Support FP8 for two optimizers(Adam and AdamW).
- Provide three training examples using FP8: Swin-Transformer, DeiT and RoBERTa.
- Support O3 optimization: Support FP8 in DeepSpeed ZeRO optimizer.
- Provide four training examples using FP8: Swin-Transformer, DeiT, RoBERTa and GPT-3.

MS-AMP has the following benefit comparing with Transformer Engine:

Expand All @@ -28,10 +29,10 @@ MS-AMP has the following benefit comparing with Transformer Engine:
- CUDA version 11 or later (which can be checked by running `nvcc --version`).
- PyTorch version 1.13 or later (which can be checked by running `python -c "import torch; print(torch.__version__)"`).

We strongly recommend using [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). For example, to start PyTorch 1.13 container, run the following command:
We strongly recommend using [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). For example, to start PyTorch 1.14 container, run the following command:

```
sudo docker run -it -d --name=msamp --privileged --net=host --ipc=host --gpus=all nvcr.io/nvidia/pytorch:22.09-py3 bash
sudo docker run -it -d --name=msamp --privileged --net=host --ipc=host --gpus=all nvcr.io/nvidia/pytorch:22.12-py3 bash
sudo docker exec -it msamp bash
```

Expand Down Expand Up @@ -121,7 +122,17 @@ for batch_idx, (data, target) in enumerate(train_loader):
scaler.step(optimizer)
```

A runnable, comprehensive MNIST example demonstrating good practices can be found [here](./examples). For more examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples).
For applying MS-AMP to DeepSpeed ZeRO, add a "msamp" section in deepspeed config file:

```json
"msamp": {
"enabled": true,
"opt_level": "O3"
}
```

Runnable, comprehensive examples demonstrating good practices can be found [here](./examples).
For more examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples).

### Optimization Level

Expand All @@ -131,13 +142,16 @@ Currently MS-AMP supports two optimization levels: O1 and O2. Try both, and see

- O2: From O1 to O2, our main focus is on enabling the use of low-bit data formats for auxiliary tensors in the Adam/AdamW optimizer without any loss in accuracy. Specifically, we are able to maintain accuracy by representing the first-order optimizer state in FP8 and the second-order state in FP16. This optimization has the potential to save up to 62.5% of GPU memory for the optimizer when the model size is particularly large.

- O3: This optimization level is specifically designed for ZeRO-optimizer in advanced distributed traning framework DeepSpeed. ZeRO separates model weights into regular weights and master weights, with the former used for network forward/backward on each GPU, and the latter used for model updating in the optimizer. This separation allows us to use 8-bit data precision for regular weights and weight broadcasting, which reduces GPU memory and bandwidth usage even further.

Here are details of different MS-AMP optimization levels:
| Optimization Level | Computation(GEMM) | Comm | Weight | Weight Gradient | Optimizer States |
| ------------------- | ----------- | ----- | ------ | --------------- | ---------------- |
| FP16 AMP | FP16 | FP32 | FP32 | FP32 | FP32+FP32 |
| Nvidia TE | FP8 | FP32 | FP32 | FP32 | FP32+FP32 |
| MS-AMP O1 | FP8 | FP8 | FP16 | FP8 | FP32+FP32 |
| MS-AMP O2 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |
| Optimization Level | Computation(GEMM) | Comm | Weight | Master Weight | Weight Gradient | Optimizer States |
| ------------------- | ----------- | ----- | ------ | ------------- | --------------- | ---------------- |
| FP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
| Nvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
| MS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32 |
| MS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16 |
| MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |

## Performance

Expand Down
2 changes: 1 addition & 1 deletion msamp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def initialize(model, optimizer=None, opt_level='O1'): # noqa: C901
return cast_model, cast_optimizer


__version__ = '0.1.0'
__version__ = '0.2.0'
__author__ = 'Microsoft'
__all__ = ['clip_grad_norm_', 'initialize']

0 comments on commit 716ad89

Please sign in to comment.