Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX code is extremely slow on GPUs #24411

Open
AakashKumarNain opened this issue Oct 20, 2024 · 0 comments
Open

JAX code is extremely slow on GPUs #24411

AakashKumarNain opened this issue Oct 20, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@AakashKumarNain
Copy link

AakashKumarNain commented Oct 20, 2024

Description

Last week a discussion took place on Twitter where many people noticed that the performance of JAX on GPUs is still subpar compared to PyTorch code. @mattjj and I had a discussion afterwards and we agreed that a repo with minimal codebase that showcase the differences in performance between the two would be great.

GPT-2 is a good model to start with, and here is a repo that contains code both for JAX, and PyTorch for the same model. The instructions provided on the repo are enough to download and run code locally (on a GPU machine).

On my side, these are the results I got on an A100 40G machine:

JAX


jax_run

PyTorch


torch_run

Compared to Torch, JAX is extremely slow here, and I have no idea why. There is a chance of a silent bug somewhere in the JAX code, and I may have overlooked it. Given the number of times I have been through this, I think a fresh set of eyes would do better justice. Please let me know if you need any other information on this from my side.

System info (python version, jaxlib version, accelerator, etc.)

jax[cuda12]
jaxlib==0.4.34
equinox==0.11.8
optax @ git+https://github.com/google-deepmind/optax.git@85378ad4ce1c19dfd218c65873f8941776c3eaca

Optax is installed from git because there was a fix for adamw but that was not the part of the last release.

@AakashKumarNain AakashKumarNain added the bug Something isn't working label Oct 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant