-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return jax.Array
#22
Comments
philip-paul-mueller
added a commit
to philip-paul-mueller/jace
that referenced
this issue
Oct 2, 2024
This addresses [issue#22](GridTools#22).
Solved by PR#26. |
philip-paul-mueller
added a commit
that referenced
this issue
Oct 4, 2024
This commit addresses the following issues: - It adds an implementation for auto optimizer (#24). The current implementation is not likely to stay, since it essentially uses DaCe's version, which is known to have problems with JaCe's SDFGs. - It allows to run stuff on GPU (#25). While it is possible it is still needed that the user explicitly specify it, JAX does an auto detection. - Instead of returning NumPy arrays JaCe now returns `jax.Array` objects (#22). This goes in tandem with a reworking of the type annotation, which was wrong before (and can not be correctly made).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
JAX always returns
jax.Array
objects but we either return NumPy arrays or CuPy (although not implemented).This will break drop in replacement.
If we are at it, we also have to modify the type annotation of the
jit()
wrapper.Currently the return type is preserved, however, since we change that type we should put it to
Any
.The text was updated successfully, but these errors were encountered: