Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Changes to enable fp8 on multi devices #149

Closed
wants to merge 5 commits into from
Closed

Changes to enable fp8 on multi devices #149

wants to merge 5 commits into from

Conversation

y-sq
Copy link
Contributor

@y-sq y-sq commented Nov 20, 2023

  • If the model is casted to bf16 (model = model.to(get_torch_dtype(dtype))), dtype = bf16 is also passed to the scale_a parameter of the float8 tensor, and caused
output, output_amax = torch._scaled_mm(
RuntimeError: scale_a must be float scalar
  • The float8Linear classes used in multi-gpu are Float8Column/RowParallelLinear, sync_float8_amax_and_scale_history needs to identify these class types.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 20, 2023
@drisspg drisspg changed the title Changes to enable fp8 in xlformers Changes to enable fp8 Nov 20, 2023
@drisspg drisspg changed the title Changes to enable fp8 Changes to enable fp8 on multi devices Nov 20, 2023
@y-sq y-sq marked this pull request as ready for review November 20, 2023 21:35
@facebook-github-bot
Copy link
Contributor

@y-sq has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.


for name, child in model.named_modules():
if not isinstance(child, (Float8Linear)):
if not any(isinstance(child, a) for a in fp8_classes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have removed the NoTs class I think this is likely rebase buggies

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi-gpu cases, we have Float8ColumnParallelLinear and Float8RowParallelLinear (which have dependencies of external distributed training code) as the fp8 classes. So I modified here to pass the class types to sync_float8_amax_and_scale_history.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I see I was misreading this, but if we make fp8_classes a Tuple[types] couldn't we still keep the check as is, this is a nit anyways both accomplish the same thing

@drisspg
Copy link
Contributor

drisspg commented Nov 21, 2023

Looks great, can we add a test?

@y-sq
Copy link
Contributor Author

y-sq commented Nov 21, 2023

It's strange that format check passed on my side but failed on github checks..

ufmt --version
ufmt, version 2.3.0
ufmt check .
✨ 22 files already formatted ✨

@drisspg
Copy link
Contributor

drisspg commented Nov 21, 2023

Dont worry about the format, there is deviation between ufmt and internal, and I can't pin it down, so those were not likely caused by anything you did

# Cast the module to dtype
m = m.to(dtype=linear_dtype)

# autocast off
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this test does cover it but also could we assert the buffer types are still fp32

@facebook-github-bot
Copy link
Contributor

@y-sq has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 😄 !

@facebook-github-bot
Copy link
Contributor

@y-sq merged this pull request in 77386ba.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants