Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full coverage for __jax_array__ protocol #24460

Open
maxencefaldor opened this issue Oct 22, 2024 · 1 comment
Open

Full coverage for __jax_array__ protocol #24460

maxencefaldor opened this issue Oct 22, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@maxencefaldor
Copy link

maxencefaldor commented Oct 22, 2024

Description

I am using Flax with the NNX API that requires arrays to be encapsulated within nnx.Param.

The issue is that NNX is using the __jax_array__ protocol to avoid users having to type .value but JAX doesn't have full coverage for __jax_array__.

For example, this code from an NNX model raises the following error:

self.reshape_c_k = nnx.Param(reshape_c_k)
state_fft_k = jnp.dot(state_fft, self.reshape_c_k)

Error:

TypeError: Argument 'Param(value=Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=3/0)>)'
of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.

Maybe @cgarciae can provide more insights on this if necessary?

System info (python version, jaxlib version, accelerator, etc.)

python version: 3.12.6
jax: 0.4.34
jaxlib: 0.4.34

@maxencefaldor maxencefaldor added the bug Something isn't working label Oct 22, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Oct 22, 2024

Hi - thanks for the request. This issue has been discussed previously in #10065. But the short version is that we never meant __jax_array__ to be a fully-supported public API, and it's not clear that it's the best mechanism for downstream libraries to rely on. cc/ @mattjj who has more context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants