From 59a2b26c477eab687679b945fc41b945a542611f Mon Sep 17 00:00:00 2001 From: archis Date: Mon, 30 Jan 2023 13:25:38 -0800 Subject: [PATCH 1/5] Add `func` method to `SaveAt` --- diffrax/integrate.py | 5 +++-- diffrax/saveat.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 1bf48e00..fb42a7fc 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -283,7 +283,7 @@ def _body_fun(_state, _inplace): jnp.where(_pred, __saveat_y, __ys[_save_index]) ), _ys, - _saveat_y, + saveat.func(_ts[_save_index], _saveat_y, args) ) # Some immediate questions you might have: @@ -810,7 +810,8 @@ def _promote(yi): save_index = 0 made_jump = False if made_jump is None else made_jump ts = jnp.full(out_size, jnp.inf) - ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), y0) + _y0 = saveat.func(saveat.t0, y0, args) + ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), _y0) result = jnp.array(RESULTS.successful) if saveat.dense: t0 = eqxi.error_if(t0, t0 == t1, "Cannot save dense output if t0 == t1") diff --git a/diffrax/saveat.py b/diffrax/saveat.py index 2eccc883..2c054d76 100644 --- a/diffrax/saveat.py +++ b/diffrax/saveat.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Union, Callable import equinox as eqx import jax @@ -22,6 +22,7 @@ class SaveAt(eqx.Module): solver_state: bool = False controller_state: bool = False made_jump: bool = False + func: Callable = lambda t, y, args: y def __post_init__(self): with jax.ensure_compile_time_eval(): @@ -50,6 +51,8 @@ def __post_init__(self): It is less likely you will need to use these options. +- `func`: Pass a function that returns an arbitrary pytree of values computed + from `t, y, args`. Defaults to return the state - `solver_state`: If `True`, save the internal state of the numerical solver at `t1`. - `controller_state`: If `True`, save the internal state of the step size From ecd90f8806fda066450ff36239e5944941b49e2b Mon Sep 17 00:00:00 2001 From: archis Date: Mon, 30 Jan 2023 13:28:47 -0800 Subject: [PATCH 2/5] doc edit --- diffrax/saveat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/saveat.py b/diffrax/saveat.py index 2c054d76..4ec63a0d 100644 --- a/diffrax/saveat.py +++ b/diffrax/saveat.py @@ -51,7 +51,7 @@ def __post_init__(self): It is less likely you will need to use these options. -- `func`: Pass a function that returns an arbitrary pytree of values computed +- `func`: A function that returns an arbitrary pytree of values computed from `t, y, args`. Defaults to return the state - `solver_state`: If `True`, save the internal state of the numerical solver at `t1`. From e5e317121ec7085ec7d340f6b512b2ed76cdf33a Mon Sep 17 00:00:00 2001 From: archis Date: Sun, 5 Feb 2023 14:10:10 -0800 Subject: [PATCH 3/5] added test parameterization that repeats existing set of tests for the default `saveat` function and a custom `saveat` function --- diffrax/integrate.py | 15 ++++++------ test/test_saveat_solution.py | 46 +++++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index fb42a7fc..1a1216cf 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -1,6 +1,6 @@ import functools as ft import warnings -from typing import Optional +from typing import Optional, Callable import equinox as eqx import equinox.internal as eqxi @@ -62,14 +62,15 @@ class _InnerState(eqx.Module): save_index: Int -def _save(state: _State, t: Scalar) -> _State: +def _save(state: _State, t: Scalar, save_func: Callable, args: PyTree) -> _State: ts = state.ts ys = state.ys save_index = state.save_index y = state.y ts = ts.at[save_index].set(t) - ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, y) + ys = jtu.tree_map(lambda ys_, y_: ys_.at[save_index].set(y_), ys, + save_func(t, y, args)) save_index = save_index + 1 return eqx.tree_at( @@ -109,7 +110,7 @@ def loop( ): if saveat.t0: - init_state = _save(init_state, t0) + init_state = _save(init_state, t0, save_func=saveat.func, args=args) if saveat.dense: dense_ts = init_state.dense_ts dense_ts = dense_ts.at[0].set(t0) @@ -341,7 +342,7 @@ def maybe_inplace(i, x, u): if saveat.steps: made_inplace_update = True ts = maybe_inplace(save_index, ts, tprev) - ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, y) + ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, saveat.func(tprev, y, args)) save_index = save_index + keep_step if saveat.dense: @@ -502,7 +503,7 @@ def _cond_fun(state): # if saveat.steps then the final value is already saved. # Using `tprev` instead of `t1` in case of an event terminating the solve # early. (And absent such an event then `tprev == t1`.) - final_state = _save(final_state, final_state.tprev) + final_state = _save(final_state, final_state.tprev, save_func=saveat.func, args=args) result = jnp.where( cond_fun(final_state), RESULTS.max_steps_reached, final_state.result ) @@ -810,7 +811,7 @@ def _promote(yi): save_index = 0 made_jump = False if made_jump is None else made_jump ts = jnp.full(out_size, jnp.inf) - _y0 = saveat.func(saveat.t0, y0, args) + _y0 = saveat.func(t0, y0, args) ys = jtu.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), _y0) result = jnp.array(RESULTS.successful) if saveat.dense: diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 4788cbba..5a96466d 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -1,3 +1,4 @@ +from typing import Dict import math import diffrax @@ -36,16 +37,21 @@ def _integrate(saveat): stepsize_controller=stepsize_controller, ) +default_func = lambda t, y, args: y +custom_func = lambda t, y, args: {"another_y": y, "another_t": t} -def test_saveat_solution(): - saveat = diffrax.SaveAt(t0=True) +@pytest.mark.parametrize("save_func", [custom_func, default_func]) +def test_saveat_solution(save_func): + saveat = diffrax.SaveAt(t0=True, func=save_func) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 assert sol.ts.shape == (1,) - assert sol.ys.shape == (1, 1) + assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict) + else sol.ys.shape) == (1, 1) assert sol.ts[0] == _t0 - assert sol.ys[0, 0] == _y0 + assert (sol.ys["another_y"][0, 0] if isinstance(sol.ys, Dict) + else sol.ys) == _y0 assert sol.controller_state is None assert sol.solver_state is None with pytest.raises(ValueError): @@ -58,15 +64,18 @@ def test_saveat_solution(): for controller_state in (True, False): for solver_state in (True, False): saveat = diffrax.SaveAt( - t1=True, solver_state=solver_state, controller_state=controller_state + t1=True, solver_state=solver_state, controller_state=controller_state, + func=save_func ) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 assert sol.ts.shape == (1,) - assert sol.ys.shape == (1, 1) + assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict) + else sol.ys.shape) == (1, 1) assert sol.ts[0] == _t1 - assert shaped_allclose(sol.ys[0], _y0 * math.exp(-0.5)) + assert shaped_allclose(sol.ys["another_y"][0] if isinstance(sol.ys, Dict) + else sol.ys[0], _y0 * math.exp(-0.5)) if controller_state: assert sol.controller_state is not None else: @@ -83,23 +92,26 @@ def test_saveat_solution(): assert sol.result == diffrax.RESULTS.successful # Outside [t0, t1] - saveat = diffrax.SaveAt(ts=[0]) + saveat = diffrax.SaveAt(ts=[0], func=save_func) with pytest.raises(RuntimeError): sol = _integrate(saveat) - saveat = diffrax.SaveAt(ts=[3]) + saveat = diffrax.SaveAt(ts=[3], func=save_func) with pytest.raises(RuntimeError): sol = _integrate(saveat) - saveat = diffrax.SaveAt(ts=[0.5, 0.8]) + saveat = diffrax.SaveAt(ts=[0.5, 0.8], func=save_func) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 assert sol.ts.shape == (2,) - assert sol.ys.shape == (2, 1) + assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict) + else sol.ys.shape) == (2, 1) assert sol.ts[0] == jnp.asarray(0.5) assert sol.ts[1] == jnp.asarray(0.8) - assert shaped_allclose(sol.ys[0], _y0 * math.exp(-0.2)) - assert shaped_allclose(sol.ys[1], _y0 * math.exp(-0.35)) + assert shaped_allclose(sol.ys["another_y"][0] if isinstance(sol.ys, Dict) + else sol.ys[0], _y0 * math.exp(-0.2)) + assert shaped_allclose(sol.ys["another_y"][1] if isinstance(sol.ys, Dict) + else sol.ys[1], _y0 * math.exp(-0.35)) assert sol.controller_state is None assert sol.solver_state is None with pytest.raises(ValueError): @@ -109,16 +121,16 @@ def test_saveat_solution(): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful - saveat = diffrax.SaveAt(steps=True) + saveat = diffrax.SaveAt(steps=True, func=save_func) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 assert sol.ts.shape == (4096,) - assert sol.ys.shape == (4096, 1) + assert (sol.ys["another_y"].shape if isinstance(sol.ys, Dict) else sol.ys.shape) == (4096, 1) _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) - assert shaped_allclose(sol.ys, _ys) + assert shaped_allclose(sol.ys["another_y"] if isinstance(sol.ys, Dict) else sol.ys, _ys) assert sol.controller_state is None assert sol.solver_state is None with pytest.raises(ValueError): @@ -128,7 +140,7 @@ def test_saveat_solution(): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful - saveat = diffrax.SaveAt(dense=True) + saveat = diffrax.SaveAt(dense=True, func=save_func) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 From 83778af21cf0dd185503acf0faae75d547fbb0cc Mon Sep 17 00:00:00 2001 From: archis Date: Sun, 5 Feb 2023 14:13:31 -0800 Subject: [PATCH 4/5] better implementation for default kwargs --- test/test_saveat_solution.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 5a96466d..b0817f33 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -37,12 +37,12 @@ def _integrate(saveat): stepsize_controller=stepsize_controller, ) -default_func = lambda t, y, args: y -custom_func = lambda t, y, args: {"another_y": y, "another_t": t} +default_kwargs = {} +custom_func_kwargs = {"func": lambda t, y, args: {"another_y": y, "another_t": t}} -@pytest.mark.parametrize("save_func", [custom_func, default_func]) -def test_saveat_solution(save_func): - saveat = diffrax.SaveAt(t0=True, func=save_func) +@pytest.mark.parametrize("kw_args", [default_kwargs, custom_func_kwargs]) +def test_saveat_solution(kw_args): + saveat = diffrax.SaveAt(t0=True, **kw_args) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 @@ -65,7 +65,7 @@ def test_saveat_solution(save_func): for solver_state in (True, False): saveat = diffrax.SaveAt( t1=True, solver_state=solver_state, controller_state=controller_state, - func=save_func + **kw_args ) sol = _integrate(saveat) assert sol.t0 == _t0 @@ -92,14 +92,14 @@ def test_saveat_solution(save_func): assert sol.result == diffrax.RESULTS.successful # Outside [t0, t1] - saveat = diffrax.SaveAt(ts=[0], func=save_func) + saveat = diffrax.SaveAt(ts=[0], **kw_args) with pytest.raises(RuntimeError): sol = _integrate(saveat) - saveat = diffrax.SaveAt(ts=[3], func=save_func) + saveat = diffrax.SaveAt(ts=[3], **kw_args) with pytest.raises(RuntimeError): sol = _integrate(saveat) - saveat = diffrax.SaveAt(ts=[0.5, 0.8], func=save_func) + saveat = diffrax.SaveAt(ts=[0.5, 0.8], **kw_args) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 @@ -121,7 +121,7 @@ def test_saveat_solution(save_func): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful - saveat = diffrax.SaveAt(steps=True, func=save_func) + saveat = diffrax.SaveAt(steps=True, **kw_args) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 @@ -140,7 +140,7 @@ def test_saveat_solution(save_func): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful - saveat = diffrax.SaveAt(dense=True, func=save_func) + saveat = diffrax.SaveAt(dense=True, **kw_args) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 From ef2a103e8550fc6072ed39390873a213c3b7df39 Mon Sep 17 00:00:00 2001 From: archis Date: Mon, 6 Feb 2023 15:21:50 -0800 Subject: [PATCH 5/5] reusing t --- diffrax/integrate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 1a1216cf..086ee088 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -276,15 +276,15 @@ def _body_fun(_state, _inplace): _inplace.merge(inplace) _pred = cond_fun(state) & _cond_fun(_state) - _ts = _ts.at[_save_index].set( - jnp.where(_pred, _saveat_t, _ts[_save_index]) - ) + + _tt_ = jnp.where(_pred, _saveat_t, _ts[_save_index]) + _ts = _ts.at[_save_index].set(_tt_) _ys = jtu.tree_map( lambda __ys, __saveat_y: __ys.at[_save_index].set( jnp.where(_pred, __saveat_y, __ys[_save_index]) ), _ys, - saveat.func(_ts[_save_index], _saveat_y, args) + saveat.func(_tt_, _saveat_y, args) ) # Some immediate questions you might have: