diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d8aca8d4..7e33c4bf 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -71,6 +71,7 @@ Midpoint, MultiButcherTableau, Ralston, + Rattle, ReversibleHeun, SemiImplicitEuler, Sil3, diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index ace213c4..77652d35 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -24,6 +24,7 @@ from .midpoint import Midpoint from .milstein import ItoMilstein, StratonovichMilstein from .ralston import Ralston +from .rattle import Rattle from .reversible_heun import ReversibleHeun from .runge_kutta import ( AbstractDIRK, diff --git a/diffrax/solver/rattle.py b/diffrax/solver/rattle.py new file mode 100644 index 00000000..5a1e484a --- /dev/null +++ b/diffrax/solver/rattle.py @@ -0,0 +1,146 @@ +from typing import Callable, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +import jax.tree_util as jtu +from equinox.internal import ω + +from ..custom_types import Array, Bool, DenseInfo, PyTree, Scalar +from ..local_interpolation import LocalLinearInterpolation +from ..solution import RESULTS +from ..term import AbstractTerm +from .base import AbstractImplicitSolver + + +_ErrorEstimate = None +_SolverState = None + +ConstrainFn = Callable[[PyTree], Array] + + +class RattleVars(NamedTuple): + p_1_2: PyTree # Midpoint momentum + q_1: PyTree # Midpoint position + p_1: PyTree # final momentum + lam: PyTree # Midpoint Lagrange multiplier (state) + mu: PyTree # final Lagrange multiplier (momentum) + + +class Rattle(AbstractImplicitSolver): + """Rattle method. + + 2nd order symplectic method with constrains `constrain(x)=0`. + + ??? cite "Reference" + + ```bibtex + @article{ANDERSEN198324, + title = {Rattle: A “velocity” version of the shake + algorithm for molecular dynamics calculations}, + journal = {Journal of Computational Physics}, + volume = {52}, + number = {1}, + pages = {24-34}, + year = {1983}, + author = {Hans C Andersen}, + } + ``` + """ + + term_structure = (AbstractTerm, AbstractTerm) + interpolation_cls = LocalLinearInterpolation + # Fix TypeError: non-default argument 'constrain' follows default argument + constrain: ConstrainFn = None + + def order(self, terms): + return 2 + + def init( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + t1: Scalar, + y0: PyTree, + args: PyTree, + ) -> _SolverState: + return None + + def step( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + t1: Scalar, + y0: Tuple[PyTree, PyTree], + args: PyTree, + solver_state: _SolverState, + made_jump: Bool, + ) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + term_1, term_2 = terms + midpoint = (t1 + t0) / 2 + + control1_half_1 = term_1.contr(t0, midpoint) + control1_half_2 = term_1.contr(midpoint, t1) + + control2_half_1 = term_2.contr(t0, midpoint) + control2_half_2 = term_2.contr(midpoint, t1) + + p0, q0 = y0 + + def eq(x: RattleVars, args=None): + _, vjp_fun = jax.vjp(self.constrain, q0) + _, vjp_fun_mu = jax.vjp(self.constrain, x.q_1) + + zero = ( + ( + p0**ω + - control1_half_1 * (vjp_fun(x.lam)[0]) ** ω + + term_1.vf_prod(t0, q0, args, control1_half_1) ** ω + - x.p_1_2**ω + ).ω, + ( + q0**ω + + term_2.vf_prod(t0, x.p_1_2, args, control2_half_1) ** ω + + term_2.vf_prod(midpoint, x.p_1_2, args, control2_half_2) ** ω + - x.q_1**ω + ).ω, + self.constrain(x.q_1), + ( + x.p_1_2**ω + + term_1.vf_prod(midpoint, x.q_1, args, control1_half_2) ** ω + - (control1_half_2 * vjp_fun_mu(x.mu)[0] ** ω) + - x.p_1**ω + ).ω, + jax.jvp(self.constrain, (x.q_1,), (term_2.vf(t1, x.p_1, args),))[1], + ) + return zero + + cs = jax.eval_shape(self.constrain, q0) + + init_vars = RattleVars( + p_1_2=p0, + q_1=(q0**ω * 2).ω, + p_1=p0, + lam=jtu.tree_map(jnp.zeros_like, cs), + mu=jtu.tree_map(jnp.zeros_like, cs), + ) + + sol = self.nonlinear_solver(eq, init_vars, None) + + y1 = (sol.root.p_1, sol.root.q_1) + dense_info = dict(y0=y0, y1=y1) + return y1, None, dense_info, None, RESULTS.successful + + def func( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + y0: Tuple[PyTree, PyTree], + args: PyTree, + ) -> Tuple[PyTree, PyTree]: + term_1, term_2 = terms + y0_1, y0_2 = y0 + f1 = term_1.func(t0, y0_2, args) + f2 = term_2.func(t0, y0_1, args) + return (f1, f2) diff --git a/test/test_solver.py b/test/test_solver.py index 67ed0701..ff51c375 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -470,3 +470,41 @@ def vector_field(t, y, args): return out.ys[0] f(1.0) + + +def test_rattle(): + import numpy as np + + def constrain(q): + return jnp.sqrt(jnp.sum(q**2, keepdims=True)) - 1.0 + + rat = diffrax.Rattle( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-4, atol=1e-6), + constrain=constrain, + ) + + # Potential free movement on a circle + def H(p, q): + del q + return p @ p.T / 2.0 + + # V = p^2/2m m=1, v=1 + + terms = ( + diffrax.ODETerm( + lambda t, q, args: -jax.grad(H, argnums=1)(jnp.zeros_like(q), q) + ), + diffrax.ODETerm( + lambda t, p, args: jax.grad(H, argnums=0)(p, jnp.zeros_like(p)) + ), + ) + # p,q + y0 = (jnp.asarray([1.0, 0.0]), jnp.asarray([0.0, 1.0])) + t1 = 2 * jnp.pi / 4 + n = 2**12 + dt = t1 / n + saveat = diffrax.SaveAt(t1=True) + solution = diffrax.diffeqsolve(terms, rat, 0.0, t1, dt0=dt, y0=y0, saveat=saveat) + p1, q1 = solution.ys + assert np.allclose(p1, jnp.asarray([0.0, -1.0]), rtol=1e-4, atol=1e-4) + assert np.allclose(q1, jnp.asarray([1.0, 0.0]), rtol=1e-4, atol=1e-4)