diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 667806a80f..ef60be1f99 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -3,15 +3,19 @@ from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant from pytensor.link.basic import JITLinker class JAXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using JAX.""" + def __init__(self, *args, **kwargs): + self.scalar_shape_inputs: tuple[int] = () + super().__init__(*args, **kwargs) + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.jax.dispatch import jax_funcify + from pytensor.link.jax.dispatch.shape import JAXShapeTuple from pytensor.tensor.random.type import RandomType shared_rng_inputs = [ @@ -63,6 +67,21 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): fgraph.inputs.remove(new_inp) fgraph.inputs.insert(old_inp_fgrap_index, new_inp) + fgraph_inputs = fgraph.inputs + clients = fgraph.clients + # Detect scalar shape inputs that are used only in JAXShapeTuple nodes + scalar_shape_inputs = [ + inp + for node in fgraph.apply_nodes + if isinstance(node.op, JAXShapeTuple) + for inp in node.inputs + if inp in fgraph_inputs + and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp]) + ] + self.scalar_shape_inputs = tuple( + fgraph_inputs.index(inp) for inp in scalar_shape_inputs + ) + return jax_funcify( fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs ) @@ -70,12 +89,19 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def jit_compile(self, fn): import jax - # I suppose we can consider `Constant`s to be "static" according to - # JAX. - static_argnums = [ - n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant) - ] - return jax.jit(fn, static_argnums=static_argnums) + jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs) + + if not self.scalar_shape_inputs: + return jit_fn + + def convert_scalar_shape_inputs(*args): + new_args = [ + int(arg) if i in self.scalar_shape_inputs else arg + for i, arg in enumerate(args) + ] + return jit_fn(*new_args) + + return convert_scalar_shape_inputs def create_thunk_inputs(self, storage_map): from pytensor.link.jax.dispatch import jax_typify diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 917c7be763..137b3241f0 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -863,15 +863,55 @@ def test_random_concrete_shape_subtensor_tuple(self): jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (2,) + def test_random_scalar_shape_input(self): + dim0 = pt.scalar("dim0", dtype=int) + dim1 = pt.scalar("dim1", dtype=int) + + out = pt.random.normal(0, 1, size=dim0) + jax_fn = compile_random_function([dim0], out) + assert jax_fn(np.array(2)).shape == (2,) + assert jax_fn(np.array(3)).shape == (3,) + + out = pt.random.normal(0, 1, size=[dim0, dim1]) + jax_fn = compile_random_function([dim0, dim1], out) + assert jax_fn(np.array(2), np.array(3)).shape == (2, 3) + assert jax_fn(np.array(4), np.array(5)).shape == (4, 5) + @pytest.mark.xfail( - reason="`size_pt` should be specified as a static argument", strict=True + raises=TypeError, reason="Cannot convert scalar input to integer" ) - def test_random_concrete_shape_graph_input(self): - rng = shared(np.random.default_rng(123)) - size_pt = pt.scalar() - out = pt.random.normal(0, 1, size=size_pt, rng=rng) - jax_fn = compile_random_function([size_pt], out) - assert jax_fn(10).shape == (10,) + def test_random_scalar_shape_input_not_supported(self): + dim = pt.scalar("dim", dtype=int) + out1 = pt.random.normal(0, 1, size=dim) + # An operation that wouldn't work if we replaced 0d array by integer + out2 = dim[...].set(1) + jax_fn = compile_random_function([dim], [out1, out2]) + + res1, res2 = jax_fn(np.array(2)) + assert res1.shape == (2,) + assert res2 == 1 + + @pytest.mark.xfail( + raises=TypeError, reason="Cannot convert scalar input to integer" + ) + def test_random_scalar_shape_input_not_supported2(self): + dim = pt.scalar("dim", dtype=int) + # This could theoretically be supported + # but would require knowing that * 2 is a safe operation for a python integer + out = pt.random.normal(0, 1, size=dim * 2) + jax_fn = compile_random_function([dim], out) + assert jax_fn(np.array(2)).shape == (4,) + + @pytest.mark.xfail( + raises=TypeError, reason="Cannot convert tensor input to shape tuple" + ) + def test_random_vector_shape_graph_input(self): + shape = pt.vector("shape", shape=(2,), dtype=int) + out = pt.random.normal(0, 1, size=shape) + + jax_fn = compile_random_function([shape], out) + assert jax_fn(np.array([2, 3])).shape == (2, 3) + assert jax_fn(np.array([4, 5])).shape == (4, 5) def test_constant_shape_after_graph_rewriting(self): size = pt.vector("size", shape=(2,), dtype=int)