You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Last week a discussion took place on Twitter where many people noticed that the performance of JAX on GPUs is still subpar compared to PyTorch code. @mattjj and I had a discussion afterwards and we agreed that a repo with minimal codebase that showcase the differences in performance between the two would be great.
GPT-2 is a good model to start with, and here is a repo that contains code both for JAX, and PyTorch for the same model. The instructions provided on the repo are enough to download and run code locally (on a GPU machine).
On my side, these are the results I got on an A100 40G machine:
JAX
PyTorch
Compared to Torch, JAX is extremely slow here, and I have no idea why. There is a chance of a silent bug somewhere in the JAX code, and I may have overlooked it. Given the number of times I have been through this, I think a fresh set of eyes would do better justice. Please let me know if you need any other information on this from my side.
System info (python version, jaxlib version, accelerator, etc.)
Description
Last week a discussion took place on Twitter where many people noticed that the performance of JAX on GPUs is still subpar compared to PyTorch code. @mattjj and I had a discussion afterwards and we agreed that a repo with minimal codebase that showcase the differences in performance between the two would be great.
GPT-2 is a good model to start with, and here is a repo that contains code both for JAX, and PyTorch for the same model. The instructions provided on the repo are enough to download and run code locally (on a GPU machine).
On my side, these are the results I got on an A100 40G machine:
JAX
PyTorch
Compared to Torch, JAX is extremely slow here, and I have no idea why. There is a chance of a silent bug somewhere in the JAX code, and I may have overlooked it. Given the number of times I have been through this, I think a fresh set of eyes would do better justice. Please let me know if you need any other information on this from my side.
System info (python version, jaxlib version, accelerator, etc.)
Optax is installed from git because there was a fix for
adamw
but that was not the part of the last release.The text was updated successfully, but these errors were encountered: