Skip to content

Commit

Permalink
Merge pull request #220 from joglekara/feature/custom_save_func
Browse files Browse the repository at this point in the history
Add `func` method to `SaveAt`
  • Loading branch information
patrick-kidger authored Feb 21, 2023
2 parents 5c0bfb3 + ef2a103 commit 05d03d8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
24 changes: 13 additions & 11 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools as ft
import warnings
from typing import Optional
from typing import Optional, Callable

import equinox as eqx
import equinox.internal as eqxi
Expand Down Expand Up @@ -62,14 +62,15 @@ class _InnerState(eqx.Module):
save_index: Int


def _save(state: _State, t: Scalar) -> _State:
def _save(state: _State, t: Scalar, save_func: Callable, args: PyTree) -> _State:
ts = state.ts
ys = state.ys
save_index = state.save_index
y = state.y

ts = ts.at[save_index].set(t)
ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, y)
ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys,
save_func(t, y, args))
save_index = save_index + 1

return eqx.tree_at(
Expand Down Expand Up @@ -109,7 +110,7 @@ def loop(
):

if saveat.t0:
init_state = _save(init_state, t0)
init_state = _save(init_state, t0, save_func=saveat.func, args=args)
if saveat.dense:
dense_ts = init_state.dense_ts
dense_ts = dense_ts.at[0].set(t0)
Expand Down Expand Up @@ -275,15 +276,15 @@ def _body_fun(_state, _inplace):

_inplace.merge(inplace)
_pred = cond_fun(state) & _cond_fun(_state)
_ts = _ts.at[_save_index].set(
jnp.where(_pred, _saveat_t, _ts[_save_index])
)

_tt_ = jnp.where(_pred, _saveat_t, _ts[_save_index])
_ts = _ts.at[_save_index].set(_tt_)
_ys = jtu.tree_map(
lambda __ys, __saveat_y: __ys.at[_save_index].set(
jnp.where(_pred, __saveat_y, __ys[_save_index])
),
_ys,
_saveat_y,
saveat.func(_tt_, _saveat_y, args)
)

# Some immediate questions you might have:
Expand Down Expand Up @@ -341,7 +342,7 @@ def maybe_inplace(i, x, u):
if saveat.steps:
made_inplace_update = True
ts = maybe_inplace(save_index, ts, tprev)
ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, y)
ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, saveat.func(tprev, y, args))
save_index = save_index + keep_step

if saveat.dense:
Expand Down Expand Up @@ -502,7 +503,7 @@ def _cond_fun(state):
# if saveat.steps then the final value is already saved.
# Using `tprev` instead of `t1` in case of an event terminating the solve
# early. (And absent such an event then `tprev == t1`.)
final_state = _save(final_state, final_state.tprev)
final_state = _save(final_state, final_state.tprev, save_func=saveat.func, args=args)
result = jnp.where(
cond_fun(final_state), RESULTS.max_steps_reached, final_state.result
)
Expand Down Expand Up @@ -810,7 +811,8 @@ def _promote(yi):
save_index = 0
made_jump = False if made_jump is None else made_jump
ts = jnp.full(out_size, jnp.inf)
ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), y0)
_y0 = saveat.func(t0, y0, args)
ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), _y0)
result = jnp.array(RESULTS.successful)
if saveat.dense:
t0 = eqxi.error_if(t0, t0 == t1, "Cannot save dense output if t0 == t1")
Expand Down
5 changes: 4 additions & 1 deletion diffrax/saveat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, Callable

import equinox as eqx
import jax
Expand All @@ -22,6 +22,7 @@ class SaveAt(eqx.Module):
solver_state: bool = False
controller_state: bool = False
made_jump: bool = False
func: Callable = lambda t, y, args: y

def __post_init__(self):
with jax.ensure_compile_time_eval():
Expand Down Expand Up @@ -50,6 +51,8 @@ def __post_init__(self):
It is less likely you will need to use these options.
- `func`: A function that returns an arbitrary pytree of values computed
from `t, y, args`. Defaults to return the state
- `solver_state`: If `True`, save the internal state of the numerical solver at
`t1`.
- `controller_state`: If `True`, save the internal state of the step size
Expand Down
46 changes: 29 additions & 17 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict
import math

import diffrax
Expand Down Expand Up @@ -36,16 +37,21 @@ def _integrate(saveat):
stepsize_controller=stepsize_controller,
)

