Skip to content

Commit

Permalink
Merge pull request #125 from patrick-kidger/solver-docs
Browse files Browse the repository at this point in the history
Improved solver documentation
  • Loading branch information
patrick-kidger authored Jul 20, 2022
2 parents aa27945 + a3d827e commit 115997e
Show file tree
Hide file tree
Showing 53 changed files with 1,467 additions and 536 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
with:
python-version: "3.8"
test-script: |
python -m pip install pytest psutil jax jaxlib equinox scipy
python -m pip install pytest psutil jax jaxlib equinox scipy optax
cp -r ${{ github.workspace }}/test ./test
pytest
pypi-token: ${{ secrets.pypi_token }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest psutil wheel scipy numpy jaxlib
python -m pip install pytest psutil wheel scipy numpy optax jaxlib
- name: Checks with pre-commit
uses: pre-commit/[email protected]
Expand Down
10 changes: 8 additions & 2 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from .adjoint import (
AbstractAdjoint,
BacksolveAdjoint,
ImplicitAdjoint,
NoAdjoint,
RecursiveCheckpointAdjoint,
)
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
from .event import (
AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent,
SteadyStateEvent,
)
from .global_interpolation import (
AbstractGlobalInterpolation,
backward_hermite_coefficients,
Expand All @@ -31,7 +37,6 @@
from .saveat import SaveAt
from .solution import RESULTS, Solution
from .solver import (
AbstractAdaptiveSDESolver,
AbstractAdaptiveSolver,
AbstractDIRK,
AbstractERK,
Expand All @@ -45,6 +50,7 @@
AbstractWrappedSolver,
Bosh3,
ButcherTableau,
CalculateJacobian,
Dopri5,
Dopri8,
Euler,
Expand Down Expand Up @@ -81,4 +87,4 @@
)


__version__ = "0.1.2"
__version__ = "0.2.0"
73 changes: 71 additions & 2 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.lax as lax
import jax.numpy as jnp

from .misc import nondifferentiable_output, ω
from .misc import implicit_jvp, nondifferentiable_output, ω
from .saveat import SaveAt
from .term import AbstractTerm, AdjointTerm

Expand All @@ -22,6 +22,7 @@ def loop(
terms,
solver,
stepsize_controller,
discrete_terminating_event,
saveat,
t0,
t1,
Expand Down Expand Up @@ -92,6 +93,73 @@ def loop(self, *, throw, **kwargs):
return final_state, aux_stats


def _vf(ys, residual, args__terms, closure):
state_no_y, _ = residual
t = state_no_y.tprev
(y,) = ys # unpack length-1 dimension
args, terms = args__terms
_, _, solver, _, _ = closure
return solver.func(terms, t, y, args)


def _solve(args__terms, closure):
args, terms = args__terms
self, kwargs, solver, saveat, init_state = closure
final_state, aux_stats = self._loop_fn(
**kwargs,
args=args,
terms=terms,
solver=solver,
saveat=saveat,
init_state=init_state,
is_bounded=False,
)
# Note that we use .ys not .y here. The former is what is actually returned
# by diffeqsolve, so it is the thing we want to attach the tangent to.
return final_state.ys, (
eqx.tree_at(lambda s: s.ys, final_state, None),
aux_stats,
)


class ImplicitAdjoint(AbstractAdjoint):
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
This is used when solving towards a steady state, typically using
[`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$
for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found
through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through
the solver and instead directly compute
$\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}$
via the implicit function theorem.
""" # noqa: E501

def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs):
del throw

# `is` check because this may return a Tracer from SaveAt(ts=<array>)
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
raise ValueError(
"Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`."
)

init_state = eqx.tree_at(
lambda s: (s.y, s.solver_state, s.controller_state),
init_state,
replace_fn=lax.stop_gradient,
)
closure = (self, kwargs, solver, saveat, init_state)
ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure)

final_state_no_ys, aux_stats = residual
return (
eqx.tree_at(
lambda s: s.ys, final_state_no_ys, ys, is_leaf=lambda x: x is None
),
aux_stats,
)


# Compute derivatives with respect to the first argument:
# - y, corresponding to the initial state;
# - args, corresponding to explicit parameters;
Expand All @@ -116,7 +184,6 @@ def _loop_backsolve_fwd(y__args__terms, **kwargs):
return (final_state, aux_stats), (ts, ys)


# TODO: implement this as a single diffeqsolve with events, once events are supported.
def _loop_backsolve_bwd(
residuals,
grad_final_state__aux_stats,
Expand All @@ -125,6 +192,7 @@ def _loop_backsolve_bwd(
self,
solver,
stepsize_controller,
discrete_terminating_event,
saveat,
t0,
t1,
Expand Down Expand Up @@ -162,6 +230,7 @@ def _loop_backsolve_bwd(
adjoint=self,
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down
98 changes: 98 additions & 0 deletions diffrax/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import abc
from typing import Callable, Optional

import equinox as eqx

from .custom_types import Bool, PyTree, Scalar
from .misc import rms_norm
from .step_size_controller import AbstractAdaptiveStepSizeController


class AbstractDiscreteTerminatingEvent(eqx.Module):
"""Evaluated at the end of each integration step. If true then the solve is stopped
at that time.
"""

@abc.abstractmethod
def __call__(self, state, **kwargs):
"""**Arguments:**
- `state`: a dataclass of the evolving state of the system, including in
particular the solution `state.y` at time `state.tprev`.
- `**kwargs`: the integration options held constant throughout the solve
are passed as keyword arguments: `terms`, `solver`, `args`. etc.
**Returns**
A boolean. If true then the solve is terminated.
"""


class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
"""Terminates the solve if its condition is ever active."""

cond_fn: Callable[..., Bool]

def __call__(self, state, **kwargs):
return self.cond_fn(state, **kwargs)


DiscreteTerminatingEvent.__init__.__doc__ = """**Arguments:**
- `cond_fn`: A function `(state, **kwargs) -> bool` that is evaluated on every step of
the differential equation solve. If it returns `True` then the solve is finished at
that timestep. `state` is a dataclass of the evolving state of the system,
including in particular the solution `state.y` at time `state.tprev`. Passed as
keyword arguments are the `terms`, `solver`, `args` etc. that are constant
throughout the solve.
"""


class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
"""Terminates the solve once it reaches a steady state."""

rtol: Optional[float] = None
atol: Optional[float] = None
norm: Callable[[PyTree], Scalar] = rms_norm

def __call__(self, state, *, terms, args, solver, stepsize_controller, **kwargs):
del kwargs
_error = False
if self.rtol is None:
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
_rtol = stepsize_controller.rtol
else:
_error = True
else:
_rtol = self.rtol
if self.atol is None:
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
_atol = stepsize_controller.atol
else:
_error = True
else:
_atol = self.atol
if _error:
raise ValueError(
"The `rtol` and `atol` tolerances for `SteadyStateEvent` default "
"to the `rtol` and `atol` used with an adaptive step size "
"controller (such as `diffrax.PIDController`). Either use an "
"adaptive step size controller, or specify these tolerances "
"manually."
)

# TODO: this makes an additional function evaluation that in practice has
# probably already been made by the solver.
vf = solver.func(terms, state.tprev, state.y, args)
return self.norm(vf) < _atol + _rtol * self.norm(state.y)


SteadyStateEvent.__init__.__doc__ = """**Arguments:**
- `rtol`: The relative tolerance for determining convergence. Defaults to the
same `rtol` as passed to an adaptive step controller if one is used.
- `atol`: The absolute tolerance for determining convergence. Defaults to the
same `atol` as passed to an adaptive step controller if one is used.
- `norm`: A function `PyTree -> Scalar`, which is called to determine whether
the vector field is close to zero.
"""
Loading

0 comments on commit 115997e

Please sign in to comment.