From 9ac810cfd4f85f88fd598dca72bedd8370304b42 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 11 Oct 2024 11:24:00 +0200 Subject: [PATCH 1/3] Default to JAX test mode in random tests --- tests/link/jax/test_random.py | 50 +++++++++++++++++------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index a01f5e3f46..6915113c6d 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -27,7 +27,7 @@ from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 -def compile_random_function(*args, mode="JAX", **kwargs): +def compile_random_function(*args, mode=jax_mode, **kwargs): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): @@ -42,7 +42,7 @@ def test_random_RandomStream(): srng = RandomStream(seed=123) out = srng.normal() - srng.normal() - fn = compile_random_function([], out, mode=jax_mode) + fn = compile_random_function([], out) jax_res_1 = fn() jax_res_2 = fn() @@ -55,7 +55,7 @@ def test_random_updates(rng_ctor): rng = shared(original_value, name="original_rng", borrow=False) next_rng, x = pt.random.normal(name="x", rng=rng).owner.outputs - f = compile_random_function([], [x], updates={rng: next_rng}, mode=jax_mode) + f = compile_random_function([], [x], updates={rng: next_rng}) assert f() != f() # Check that original rng variable content was not overwritten when calling jax_typify @@ -479,7 +479,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c """ rng = shared(np.random.default_rng(29403)) g = rv_op(*dist_params, size=(10000, *base_size), rng=rng) - g_fn = compile_random_function(dist_params, g, mode=jax_mode) + g_fn = compile_random_function(dist_params, g) samples = g_fn( *[ i.tag.test_value @@ -521,7 +521,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn): param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None)) rv = rv_fn(param_that_implies_size) - draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode) + draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}) assert draws.shape == (2, 2) assert np.unique(draws).size == 4 @@ -531,7 +531,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn): def test_random_bernoulli(size): rng = shared(np.random.default_rng(123)) g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -542,7 +542,7 @@ def test_random_mvnormal(): mu = np.ones(4) cov = np.eye(4) g = pt.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) @@ -557,7 +557,7 @@ def test_random_mvnormal(): def test_random_dirichlet(parameter, size): rng = shared(np.random.default_rng(123)) g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -566,7 +566,7 @@ def test_random_choice(): # `replace=True` and `p is None` rng = shared(np.random.default_rng(123)) g = pt.random.choice(np.arange(4), size=10_000, rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() assert samples.shape == (10_000,) # Elements are picked at equal frequency @@ -575,7 +575,7 @@ def test_random_choice(): # `replace=True` and `p is not None` rng = shared(np.random.default_rng(123)) g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() assert samples.shape == (5, 2) # Only odd numbers are picked @@ -584,7 +584,7 @@ def test_random_choice(): # `replace=False` and `p is None` rng = shared(np.random.default_rng(123)) g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() assert samples.shape == (2, 49) # Elements are unique @@ -599,7 +599,7 @@ def test_random_choice(): rng=rng, replace=False, ) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() assert samples.shape == (3,) # Elements are unique @@ -611,14 +611,14 @@ def test_random_choice(): def test_random_categorical(): rng = shared(np.random.default_rng(123)) g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() assert samples.shape == (10000, 4) np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) # Test zero probabilities g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() assert samples.shape == (1000,) assert np.all(samples % 2 == 1) @@ -628,7 +628,7 @@ def test_random_permutation(): array = np.arange(4) rng = shared(np.random.default_rng(123)) g = pt.random.permutation(array, rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) permuted = g_fn() with pytest.raises(AssertionError): np.testing.assert_allclose(array, permuted) @@ -651,7 +651,7 @@ def test_random_geometric(): rng = shared(np.random.default_rng(123)) p = np.array([0.3, 0.7]) g = pt.random.geometric(p, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) @@ -662,7 +662,7 @@ def test_negative_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) np.testing.assert_allclose( @@ -676,7 +676,7 @@ def test_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) @@ -691,7 +691,7 @@ def test_beta_binomial(): a = np.array([1.5, 13]) b = np.array([0.5, 9]) g = pt.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) np.testing.assert_allclose( @@ -725,7 +725,7 @@ def test_vonmises_mu_outside_circle(): mu = np.array([-30, 40]) kappa = np.array([100, 10]) g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g) samples = g_fn() np.testing.assert_allclose( samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 @@ -823,7 +823,7 @@ def test_random_concrete_shape(): rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) - jax_fn = compile_random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -831,7 +831,7 @@ def test_random_concrete_shape_from_param(): rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(x_pt, 1, rng=rng) - jax_fn = compile_random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -850,7 +850,7 @@ def test_random_concrete_shape_subtensor(): rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) - jax_fn = compile_random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (3,) @@ -866,7 +866,7 @@ def test_random_concrete_shape_subtensor_tuple(): rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) - jax_fn = compile_random_function([x_pt], out, mode=jax_mode) + jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (2,) @@ -877,7 +877,7 @@ def test_random_concrete_shape_graph_input(): rng = shared(np.random.default_rng(123)) size_pt = pt.scalar() out = pt.random.normal(0, 1, size=size_pt, rng=rng) - jax_fn = compile_random_function([size_pt], out, mode=jax_mode) + jax_fn = compile_random_function([size_pt], out) assert jax_fn(10).shape == (10,) From 194b871900390773c5aefa12cb4ce0ffa71757e7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 11 Oct 2024 11:39:03 +0200 Subject: [PATCH 2/3] Group JAX random shape input tests --- tests/link/jax/test_random.py | 178 +++++++++++++++++----------------- 1 file changed, 87 insertions(+), 91 deletions(-) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 6915113c6d..f1d5cc269b 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -809,94 +809,90 @@ def sample_fn(rng, size, dtype, *parameters): compare_jax_and_py(fgraph, []) -def test_random_concrete_shape(): - """JAX should compile when a `RandomVariable` is passed a concrete shape. - - There are three quantities that JAX considers as concrete: - 1. Constants known at compile time; - 2. The shape of an array. - 3. `static_argnums` parameters - This test makes sure that graphs with `RandomVariable`s compile when the - `size` parameter satisfies either of these criteria. - - """ - rng = shared(np.random.default_rng(123)) - x_pt = pt.dmatrix() - out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) - jax_fn = compile_random_function([x_pt], out) - assert jax_fn(np.ones((2, 3))).shape == (2, 3) - - -def test_random_concrete_shape_from_param(): - rng = shared(np.random.default_rng(123)) - x_pt = pt.dmatrix() - out = pt.random.normal(x_pt, 1, rng=rng) - jax_fn = compile_random_function([x_pt], out) - assert jax_fn(np.ones((2, 3))).shape == (2, 3) - - -def test_random_concrete_shape_subtensor(): - """JAX should compile when a concrete value is passed for the `size` parameter. - - This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar - inputs into 1d vectors is replaced by an `Op` that turns concrete scalar - inputs into tuples of concrete values using the `jax_size_parameter_as_tuple` - rewrite. - - JAX does not accept scalars as `size` or `shape` arguments, so this is a - slight improvement over their API. - - """ - rng = shared(np.random.default_rng(123)) - x_pt = pt.dmatrix() - out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) - jax_fn = compile_random_function([x_pt], out) - assert jax_fn(np.ones((2, 3))).shape == (3,) - - -def test_random_concrete_shape_subtensor_tuple(): - """JAX should compile when a tuple of concrete values is passed for the `size` parameter. - - This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple - inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete - scalar inputs into tuples of concrete values using the - `jax_size_parameter_as_tuple` rewrite. - - """ - rng = shared(np.random.default_rng(123)) - x_pt = pt.dmatrix() - out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) - jax_fn = compile_random_function([x_pt], out) - assert jax_fn(np.ones((2, 3))).shape == (2,) - - -@pytest.mark.xfail( - reason="`size_pt` should be specified as a static argument", strict=True -) -def test_random_concrete_shape_graph_input(): - rng = shared(np.random.default_rng(123)) - size_pt = pt.scalar() - out = pt.random.normal(0, 1, size=size_pt, rng=rng) - jax_fn = compile_random_function([size_pt], out) - assert jax_fn(10).shape == (10,) - - -def test_constant_shape_after_graph_rewriting(): - size = pt.vector("size", shape=(2,), dtype=int) - x = pt.random.normal(size=size) - assert x.type.shape == (None, None) - - with pytest.raises(TypeError): - compile_random_function([size], x)([2, 5]) - - # Rebuild with strict=False so output type is not updated - # This reflects cases where size is constant folded during rewrites but the RV node is not recreated - new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) - assert new_x.type.shape == (None, None) - assert compile_random_function([], new_x)().shape == (2, 5) - - # Rebuild with strict=True, so output type is updated - # This uses a different path in the dispatch implementation - new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) - assert new_x.type.shape == (2, 5) - assert compile_random_function([], new_x)().shape == (2, 5) +class TestRandomShapeInputs: + def test_random_concrete_shape(self): + """JAX should compile when a `RandomVariable` is passed a concrete shape. + + There are three quantities that JAX considers as concrete: + 1. Constants known at compile time; + 2. The shape of an array. + 3. `static_argnums` parameters + This test makes sure that graphs with `RandomVariable`s compile when the + `size` parameter satisfies either of these criteria. + + """ + rng = shared(np.random.default_rng(123)) + x_pt = pt.dmatrix() + out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) + jax_fn = compile_random_function([x_pt], out) + assert jax_fn(np.ones((2, 3))).shape == (2, 3) + + def test_random_concrete_shape_from_param(self): + rng = shared(np.random.default_rng(123)) + x_pt = pt.dmatrix() + out = pt.random.normal(x_pt, 1, rng=rng) + jax_fn = compile_random_function([x_pt], out) + assert jax_fn(np.ones((2, 3))).shape == (2, 3) + + def test_random_concrete_shape_subtensor(self): + """JAX should compile when a concrete value is passed for the `size` parameter. + + This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar + inputs into 1d vectors is replaced by an `Op` that turns concrete scalar + inputs into tuples of concrete values using the `jax_size_parameter_as_tuple` + rewrite. + + JAX does not accept scalars as `size` or `shape` arguments, so this is a + slight improvement over their API. + + """ + rng = shared(np.random.default_rng(123)) + x_pt = pt.dmatrix() + out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) + jax_fn = compile_random_function([x_pt], out) + assert jax_fn(np.ones((2, 3))).shape == (3,) + + def test_random_concrete_shape_subtensor_tuple(self): + """JAX should compile when a tuple of concrete values is passed for the `size` parameter. + + This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple + inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete + scalar inputs into tuples of concrete values using the + `jax_size_parameter_as_tuple` rewrite. + + """ + rng = shared(np.random.default_rng(123)) + x_pt = pt.dmatrix() + out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) + jax_fn = compile_random_function([x_pt], out) + assert jax_fn(np.ones((2, 3))).shape == (2,) + + @pytest.mark.xfail( + reason="`size_pt` should be specified as a static argument", strict=True + ) + def test_random_concrete_shape_graph_input(self): + rng = shared(np.random.default_rng(123)) + size_pt = pt.scalar() + out = pt.random.normal(0, 1, size=size_pt, rng=rng) + jax_fn = compile_random_function([size_pt], out) + assert jax_fn(10).shape == (10,) + + def test_constant_shape_after_graph_rewriting(self): + size = pt.vector("size", shape=(2,), dtype=int) + x = pt.random.normal(size=size) + assert x.type.shape == (None, None) + + with pytest.raises(TypeError): + compile_random_function([size], x)([2, 5]) + + # Rebuild with strict=False so output type is not updated + # This reflects cases where size is constant folded during rewrites but the RV node is not recreated + new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) + assert new_x.type.shape == (None, None) + assert compile_random_function([], new_x)().shape == (2, 5) + + # Rebuild with strict=True, so output type is updated + # This uses a different path in the dispatch implementation + new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) + assert new_x.type.shape == (2, 5) + assert compile_random_function([], new_x)().shape == (2, 5) From 1e5c48774a57b30539b65112831fb20aa39eac29 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 11 Oct 2024 11:22:21 +0200 Subject: [PATCH 3/3] Allow running JAX functions with scalar inputs for RV shapes --- pytensor/link/jax/linker.py | 43 +++++++++++++++++++++++----- tests/link/jax/test_random.py | 54 ++++++++++++++++++++++++++++++----- 2 files changed, 83 insertions(+), 14 deletions(-) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 2450b24150..901074035b 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -3,15 +3,19 @@ from numpy.random import Generator, RandomState from pytensor.compile.sharedvalue import SharedVariable, shared -from pytensor.graph.basic import Constant from pytensor.link.basic import JITLinker class JAXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using JAX.""" + def __init__(self, *args, **kwargs): + self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked] + super().__init__(*args, **kwargs) + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.jax.dispatch import jax_funcify + from pytensor.link.jax.dispatch.shape import JAXShapeTuple from pytensor.tensor.random.type import RandomType shared_rng_inputs = [ @@ -65,6 +69,21 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): fgraph.inputs.remove(new_inp) fgraph.inputs.insert(old_inp_fgrap_index, new_inp) + fgraph_inputs = fgraph.inputs + clients = fgraph.clients + # Detect scalar shape inputs that are used only in JAXShapeTuple nodes + scalar_shape_inputs = [ + inp + for node in fgraph.apply_nodes + if isinstance(node.op, JAXShapeTuple) + for inp in node.inputs + if inp in fgraph_inputs + and all(isinstance(node.op, JAXShapeTuple) for node, _ in clients[inp]) + ] + self.scalar_shape_inputs = tuple( + fgraph_inputs.index(inp) for inp in scalar_shape_inputs + ) + return jax_funcify( fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs ) @@ -72,12 +91,22 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def jit_compile(self, fn): import jax - # I suppose we can consider `Constant`s to be "static" according to - # JAX. - static_argnums = [ - n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant) - ] - return jax.jit(fn, static_argnums=static_argnums) + jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs) + + if not self.scalar_shape_inputs: + return jit_fn + + def convert_scalar_shape_inputs( + *args, scalar_shape_inputs=self.scalar_shape_inputs + ): + return jit_fn( + *( + int(arg) if i in scalar_shape_inputs else arg + for i, arg in enumerate(args) + ) + ) + + return convert_scalar_shape_inputs def create_thunk_inputs(self, storage_map): from pytensor.link.jax.dispatch import jax_typify diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index f1d5cc269b..9822b3a3a6 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -867,15 +867,55 @@ def test_random_concrete_shape_subtensor_tuple(self): jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (2,) + def test_random_scalar_shape_input(self): + dim0 = pt.scalar("dim0", dtype=int) + dim1 = pt.scalar("dim1", dtype=int) + + out = pt.random.normal(0, 1, size=dim0) + jax_fn = compile_random_function([dim0], out) + assert jax_fn(np.array(2)).shape == (2,) + assert jax_fn(np.array(3)).shape == (3,) + + out = pt.random.normal(0, 1, size=[dim0, dim1]) + jax_fn = compile_random_function([dim0, dim1], out) + assert jax_fn(np.array(2), np.array(3)).shape == (2, 3) + assert jax_fn(np.array(4), np.array(5)).shape == (4, 5) + @pytest.mark.xfail( - reason="`size_pt` should be specified as a static argument", strict=True + raises=TypeError, reason="Cannot convert scalar input to integer" ) - def test_random_concrete_shape_graph_input(self): - rng = shared(np.random.default_rng(123)) - size_pt = pt.scalar() - out = pt.random.normal(0, 1, size=size_pt, rng=rng) - jax_fn = compile_random_function([size_pt], out) - assert jax_fn(10).shape == (10,) + def test_random_scalar_shape_input_not_supported(self): + dim = pt.scalar("dim", dtype=int) + out1 = pt.random.normal(0, 1, size=dim) + # An operation that wouldn't work if we replaced 0d array by integer + out2 = dim[...].set(1) + jax_fn = compile_random_function([dim], [out1, out2]) + + res1, res2 = jax_fn(np.array(2)) + assert res1.shape == (2,) + assert res2 == 1 + + @pytest.mark.xfail( + raises=TypeError, reason="Cannot convert scalar input to integer" + ) + def test_random_scalar_shape_input_not_supported2(self): + dim = pt.scalar("dim", dtype=int) + # This could theoretically be supported + # but would require knowing that * 2 is a safe operation for a python integer + out = pt.random.normal(0, 1, size=dim * 2) + jax_fn = compile_random_function([dim], out) + assert jax_fn(np.array(2)).shape == (4,) + + @pytest.mark.xfail( + raises=TypeError, reason="Cannot convert tensor input to shape tuple" + ) + def test_random_vector_shape_graph_input(self): + shape = pt.vector("shape", shape=(2,), dtype=int) + out = pt.random.normal(0, 1, size=shape) + + jax_fn = compile_random_function([shape], out) + assert jax_fn(np.array([2, 3])).shape == (2, 3) + assert jax_fn(np.array([4, 5])).shape == (4, 5) def test_constant_shape_after_graph_rewriting(self): size = pt.vector("size", shape=(2,), dtype=int)