default_kwargs = {}
custom_func_kwargs = {"func": lambda t, y, args: {"another_y": y, "another_t": t}}

def test_saveat_solution():
saveat = diffrax.SaveAt(t0=True)
@pytest.mark.parametrize("kw_args", [default_kwargs, custom_func_kwargs])
def test_saveat_solution(kw_args):
saveat = diffrax.SaveAt(t0=True, **kw_args)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
assert sol.ts.shape == (1,)
assert sol.ys.shape == (1, 1)
assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict)
else sol.ys.shape) == (1, 1)
assert sol.ts[0] == _t0
assert sol.ys[0, 0] == _y0
assert (sol.ys["another_y"][0, 0] if isinstance(sol.ys, Dict)
else sol.ys) == _y0
assert sol.controller_state is None
assert sol.solver_state is None
with pytest.raises(ValueError):
Expand All @@ -58,15 +64,18 @@ def test_saveat_solution():
for controller_state in (True, False):
for solver_state in (True, False):
saveat = diffrax.SaveAt(
t1=True, solver_state=solver_state, controller_state=controller_state
t1=True, solver_state=solver_state, controller_state=controller_state,
**kw_args
)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
assert sol.ts.shape == (1,)
assert sol.ys.shape == (1, 1)
assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict)
else sol.ys.shape) == (1, 1)
assert sol.ts[0] == _t1
assert shaped_allclose(sol.ys[0], _y0 * math.exp(-0.5))
assert shaped_allclose(sol.ys["another_y"][0] if isinstance(sol.ys, Dict)
else sol.ys[0], _y0 * math.exp(-0.5))
if controller_state:
assert sol.controller_state is not None
else:
Expand All @@ -83,23 +92,26 @@ def test_saveat_solution():
assert sol.result == diffrax.RESULTS.successful

# Outside [t0, t1]
saveat = diffrax.SaveAt(ts=[0])
saveat = diffrax.SaveAt(ts=[0], **kw_args)
with pytest.raises(RuntimeError):
sol = _integrate(saveat)
saveat = diffrax.SaveAt(ts=[3])
saveat = diffrax.SaveAt(ts=[3], **kw_args)
with pytest.raises(RuntimeError):
sol = _integrate(saveat)

saveat = diffrax.SaveAt(ts=[0.5, 0.8])
saveat = diffrax.SaveAt(ts=[0.5, 0.8], **kw_args)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
assert sol.ts.shape == (2,)
assert sol.ys.shape == (2, 1)
assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict)
else sol.ys.shape) == (2, 1)
assert sol.ts[0] == jnp.asarray(0.5)
assert sol.ts[1] == jnp.asarray(0.8)
assert shaped_allclose(sol.ys[0], _y0 * math.exp(-0.2))
assert shaped_allclose(sol.ys[1], _y0 * math.exp(-0.35))
assert shaped_allclose(sol.ys["another_y"][0] if isinstance(sol.ys, Dict)
else sol.ys[0], _y0 * math.exp(-0.2))
assert shaped_allclose(sol.ys["another_y"][1] if isinstance(sol.ys, Dict)
else sol.ys[1], _y0 * math.exp(-0.35))
assert sol.controller_state is None
assert sol.solver_state is None
with pytest.raises(ValueError):
Expand All @@ -109,16 +121,16 @@ def test_saveat_solution():
assert sol.stats["num_steps"] > 0
assert sol.result == diffrax.RESULTS.successful

saveat = diffrax.SaveAt(steps=True)
saveat = diffrax.SaveAt(steps=True, **kw_args)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
assert sol.ts.shape == (4096,)
assert sol.ys.shape == (4096, 1)
assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict) else sol.ys.shape) == (4096, 1)
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts)
_ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None]
_ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys)
assert shaped_allclose(sol.ys, _ys)
assert shaped_allclose(sol.ys["another_y"] if isinstance(sol.ys, Dict) else sol.ys, _ys)
assert sol.controller_state is None
assert sol.solver_state is None
with pytest.raises(ValueError):
Expand All @@ -128,7 +140,7 @@ def test_saveat_solution():
assert sol.stats["num_steps"] > 0
assert sol.result == diffrax.RESULTS.successful

saveat = diffrax.SaveAt(dense=True)
saveat = diffrax.SaveAt(dense=True, **kw_args)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
Expand Down

0 comments on commit 05d03d8

Please sign in to comment.