-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable torch.compile for loss_fn #597
Enable torch.compile for loss_fn #597
Conversation
MLCommons CLA bot: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting this PR in!
It seems like there is a small linting error from one of the lines in submission_runner.py. Could you fix this please? If you have yapf==0.32 installed you can just run yapf -i -r -vv -p submission_runner.py
and commit the changes.
This PR enables
torch.compile
for loss functions, which mitigates the performance gap between PyTorch and JAX for fastmri (11.6% -> 1.83%), librispeech_deepspeech (12.02% -> -15.01%), ogbg (6.06% -> 3.28%), and wmt (19.75% -> 3.52%).Close: #467
Close: #488