Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RandomVariable graphs with scalar shape parameters in JAX backend #1029

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 11, 2024

This should make it possible to do forward sampling in more PyMC models that use dims to define variables shapes

def test_random_scalar_shape_input():
    dim0 = pt.scalar("dim0", dtype=int)
    dim1 = pt.scalar("dim1", dtype=int)

    out = pt.random.normal(0, 1, size=dim0)
    jax_fn = compile_random_function([dim0], out)
    assert jax_fn(np.array(2)).shape == (2,)
    assert jax_fn(np.array(3)).shape == (3,)

    out = pt.random.normal(0, 1, size=[dim0, dim1])
    jax_fn = compile_random_function([dim0, dim1], out)
    assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
    assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)

These was already a special rewrite to replace make_vector, expand_dims in the shape of RVs, but without handling these inputs from the outside it wouldn't achieve much for PyTensor users:

@node_rewriter([RandomVariable])
def size_parameter_as_tuple(fgraph, node):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `size` or `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
size_arg = node.inputs[1]
size_node = size_arg.owner
if size_node is None:
return
if isinstance(size_node.op, JAXShapeTuple):
return
if isinstance(size_node.op, MakeVector) or (
isinstance(size_node.op, DimShuffle)
and size_node.op.input_ndim == 0
and size_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
new_size_args = JAXShapeTuple()(*size_node.inputs)
new_inputs = list(node.inputs)
new_inputs[1] = new_size_args
new_node = node.clone_with_new_inputs(new_inputs)
return new_node.outputs


📚 Documentation preview 📚: https://pytensor--1029.org.readthedocs.build/en/1029/

@ricardoV94 ricardoV94 added enhancement New feature or request jax labels Oct 11, 2024
@ricardoV94 ricardoV94 changed the title Allow running RandomVariable graphs with scalar shape parameters in JAX backend Support RandomVariable graphs with scalar shape parameters in JAX backend Oct 11, 2024
@ricardoV94 ricardoV94 force-pushed the jax_scalar_rv_shapes branch 2 times, most recently from ff899b1 to 9b9bcba Compare October 11, 2024 09:44
Copy link

codecov bot commented Oct 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.11%. Comparing base (0824dba) to head (1e5c487).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1029      +/-   ##
==========================================
- Coverage   82.12%   82.11%   -0.02%     
==========================================
  Files         183      183              
  Lines       48111    48122      +11     
  Branches     8667     8668       +1     
==========================================
+ Hits        39510    39513       +3     
- Misses       6435     6439       +4     
- Partials     2166     2170       +4     
Files with missing lines Coverage Δ
pytensor/link/jax/linker.py 96.22% <100.00%> (+0.98%) ⬆️

... and 1 file with indirect coverage changes

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 11, 2024

This solves pymc-devs/pymc#7348

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant