You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi - thanks for the request. This issue has been discussed previously in #10065. But the short version is that we never meant __jax_array__ to be a fully-supported public API, and it's not clear that it's the best mechanism for downstream libraries to rely on. cc/ @mattjj who has more context.
Description
I am using Flax with the NNX API that requires arrays to be encapsulated within
nnx.Param
.The issue is that NNX is using the
__jax_array__
protocol to avoid users having to type.value
but JAX doesn't have full coverage for__jax_array__
.For example, this code from an NNX model raises the following error:
Error:
Maybe @cgarciae can provide more insights on this if necessary?
System info (python version, jaxlib version, accelerator, etc.)
python version: 3.12.6
jax: 0.4.34
jaxlib: 0.4.34
The text was updated successfully, but these errors were encountered: