Skip to content

Commit

Permalink
Explicitly marked BacksolveAdjoint as incompatible with events.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Aug 18, 2023
1 parent 7f30854 commit 5f1978d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
8 changes: 7 additions & 1 deletion diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def _loop_backsolve_bwd(
throw,
init_state,
):
assert discrete_terminating_event is None

#
# Unpack our various arguments. Delete a lot of things just to make sure we're not
Expand Down Expand Up @@ -565,7 +566,6 @@ def _loop_backsolve_bwd(
adjoint=self,
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -744,6 +744,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
discrete_terminating_event,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -785,6 +786,10 @@ def loop(
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
"a single term."
)
if discrete_terminating_event is not None:
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand All @@ -798,6 +803,7 @@ def loop(
saveat=saveat,
init_state=init_state,
solver=solver,
discrete_terminating_event=discrete_terminating_event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
50 changes: 48 additions & 2 deletions test/test_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import diffrax
import jax
import jax.numpy as jnp
import pytest


def test_discrete_terminate1():
Expand All @@ -9,7 +11,12 @@ def test_discrete_terminate1():
t1 = jnp.inf
dt0 = 1
y0 = 1.0
event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.y > 10)

def event_fn(state, **kwargs):
assert isinstance(state.y, jax.Array)
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)
sol = diffrax.diffeqsolve(
term, solver, t0, t1, dt0, y0, discrete_terminating_event=event
)
Expand All @@ -23,11 +30,50 @@ def test_discrete_terminate2():
t1 = jnp.inf
dt0 = 1
y0 = 1.0
event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.tprev > 10)

def event_fn(state, **kwargs):
assert isinstance(state.y, jax.Array)
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)
sol = diffrax.diffeqsolve(
term, solver, t0, t1, dt0, y0, discrete_terminating_event=event
)
assert jnp.all(sol.ts > 10)


def test_event_backsolve():
term = diffrax.ODETerm(lambda t, y, args: y)
solver = diffrax.Tsit5()
t0 = 0
t1 = jnp.inf
dt0 = 1
y0 = 1.0

def event_fn(state, **kwargs):
assert isinstance(state.y, jax.Array)
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)

@jax.jit
@jax.grad
def run(y0):
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
discrete_terminating_event=event,
adjoint=diffrax.BacksolveAdjoint(),
)
return jnp.sum(sol.ys)

# And in particular not some other error.
with pytest.raises(NotImplementedError):
run(y0)


# diffrax.SteadyStateEvent tested as part of test_adjoint.py::test_implicit

0 comments on commit 5f1978d

Please sign in to comment.