Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Forward sampling with dims fails when mode="JAX" #7286

Closed
jessegrabowski opened this issue Apr 28, 2024 · 8 comments
Closed

BUG: Forward sampling with dims fails when mode="JAX" #7286

jessegrabowski opened this issue Apr 28, 2024 · 8 comments

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Apr 28, 2024

Describe the issue:

Shapes aren't being correct set on variables when using coords in JAX. I guess this is a consequence of coords being mutable by default, and could be addressed by using freeze_dims_and_data as in #7263. If this is the case, perhaps we should check for the mode='JAX' compile_kwarg in forward samplers and raise early with a more informative error?

Reproduceable code example:

import pymc as pm

# Fails
with pm.Model(coords={'a':['1']}) as m:
    x = pm.Normal('x', dims=['a'])
    pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})

# Works
with pm.Model() as m:
    x = pm.Normal('x', shape=(1,))
    pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
    193 for thunk, node, old_storage in zip(
    194     thunks, order, post_thunk_old_storage
    195 ):
--> 196     thunk()
    197     for old_s in old_storage:

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    654 def thunk(
    655     fgraph=self.fgraph,
    656     fgraph_jit=fgraph_jit,
    657     thunk_inputs=thunk_inputs,
    658     thunk_outputs=thunk_outputs,
    659 ):
--> 660     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    662     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 11 frame]

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
      4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
      6 return x, variable

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
    105 def sample_fn(rng, size, dtype, *parameters):
--> 106     return jax_sample_fn(op)(rng, size, out_dtype, *parameters)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
    163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
    165 rng["jax_state"] = rng_key

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
    709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
    711 return _normal(key, shape, dtype)

    [... skipping hidden 2 frame]

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
   2141   pass
-> 2142 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[19], line 1
----> 1 pm.draw(x, mode='JAX')

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/sampling/forward.py:314, in draw(vars, draws, random_seed, **kwargs)
    311 draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
    313 if draws == 1:
--> 314     return draw_fn()
    316 # Single variable output
    317 if not isinstance(vars, list | tuple):

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:200, in streamline.<locals>.streamline_default_f()
    198             old_s[0] = None
    199 except Exception:
--> 200     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:523, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    518     warnings.warn(
    519         f"{exc_type} error does not allow us to add an extra error message"
    520     )
    521     # Some exception need extra parameter in inputs. So forget the
    522     # extra long error message in that case.
--> 523 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
    192 try:
    193     for thunk, node, old_storage in zip(
    194         thunks, order, post_thunk_old_storage
    195     ):
--> 196         thunk()
    197         for old_s in old_storage:
    198             old_s[0] = None

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    654 def thunk(
    655     fgraph=self.fgraph,
    656     fgraph_jit=fgraph_jit,
    657     thunk_inputs=thunk_inputs,
    658     thunk_outputs=thunk_outputs,
    659 ):
--> 660     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    662     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    663         compute_map[o_var][0] = True

    [... skipping hidden 11 frame]

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
      3 tensor_variable = shape_tuple_fn(a)
      4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
      6 return x, variable

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
    105 def sample_fn(rng, size, dtype, *parameters):
--> 106     return jax_sample_fn(op)(rng, size, out_dtype, *parameters)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
    162 rng_key, sampling_key = jax.random.split(rng_key, 2)
    163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
    165 rng["jax_state"] = rng_key
    166 return (rng, sample)

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
    707   raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
    708                    f"got {dtype}")
    709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
    711 return _normal(key, shape, dtype)

    [... skipping hidden 2 frame]

File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
   2140 except TypeError:
   2141   pass
-> 2142 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
Toposort index: 1
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(1,)), TensorType(int64, shape=()), TensorType(int8, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes', ()]
Inputs strides: ['No strides', ()]
Inputs values: [{'bit_generator': 1, 'state': {'state': 5504079417979030970, 'inc': 4407794720271215875}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1281518353, 2620247482], dtype=uint32)}, array(1)]
Outputs clients: [['output'], ['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_14410/452414321.py", line 2, in <module>
    x = pm.Normal('x', dims=['a'])
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 554, in __new__
    rv_out = cls.dist(*args, **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/continuous.py", line 511, in dist
    return super().dist([mu, sigma], **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 633, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyMC version information:

pymc: 5.13.1

Context for the issue:

No response

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 28, 2024

Shapes are being set correctly, it's just that as you said they are mutable and JAX simply doesn't support that, at least not without static_argnums. When we write x = pm.Normal("x", dims="trial") PyMC is writing x = pt.random.normal(size=trial_length), where trial_length is a shared scalar variable.

@ricardoV94
Copy link
Member

If this is the case, perhaps we should check for the mode='JAX' compile_kwarg in forward samplers and raise early with a more informative error?

That may be a bit cumbersome? You don't want to raise if the shapes are constant, which happens after freeze_rv_and_dims or if a user specified shape besides dims, and introspecting the graph to assess which case it is could be messy/costly

@jessegrabowski
Copy link
Member Author

I guess I don't understand why the shared variable dims aren't replaced before forward sampling. At that point the shapes should all be known.

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 28, 2024

I guess I don't understand why the shared variable dims aren't replaced before forward sampling. At that point the shapes should all be known.

We never did that. We do it for mcmc sampling in the JAX samplers, because it's a specific code path.

We could do that, although I think the explicit freeze approach is better. I'm thinking of reintroducing caching and then it becomes very useful being able to compile a function that works for multiple dim lengths: #7177

@jessegrabowski
Copy link
Member Author

Well it wasn't necessary before, because shapes induced by coords were fixed by default

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 28, 2024

Well it wasn't necessary before, because shapes induced by coords were fixed by default

Yes, although even earlier they were also mutable by default.

JAX backend for forward sampling is still niche use case, I wouldn't say we were officially supporting it yet.

I'm also unhappy to put too much work around JAX inflexibility.

One idea would be to go on Pytensor and be more clever about static_argnums, we could do some cheap checks to see if a variable is used directly as the shape of an Op (RVs, Alloc) and mark that variable as static_argnum when compiling the Jitted function. This only works for scalar variables, and we usually use 0d arrays, but maybe it's something we could work around.

That is a more general QoL improvement as well?

@jessegrabowski
Copy link
Member Author

Yes, I agree that we should be using static_argnums. That would be using JAX's own work-around for the static shape problem, so the functions we produce that way would be no worse than a native JAX solution, which seems fine to me.

@ricardoV94
Copy link
Member

Closing in favor of #7348

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants