Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate LayerNormFp32 #850

Closed
wants to merge 1 commit into from
Closed

Conversation

EIFY
Copy link
Contributor

@EIFY EIFY commented Apr 1, 2024

Modern pytorch (1.10+) always performs LN in fp32:

For example, LayerNorm has to be done in fp32 and recent pytorch (1.10+) has been fixed to do that regardless of the input types, but earlier pytorch versions accumulate in the input type which can be an issue.

So it's no longer necessary to use LayerNormFp32 to explicitly cast to fp32. However, the built-in torch.nn.LayerNorm always returns in fp32 when run under the autocast() context, so we still need the LayerNorm subclass to cast back. See also pytorch/pytorch#66707 (comment).

Modern pytorch always performs LN in fp32.
@rwightman
Copy link
Collaborator

@EIFY I don't think this is quite the case, in an autocast context it returns float32 because it's upcast to float32 when AMP . But we aren't using this when AMP is enabled, it's used when pure float16/bfloat16 is enabled. Then it does make a difference. Even if the reduction is being done internally in float32, the affine ops will be done in low precision where as in LayerNormFp32 everything will be done in float32 regardless of the dtype.

@rwightman rwightman closed this May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants