Skip to content

Commit

Permalink
Forward mode "adjoint" (#537)
Browse files Browse the repository at this point in the history
* add .venv

* add code, tests and documentation for ForwardAdjoint

* make version of mkdocs-autorefs explicit (patrick-kidger/optimistix#91, but for diffrax)

* rename, add documentation, explicate lack of test covarage for unit-input case.

* rename import of ForwardMode

* fix duplicate

* Make docstring of ForwardMode more precise, add references to it where forward-mode autodiff is mentioned in the other adjoints

---------

Co-authored-by: Johanna Haffner <[email protected]>
  • Loading branch information
johannahaffner and Johanna Haffner authored Dec 22, 2024
1 parent 3553ae1 commit 0beb5ce
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ site/
.all_objects.cache
.pymon
.idea/
.venv/
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AbstractAdjoint as AbstractAdjoint,
BacksolveAdjoint as BacksolveAdjoint,
DirectAdjoint as DirectAdjoint,
ForwardMode as ForwardMode,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
)
Expand Down
43 changes: 42 additions & 1 deletion diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ class RecursiveCheckpointAdjoint(AbstractAdjoint):
!!! info
Note that this cannot be forward-mode autodifferentiated. (E.g. using
`jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if that is something you need.
`jax.jvp`.) Try using [`diffrax.DirectAdjoint`][] if you need both forward-mode
and reverse-mode autodifferentiation, and [`diffrax.ForwardMode`][] if you need
only forward-mode autodifferentiation.
??? cite "References"
Expand Down Expand Up @@ -333,6 +335,8 @@ class DirectAdjoint(AbstractAdjoint):
So unless you need forward-mode autodifferentiation then
[`diffrax.RecursiveCheckpointAdjoint`][] should be preferred.
If you need only forward-mode autodifferentiation, then [`diffrax.ForwardMode`][] is
more efficient.
"""

def loop(
Expand Down Expand Up @@ -852,3 +856,40 @@ def loop(
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats


class ForwardMode(AbstractAdjoint):
"""Supports forward-mode automatic differentiation through a differential equation
solve. This works by propagating the derivatives during the forward-pass - that is,
during the ODE solve, instead of solving the adjoint equations afterwards.
(So this is really a different way of quantifying the sensitivity of the output to
the input, even if its interface is that of an adjoint for convenience.)
This is useful when we have many more outputs than inputs to a function - for
instance during parameter inference for ODE models with least-squares solvers such
as `optimistix.Levenberg-Marquardt`, that operate on the residuals.
"""

def loop(
self,
*,
solver,
throw,
passed_solver_state,
passed_controller_state,
**kwargs,
):
del throw, passed_solver_state, passed_controller_state
inner_while_loop = eqx.Partial(_inner_loop, kind="lax")
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none)
final_state = self._loop(
solver=solver,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
**kwargs,
)
return final_state
4 changes: 4 additions & 0 deletions docs/api/adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax
selection:
members: false

::: diffrax.ForwardMode
selection:
members: false

---

::: diffrax.adjoint_rms_seminorm
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ mkdocs-autorefs==1.0.1
mkdocs-material-extensions==1.3.1

# Install latest version of our dependencies
jax[cpu]
jax[cpu]
31 changes: 31 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _run_inexact(inexact, saveat, adjoint):
return _run(eqx.combine(inexact, static), saveat, adjoint)

_run_grad = eqx.filter_jit(jax.grad(_run_inexact))
_run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact))
_run_grad_int = eqx.filter_jit(jax.grad(_run, allow_int=True))

twice_inexact = jtu.tree_map(lambda *x: jnp.stack(x), inexact, inexact)
Expand All @@ -83,6 +84,11 @@ def _run_vmap_grad(twice_inexact, saveat, adjoint):
f = jax.vmap(jax.grad(_run_inexact), in_axes=(0, None, None))
return f(twice_inexact, saveat, adjoint)

@eqx.filter_jit
def _run_vmap_fwd_grad(twice_inexact, saveat, adjoint):
f = jax.vmap(jax.jacfwd(_run_inexact), in_axes=(0, None, None))
return f(twice_inexact, saveat, adjoint)

# @eqx.filter_jit
# def _run_vmap_finite_diff(twice_inexact, saveat, adjoint):
# @jax.vmap
Expand All @@ -102,6 +108,17 @@ def _run_impl(twice_inexact):

return _run_impl(twice_inexact)

@eqx.filter_jit
def _run_fwd_grad_vmap(twice_inexact, saveat, adjoint):
@jax.jacfwd
def _run_impl(twice_inexact):
f = jax.vmap(_run_inexact, in_axes=(0, None, None))
out = f(twice_inexact, saveat, adjoint)
assert out.shape == (2,)
return jnp.sum(out)

return _run_impl(twice_inexact)

# Yep, test that they're not implemented. We can remove these checks if we ever
# do implement them.
# Until that day comes, it's worth checking that things don't silently break.
Expand Down Expand Up @@ -136,10 +153,16 @@ def _convert_float0(x):
inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
)
backsolve_grads = _run_grad(inexact, saveat, diffrax.BacksolveAdjoint())
forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardMode())
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)

# Test support for integer inputs (jax.grad(..., allow_int=True)). There
# is no corresponding option for jax.jacfwd or jax.linearize, and a
# workaround (jvp with custom "unit pytrees" for mixed array and
# non-array inputs?) is not implemented and tested here.
direct_grads = _run_grad_int(
y0__args__term, saveat, diffrax.DirectAdjoint()
)
Expand All @@ -166,9 +189,13 @@ def _convert_float0(x):
backsolve_grads = _run_vmap_grad(
twice_inexact, saveat, diffrax.BacksolveAdjoint()
)
forward_grads = _run_vmap_fwd_grad(
twice_inexact, saveat, diffrax.ForwardMode()
)
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)

direct_grads = _run_grad_vmap(
twice_inexact, saveat, diffrax.DirectAdjoint()
Expand All @@ -179,9 +206,13 @@ def _convert_float0(x):
backsolve_grads = _run_grad_vmap(
twice_inexact, saveat, diffrax.BacksolveAdjoint()
)
forward_grads = _run_fwd_grad_vmap(
twice_inexact, saveat, diffrax.ForwardMode()
)
assert tree_allclose(fd_grads, direct_grads[0])
assert tree_allclose(direct_grads, recursive_grads, atol=1e-5)
assert tree_allclose(direct_grads, backsolve_grads, atol=1e-5)
assert tree_allclose(direct_grads, forward_grads, atol=1e-5)


def test_adjoint_seminorm():
Expand Down

0 comments on commit 0beb5ce

Please sign in to comment.