Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stochastic Theta method #498

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
SPaRK as SPaRK,
SRA1 as SRA1,
StochasticButcherTableau as StochasticButcherTableau,
StochasticTheta as StochasticTheta,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
)
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@
AbstractSRK as AbstractSRK,
StochasticButcherTableau as StochasticButcherTableau,
)
from .stochastic_theta import StochasticTheta as StochasticTheta
from .tsit5 import Tsit5 as Tsit5
145 changes: 145 additions & 0 deletions diffrax/_solver/stochastic_theta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections.abc import Callable
from typing import ClassVar
from typing_extensions import TypeAlias

import optimistix as optx
from equinox.internal import ω

from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from .._local_interpolation import LocalLinearInterpolation
from .._root_finder import with_stepsize_controller_tols
from .._solution import RESULTS
from .._term import AbstractTerm, MultiTerm, ODETerm
from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, AbstractItoSolver


_SolverState: TypeAlias = None


def _implicit_relation(z1, nonlinear_solve_args):
(
vf_prod_drift,
t1,
y0,
args,
control,
k0_drift,
k0_diff,
theta,
) = nonlinear_solve_args
add_state = (y0**ω + z1**ω).ω
implicit_drift = (vf_prod_drift(t1, add_state, args, control) ** ω * theta).ω
euler_drift = ((1 - theta) * k0_drift**ω).ω
diff = (z1**ω - (implicit_drift**ω + euler_drift**ω + k0_diff**ω).ω ** ω).ω
return diff


class StochasticTheta(
AbstractImplicitSolver,
AbstractAdaptiveSolver,
AbstractItoSolver,
):
r"""Stochastic Theta method.

Stochastic A stable 0.5 strong order (1.0 weak order) SDIRK method. Has an embedded
1st order Euler method for adaptive step sizing. Uses 1 stage. Uses a 1st order
local linear interpolation for dense/ts output.

!!! warning

If `theta` is 0, this results in an explicit Euler step, which is also how the
error estimate is computed (which would result in estimated error being 0).

??? cite "Reference"

```bibtex
@article{higham2000mean,
title={Mean-square and asymptotic stability of the stochastic theta method},
author={Higham, Desmond J},
journal={SIAM journal on numerical analysis},
volume={38},
number={3},
pages={753--769},
year={2000},
publisher={SIAM}
}
```
"""

theta: float
term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]]
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)()
root_find_max_steps: int = 10

def order(self, terms):
return 1

def error_order(self, terms):
return 1.0

def strong_order(self, terms):
return 0.5

def init(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
) -> _SolverState:
return None

def step(
self,
terms: MultiTerm[tuple[ODETerm, AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
del made_jump
control = terms.contr(t0, t1)
k0_drift = terms.terms[0].vf_prod(t0, y0, args, control[0])
k0_diff = terms.terms[1].vf_prod(t0, y0, args, control[1])
root_args = (
terms.terms[0].vf_prod,
t1,
y0,
args,
control[0],
k0_drift,
k0_diff,
self.theta,
)
nonlinear_sol = optx.root_find(
_implicit_relation,
self.root_finder,
k0_drift,
root_args,
throw=False,
max_steps=self.root_find_max_steps,
)
k1 = nonlinear_sol.value
y1 = (y0**ω + k1**ω).ω
# Use the trapezoidal rule for adaptive step sizing.
k0 = (k0_drift**ω + k0_diff**ω).ω
y_error = (0.5 * (k1**ω - k0**ω)).ω
dense_info = dict(y0=y0, y1=y1)
solver_state = None
result = RESULTS.promote(nonlinear_sol.result)
return y1, y_error, dense_info, solver_state, result

def func(
self,
terms: AbstractTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
) -> VF:
return terms.vf(t0, y0, args)
10 changes: 10 additions & 0 deletions docs/api/solvers/sde_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ These solvers can be used to solve SDEs just as well as they can be used to solv
selection:
members: false


### Implicit Runge--Kutta (IRK) methods

These are SDE only IRK methods.

::: diffrax.StochasticTheta
selection:
members: false


### Stochastic Runge--Kutta (SRK)

These are a particularly important class of SDE-only solvers.
Expand Down
31 changes: 31 additions & 0 deletions test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,37 @@ def test_implicit_euler_adaptive():
assert out2.result == diffrax.RESULTS.successful


def test_stochastic_theta_adaptive():
t0 = 0
t1 = 1
dt0 = 1
y0 = 1.0

ode = diffrax.ODETerm(lambda t, y, args: -10 * y**3)
path = diffrax.VirtualBrownianTree(t0, t1, 1e-5, (1,), key=jax.random.key(0))
diff = diffrax.ControlTerm(lambda t, y, args: jnp.array([1.0]), path)
term = diffrax.MultiTerm(ode, diff)

solver1 = diffrax.StochasticTheta(
1.0, root_finder=diffrax.VeryChord(rtol=1e-5, atol=1e-5)
)
solver2 = diffrax.StochasticTheta(1.0)
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
out1 = diffrax.diffeqsolve(term, solver1, t0, t1, dt0, y0, throw=False)
out2 = diffrax.diffeqsolve(
term,
solver2,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
throw=False,
)
assert out1.result == diffrax.RESULTS.nonlinear_divergence
assert out2.result == diffrax.RESULTS.successful


class _DoubleDopri5(diffrax.AbstractRungeKutta):
tableau: ClassVar[diffrax.MultiButcherTableau] = diffrax.MultiButcherTableau(
diffrax.Dopri5.tableau, diffrax.Dopri5.tableau
Expand Down
Loading