Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RATTLE solver #328

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
Midpoint,
MultiButcherTableau,
Ralston,
Rattle,
ReversibleHeun,
SemiImplicitEuler,
Sil3,
Expand Down
1 change: 1 addition & 0 deletions diffrax/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .midpoint import Midpoint
from .milstein import ItoMilstein, StratonovichMilstein
from .ralston import Ralston
from .rattle import Rattle
from .reversible_heun import ReversibleHeun
from .runge_kutta import (
AbstractDIRK,
Expand Down
146 changes: 146 additions & 0 deletions diffrax/solver/rattle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω

from ..custom_types import Array, Bool, DenseInfo, PyTree, Scalar
from ..local_interpolation import LocalLinearInterpolation
from ..solution import RESULTS
from ..term import AbstractTerm
from .base import AbstractImplicitSolver


_ErrorEstimate = None
_SolverState = None

ConstrainFn = Callable[[PyTree], Array]


class RattleVars(NamedTuple):
p_1_2: PyTree # Midpoint momentum
q_1: PyTree # Midpoint position
p_1: PyTree # final momentum
lam: PyTree # Midpoint Lagrange multiplier (state)
mu: PyTree # final Lagrange multiplier (momentum)


class Rattle(AbstractImplicitSolver):
"""Rattle method.

2nd order symplectic method with constrains `constrain(x)=0`.

??? cite "Reference"

```bibtex
@article{ANDERSEN198324,
title = {Rattle: A “velocity” version of the shake
algorithm for molecular dynamics calculations},
journal = {Journal of Computational Physics},
volume = {52},
number = {1},
pages = {24-34},
year = {1983},
author = {Hans C Andersen},
}
```
"""

term_structure = (AbstractTerm, AbstractTerm)
interpolation_cls = LocalLinearInterpolation
# Fix TypeError: non-default argument 'constrain' follows default argument
constrain: ConstrainFn = None

def order(self, terms):
return 2

def init(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
t1: Scalar,
y0: PyTree,
args: PyTree,
) -> _SolverState:
return None

def step(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
t1: Scalar,
y0: Tuple[PyTree, PyTree],
args: PyTree,
solver_state: _SolverState,
made_jump: Bool,
) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

term_1, term_2 = terms
midpoint = (t1 + t0) / 2

control1_half_1 = term_1.contr(t0, midpoint)
control1_half_2 = term_1.contr(midpoint, t1)

control2_half_1 = term_2.contr(t0, midpoint)
control2_half_2 = term_2.contr(midpoint, t1)

p0, q0 = y0

def eq(x: RattleVars, args=None):
_, vjp_fun = jax.vjp(self.constrain, q0)
_, vjp_fun_mu = jax.vjp(self.constrain, x.q_1)

zero = (
(
p0**ω
- control1_half_1 * (vjp_fun(x.lam)[0]) ** ω
+ term_1.vf_prod(t0, q0, args, control1_half_1) ** ω
- x.p_1_2**ω
).ω,
(
q0**ω
+ term_2.vf_prod(t0, x.p_1_2, args, control2_half_1) ** ω
+ term_2.vf_prod(midpoint, x.p_1_2, args, control2_half_2) ** ω
- x.q_1**ω
).ω,
self.constrain(x.q_1),
(
x.p_1_2**ω
+ term_1.vf_prod(midpoint, x.q_1, args, control1_half_2) ** ω
- (control1_half_2 * vjp_fun_mu(x.mu)[0] ** ω)
- x.p_1**ω
).ω,
jax.jvp(self.constrain, (x.q_1,), (term_2.vf(t1, x.p_1, args),))[1],
)
return zero

cs = jax.eval_shape(self.constrain, q0)

init_vars = RattleVars(
p_1_2=p0,
q_1=(q0**ω * 2).ω,
p_1=p0,
lam=jtu.tree_map(jnp.zeros_like, cs),
mu=jtu.tree_map(jnp.zeros_like, cs),
)

sol = self.nonlinear_solver(eq, init_vars, None)

y1 = (sol.root.p_1, sol.root.q_1)
dense_info = dict(y0=y0, y1=y1)
return y1, None, dense_info, None, RESULTS.successful

def func(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
y0: Tuple[PyTree, PyTree],
args: PyTree,
) -> Tuple[PyTree, PyTree]:
term_1, term_2 = terms
y0_1, y0_2 = y0
f1 = term_1.func(t0, y0_2, args)
f2 = term_2.func(t0, y0_1, args)
return (f1, f2)
38 changes: 38 additions & 0 deletions test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,41 @@ def vector_field(t, y, args):
return out.ys[0]

f(1.0)


def test_rattle():
import numpy as np

def constrain(q):
return jnp.sqrt(jnp.sum(q**2, keepdims=True)) - 1.0

rat = diffrax.Rattle(
nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-4, atol=1e-6),
constrain=constrain,
)

# Potential free movement on a circle
def H(p, q):
del q
return p @ p.T / 2.0

# V = p^2/2m m=1, v=1

terms = (
diffrax.ODETerm(
lambda t, q, args: -jax.grad(H, argnums=1)(jnp.zeros_like(q), q)
),
diffrax.ODETerm(
lambda t, p, args: jax.grad(H, argnums=0)(p, jnp.zeros_like(p))
),
)
# p,q
y0 = (jnp.asarray([1.0, 0.0]), jnp.asarray([0.0, 1.0]))
t1 = 2 * jnp.pi / 4
n = 2**12
dt = t1 / n
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(terms, rat, 0.0, t1, dt0=dt, y0=y0, saveat=saveat)
p1, q1 = solution.ys
assert np.allclose(p1, jnp.asarray([0.0, -1.0]), rtol=1e-4, atol=1e-4)
assert np.allclose(q1, jnp.asarray([1.0, 0.0]), rtol=1e-4, atol=1e-4)