Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.compile for loss_fn #597

Merged
merged 16 commits into from
Dec 8, 2023

Conversation

BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Dec 7, 2023

This PR enables torch.compile for loss functions, which mitigates the performance gap between PyTorch and JAX for fastmri (11.6% -> 1.83%), librispeech_deepspeech (12.02% -> -15.01%), ogbg (6.06% -> 3.28%), and wmt (19.75% -> 3.52%).

Close: #467
Close: #488

@BoyuanFeng BoyuanFeng requested a review from a team as a code owner December 7, 2023 22:50
Copy link

github-actions bot commented Dec 7, 2023

MLCommons CLA bot:
Thank you very much for your submission, we really appreciate it. Before we can accept your contribution, we ask that you sign the MLCommons CLA (Apache 2). Please use this [Google form] (https://forms.gle/Ew1KkBVpyeJDuRw67) to initiate authorization. If you are from an MLCommons member organization, we will request that you be added to the CLA. If you are not from a member organization, we will email you a CLA to sign. For any questions, please contact [email protected].
0 out of 1 committers have signed the MLCommons CLA.
@boyuan Feng
Boyuan Feng seems not to be a GitHub user. You need a GitHub account after you become MLCommons member. If you have already a GitHub account, please add the email address used for this commit to your account.
You can retrigger this bot by commenting recheck in this Pull Request

Copy link
Contributor

@priyakasimbeg priyakasimbeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting this PR in!
It seems like there is a small linting error from one of the lines in submission_runner.py. Could you fix this please? If you have yapf==0.32 installed you can just run yapf -i -r -vv -p submission_runner.py and commit the changes.

@priyakasimbeg priyakasimbeg changed the base branch from main to dev December 8, 2023 00:45
@BoyuanFeng BoyuanFeng closed this Dec 8, 2023
@BoyuanFeng BoyuanFeng reopened this Dec 8, 2023
@github-actions github-actions bot locked and limited conversation to collaborators Dec 8, 2023
@msaroufim msaroufim self-requested a review December 8, 2023 23:13
@priyakasimbeg priyakasimbeg merged commit a7b9650 into mlcommons:dev Dec 8, 2023
16 of 17 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Deepspeech pytorch timing gap and torch compile issues WMT slower in Pytorch than Jax
3 participants