This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fbshipit-source-id: bba680249506372ab6cf3ee6ef9a3988d0f540dc
- Loading branch information
0 parents
commit 52aed83
Showing
33 changed files
with
4,183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: Ufmt | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.10"] | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
pip install ufmt | ||
- name: Analyzing the code with ufmt | ||
run: | | ||
ufmt check . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
**/__pycache__/ | ||
float8_experimental/__pycache__/* | ||
finetune/__pycache__/* | ||
test/__pycache__/* | ||
tmp/* | ||
benchmarks/data/* | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Code of Conduct | ||
|
||
## Our Pledge | ||
|
||
In the interest of fostering an open and welcoming environment, we as | ||
contributors and maintainers pledge to make participation in our project and | ||
our community a harassment-free experience for everyone, regardless of age, body | ||
size, disability, ethnicity, sex characteristics, gender identity and expression, | ||
level of experience, education, socio-economic status, nationality, personal | ||
appearance, race, religion, or sexual identity and orientation. | ||
|
||
## Our Standards | ||
|
||
Examples of behavior that contributes to creating a positive environment | ||
include: | ||
|
||
* Using welcoming and inclusive language | ||
* Being respectful of differing viewpoints and experiences | ||
* Gracefully accepting constructive criticism | ||
* Focusing on what is best for the community | ||
* Showing empathy towards other community members | ||
|
||
Examples of unacceptable behavior by participants include: | ||
|
||
* The use of sexualized language or imagery and unwelcome sexual attention or | ||
advances | ||
* Trolling, insulting/derogatory comments, and personal or political attacks | ||
* Public or private harassment | ||
* Publishing others' private information, such as a physical or electronic | ||
address, without explicit permission | ||
* Other conduct which could reasonably be considered inappropriate in a | ||
professional setting | ||
|
||
## Our Responsibilities | ||
|
||
Project maintainers are responsible for clarifying the standards of acceptable | ||
behavior and are expected to take appropriate and fair corrective action in | ||
response to any instances of unacceptable behavior. | ||
|
||
Project maintainers have the right and responsibility to remove, edit, or | ||
reject comments, commits, code, wiki edits, issues, and other contributions | ||
that are not aligned to this Code of Conduct, or to ban temporarily or | ||
permanently any contributor for other behaviors that they deem inappropriate, | ||
threatening, offensive, or harmful. | ||
|
||
## Scope | ||
|
||
This Code of Conduct applies within all project spaces, and it also applies when | ||
an individual is representing the project or its community in public spaces. | ||
Examples of representing a project or community include using an official | ||
project e-mail address, posting via an official social media account, or acting | ||
as an appointed representative at an online or offline event. Representation of | ||
a project may be further defined and clarified by project maintainers. | ||
|
||
This Code of Conduct also applies outside the project spaces when there is a | ||
reasonable belief that an individual's behavior may have a negative impact on | ||
the project or its community. | ||
|
||
## Enforcement | ||
|
||
Instances of abusive, harassing, or otherwise unacceptable behavior may be | ||
reported by contacting the project team at <[email protected]>. All | ||
complaints will be reviewed and investigated and will result in a response that | ||
is deemed necessary and appropriate to the circumstances. The project team is | ||
obligated to maintain confidentiality with regard to the reporter of an incident. | ||
Further details of specific enforcement policies may be posted separately. | ||
|
||
Project maintainers who do not follow or enforce the Code of Conduct in good | ||
faith may face temporary or permanent repercussions as determined by other | ||
members of the project's leadership. | ||
|
||
## Attribution | ||
|
||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, | ||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html | ||
|
||
[homepage]: https://www.contributor-covenant.org | ||
|
||
For answers to common questions about this code of conduct, see | ||
https://www.contributor-covenant.org/faq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Contributing to float8_experimental | ||
We want to make contributing to this project as easy and transparent as | ||
possible. | ||
|
||
## Contributor License Agreement ("CLA") | ||
In order to accept your pull request, we need you to submit a CLA. You only need | ||
to do this once to work on any of Meta's open source projects. | ||
|
||
Complete your CLA here: <https://code.facebook.com/cla> | ||
|
||
## Issues | ||
We use GitHub issues to track public bugs. Please ensure your description is | ||
clear and has sufficient instructions to be able to reproduce the issue. | ||
|
||
|
||
## License | ||
By contributing to float8_experimental, you agree that your contributions will be licensed | ||
under the LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2023, PyTorch Labs | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
3. Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# float8_experimental | ||
|
||
This is a prototype of a float8 training UX in native PyTorch, with full PT2.0 and distributed support. | ||
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling. | ||
|
||
Backwards compatibility is not guaranteed at this point. The codebase is in active development and | ||
will change rapidly. | ||
|
||
# installation | ||
|
||
```Shell | ||
pip install . | ||
# Optionally install editable | ||
pip install -e . | ||
``` | ||
|
||
# User API, subject to change | ||
|
||
## single GPU | ||
|
||
```python | ||
from float8_experimental.float8_linear_utils import ( | ||
swap_linear_with_float8_linear, | ||
sync_float8_amax_and_scale_history, | ||
) | ||
from float8_experimental.float8_linear import Float8Linear | ||
|
||
# create fp32 model | ||
m = Model(...) | ||
|
||
# convert all `torch.nn.Linear` modules to `Float8Linear` | ||
swap_linear_with_float8_linear(m, Float8Linear) | ||
|
||
# toy training loop | ||
for _ in range(N_ITER): | ||
optimizer.zero_grad() | ||
y = m(x) | ||
y.sum().backward() | ||
|
||
# specific to float8: separate step to sync scales/amaxes | ||
# in the future, this may move to a context manager | ||
sync_float8_amax_and_scale_history(model) | ||
|
||
optimizer.step() | ||
``` | ||
|
||
## multi GPU | ||
|
||
```python | ||
from float8_experimental.tp_linear import swap_tp_linear_with_float8_linear | ||
|
||
# swaps the fairscale `ColumnParallelLinear` with `Float8ColumnParallelLinear`, | ||
# and the fairscale `RowParallelLinear` with `Float8RowParallelLinear` | ||
swap_tp_linear_with_float8_linear(model) | ||
|
||
# if applicable, enable sequence parallel on the right modules | ||
# TODO make the API for this nicer | ||
model.foo.bar.fc1.sequence_parallel = True | ||
model.foo.bar.fc2.sequence_parallel = True | ||
|
||
# the rest of the flow is the same as the single GPU flow | ||
``` | ||
|
||
# high level technical design | ||
|
||
## UX | ||
|
||
We are using a module swap UX to keep things simple. If the user model has `torch.nn.Linear` modules or their `fairscale` TP/SP equivalents, | ||
we can convert them to float8. `F.linear`, `torch.mm`, `torch.matmul` are not supported at the moment. | ||
|
||
User is responsible for calling the `sync_float8_amax_and_scale_history` function once per fw/bw, | ||
this function updates the amax history. If distributed is enabled, this function also syncs amax values across workers. | ||
This is a separate model level function (as opposed to each module owning the syncing of its buffers) to | ||
make it easier to optimize performance (for example, reduce all the amaxes once in a single tensor instead of doing N reductions). | ||
|
||
Composability with `DTensor` is on our radar and we plan to look into this after the manual flow works e2e. | ||
|
||
A user facing tensor subclass UX is not being considered at the moment because delayed scaling requires persistent state for | ||
activations, and there isn't a clean and sound way to implement this with tensor subclasses. | ||
|
||
## single GPU | ||
|
||
### separation of concerns | ||
|
||
1. `Float8Linear` owns casting X, W and dL/dY to float8 and does all the bookkeeping of the amax, amax_history and scale buffers | ||
2. user is responsible for applying `Float8Linear` to the right parts of their model with module swaps | ||
|
||
|
||
### Tensor subclasses | ||
|
||
We are using tensor subclasses (`Float8Tensor`) to write modular code which satisfies | ||
autograd's restriction that `x.dtype == x.grad.dtype`. The way we achieve this is by | ||
ensuring that instances of `Float8Tensor` set their dtype attribute to the original | ||
dtype (float32/float16/bfloat16) while the underlying data representation is in float8. | ||
If you look in `float8_linear.py` and `te_linear.py`, you will see that we pass instances of `Float8Tensor` | ||
around various `torch.autograd.Function` calls, enabling us to have modular code. | ||
|
||
## multi GPU | ||
|
||
### TP/SP | ||
|
||
`Float8ColumnParallelLinear` and `Float8RowParallelLinear` are replacements for the non-float8 TP/SP primitives. | ||
|
||
### FSDP with fp16 weight all-gather | ||
|
||
No change from single GPU code - it just works. | ||
|
||
### FSDP with fp8 weight all-gather | ||
|
||
FSDP with fp8 weight-all gather is currently under design. The problem can be separated into three parts: | ||
|
||
a. separation of concerns between user code and FSDP | ||
b. user code interaction with FSDP | ||
c. FSDP implementation of fp8 all-gather | ||
|
||
#### Separation of concerns between user code and FSDP | ||
|
||
We have alignment on the separation of concerns that we want: | ||
1. user code is responsible for making the model fp8 aware and adding the right buffers | ||
2. user code is responsible to passing FSDP the information necessary to cast weights to fp8: a way to tell if a weight should be cast to fp8, the weight's scale, and the Float8Tensor constructor | ||
3. FSDP is responsible for performing the fp8 cast and providing the unsharded fp8 weight to each worker | ||
4. user code is responsible for syncing amax metadata across workers and calculating scales | ||
|
||
This way, FSDP knows as little as possible about user logic - it just gets a list of weights + amax buffers + scales, | ||
and does the float8 fused cast + amax calculation. User code does everything else. | ||
|
||
#### User code interaction with FSDP | ||
|
||
We expect this to be trivial. First, when initializing FSDP, we will provide the necessary configuration | ||
to it as described above. Second, instead of `w_fp8 = cast_to_fp8(w)`, we will just check if `w` is already in fp8. | ||
|
||
#### FSDP implementation of fp8 all-gather | ||
|
||
This is in early design. The current `FlatParameter` design does not work cleanly with heterogeneous dtypes, | ||
and heterogeneous dtypes are required for a good UX, since for realistic models not all parameters | ||
(norm parameters, biases, etc) will be in float8. | ||
|
||
We are working on a new FSDP implementation that uses per-parameter sharding that will allow flexible fp8 all-gather. This is being prototyped currently. | ||
|
||
# code tips | ||
|
||
* `float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling) | ||
* `float8_experimental/dynamic_linear/dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling) | ||
* `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction | ||
* `float8_experimental/tp_linear.py` - `Float8ColumnParallelLinear` / `Float8RowParallelLinear` (TP/SP versions of float8 linear) | ||
|
||
# testing | ||
|
||
```bash | ||
# run single-GPU unit tests | ||
pytest test/test_base.py | ||
|
||
# run a single-GPU integration test on SAM | ||
pytest test/test_sam.py | ||
|
||
# run single-GPU compile tests | ||
pytest test/test_compile.py | ||
# run a two-GPU integration test on FSDP | ||
./test/test_fsdp.sh | ||
|
||
# run integration tests for TP/SP | ||
./test/test_tp.sh | ||
|
||
# run all of these tests | ||
./test/run_everything.sh | ||
``` | ||
|
||
# benchmarking | ||
|
||
```bash | ||
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes | ||
./benchmarks/bench_matmul.py | ||
|
||
# benchmark fw/bw of `Linear`, `Float8Linear` and `te.Linear` on LLaMa 2 70B shapes | ||
# make sure to turn on torch.compile to get the best performance | ||
./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile | ||
|
||
``` | ||
|
||
# License | ||
PyTorch has a BSD 3-Clause License, as found in the LICENSE file. | ||
|
Oops, something went wrong.