Skip to content

Commit

Permalink
Group JAX random shape input tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 29, 2024
1 parent 9ac810c commit 194b871
Showing 1 changed file with 87 additions and 91 deletions.
178 changes: 87 additions & 91 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,94 +809,90 @@ def sample_fn(rng, size, dtype, *parameters):
compare_jax_and_py(fgraph, [])


def test_random_concrete_shape():
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


def test_random_concrete_shape_from_param():
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


def test_random_concrete_shape_subtensor():
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
slight improvement over their API.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (3,)


def test_random_concrete_shape_subtensor_tuple():
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,)


@pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out)
assert jax_fn(10).shape == (10,)


def test_constant_shape_after_graph_rewriting():
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)

with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5])

# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5)

# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)
class TestRandomShapeInputs:
def test_random_concrete_shape(self):
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)

def test_random_concrete_shape_from_param(self):
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)

def test_random_concrete_shape_subtensor(self):
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
slight improvement over their API.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (3,)

def test_random_concrete_shape_subtensor_tuple(self):
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,)

@pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input(self):
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out)
assert jax_fn(10).shape == (10,)

def test_constant_shape_after_graph_rewriting(self):
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)

with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5])

# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5)

# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)

0 comments on commit 194b871

Please sign in to comment.