-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: Forward sampling with dims fails when mode="JAX" #7286
Comments
Shapes are being set correctly, it's just that as you said they are mutable and JAX simply doesn't support that, at least not without static_argnums. When we write |
That may be a bit cumbersome? You don't want to raise if the shapes are constant, which happens after |
I guess I don't understand why the shared variable dims aren't replaced before forward sampling. At that point the shapes should all be known. |
We never did that. We do it for mcmc sampling in the JAX samplers, because it's a specific code path. We could do that, although I think the explicit freeze approach is better. I'm thinking of reintroducing caching and then it becomes very useful being able to compile a function that works for multiple dim lengths: #7177 |
Well it wasn't necessary before, because shapes induced by coords were fixed by default |
Yes, although even earlier they were also mutable by default. JAX backend for forward sampling is still niche use case, I wouldn't say we were officially supporting it yet. I'm also unhappy to put too much work around JAX inflexibility. One idea would be to go on Pytensor and be more clever about static_argnums, we could do some cheap checks to see if a variable is used directly as the shape of an Op (RVs, Alloc) and mark that variable as static_argnum when compiling the Jitted function. This only works for scalar variables, and we usually use 0d arrays, but maybe it's something we could work around. That is a more general QoL improvement as well? |
Yes, I agree that we should be using static_argnums. That would be using JAX's own work-around for the static shape problem, so the functions we produce that way would be no worse than a native JAX solution, which seems fine to me. |
Closing in favor of #7348 |
Describe the issue:
Shapes aren't being correct set on variables when using
coords
in JAX. I guess this is a consequence of coords being mutable by default, and could be addressed by usingfreeze_dims_and_data
as in #7263. If this is the case, perhaps we should check for themode='JAX'
compile_kwarg
in forward samplers and raise early with a more informative error?Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: