Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added test for grad-of-vmap #102

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/google/lineax" }
dependencies = ["jax>=0.4.26", "jaxtyping>=0.2.20", "equinox>=0.11.3", "typing_extensions>=4.5.0"]
dependencies = ["jax>=0.4.26", "jaxtyping>=0.2.20", "equinox>=0.11.5", "typing_extensions>=4.5.0"]

[build-system]
requires = ["hatchling"]
Expand Down
42 changes: 42 additions & 0 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
Expand Down Expand Up @@ -86,3 +87,44 @@ def wrap_solve(matrix, vector):
matrix, vec
)
assert tree_allclose(lx_result, jax_result)


# https://github.com/patrick-kidger/lineax/issues/101
def test_grad_vmap_basic(getkey):
A = jr.normal(getkey(), (16, 8))
B = jr.normal(getkey(), (128, 16))

@jax.jit
@jax.grad
def fn(A):
op = lx.MatrixLinearOperator(A)
return jax.vmap(
lambda b: lx.linear_solve(
op, b, lx.AutoLinearSolver(well_posed=False)
).value
)(B).mean()

fn(A)


def test_grad_vmap_advanced(getkey):
# this is a more complicated version of the above test, in which the batch axes and
# the undefinedprimals do not necessarily line up in the same arguments.
A = jr.normal(getkey(), (2, 8)), jr.normal(getkey(), (3, 8, 128))
B = jr.normal(getkey(), (2, 128)), jr.normal(getkey(), (3,))

output_structure = (
jax.ShapeDtypeStruct((2,), jnp.float64),
jax.ShapeDtypeStruct((3,), jnp.float64),
)

def to_vmap(A, B):
op = lx.PyTreeLinearOperator(A, output_structure)
return lx.linear_solve(op, B, lx.AutoLinearSolver(well_posed=False)).value

@jax.jit
@jax.grad
def fn(A):
return jax.vmap(to_vmap, in_axes=((None, 2), (1, None)))(A, B).mean()

fn(A)
Loading