-
Notifications
You must be signed in to change notification settings - Fork 20
Changes to enable fp8 on multi devices #149
Conversation
y-sq
commented
Nov 20, 2023
- If the model is casted to bf16 (model = model.to(get_torch_dtype(dtype))), dtype = bf16 is also passed to the scale_a parameter of the float8 tensor, and caused
- The float8Linear classes used in multi-gpu are Float8Column/RowParallelLinear, sync_float8_amax_and_scale_history needs to identify these class types.
@y-sq has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
||
for name, child in model.named_modules(): | ||
if not isinstance(child, (Float8Linear)): | ||
if not any(isinstance(child, a) for a in fp8_classes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have removed the NoTs class I think this is likely rebase buggies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In multi-gpu cases, we have Float8ColumnParallelLinear and Float8RowParallelLinear (which have dependencies of external distributed training code) as the fp8 classes. So I modified here to pass the class types to sync_float8_amax_and_scale_history.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahhh I see I was misreading this, but if we make fp8_classes a Tuple[types] couldn't we still keep the check as is, this is a nit anyways both accomplish the same thing
Looks great, can we add a test? |
It's strange that format check passed on my side but failed on github checks..
|
Dont worry about the format, there is deviation between ufmt and internal, and I can't pin it down, so those were not likely caused by anything you did |
# Cast the module to dtype | ||
m = m.to(dtype=linear_dtype) | ||
|
||
# autocast off |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this test does cover it but also could we assert the buffer types are still fp32
@y-sq has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 😄 !