Skip to content

Commit

Permalink
PID for complex dtype fixes (#391)
Browse files Browse the repository at this point in the history
* Fix complex casting and types issues

* Dependency version
  • Loading branch information
Randl authored Apr 22, 2024
1 parent 322852a commit 41c58a6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
10 changes: 6 additions & 4 deletions diffrax/_solver/tsit5.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import ClassVar, Optional

import jax
import jax.numpy as jnp
import numpy as np
from equinox.internal import ω
Expand Down Expand Up @@ -147,10 +148,11 @@ def evaluate(
)
b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2
b7 = 2.5 * (t - 1) * (t - 0.6) * t**2
return (
self.y0**ω
+ vector_tree_dot(jnp.stack([b1, b2, b3, b4, b5, b6, b7]), self.k) ** ω
).ω
with jax.numpy_dtype_promotion("standard"):
return (
self.y0**ω
+ vector_tree_dot(jnp.stack([b1, b2, b3, b4, b5, b6, b7]), self.k) ** ω
).ω


class Tsit5(AbstractERK):
Expand Down
12 changes: 8 additions & 4 deletions diffrax/_step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
Expand All @@ -18,6 +19,7 @@
from equinox import AbstractVar
from equinox.internal import ω
from jaxtyping import Array, PyTree
from lineax.internal import complex_to_real_dtype

from .._custom_types import (
Args,
Expand Down Expand Up @@ -460,8 +462,8 @@ def init(
jump_next_step,
at_dtmin,
dt0,
jnp.array(1.0, dtype=y_dtype),
jnp.array(1.0, dtype=y_dtype),
jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)),
jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)),
)

def adapt_step_size(
Expand Down Expand Up @@ -569,7 +571,8 @@ def _scale(_y0, _y1_candidate, _y_error):
_nan = jnp.isnan(_y1_candidate).any()
_y1_candidate = jnp.where(_nan, _y0, _y1_candidate)
_y = jnp.maximum(jnp.abs(_y0), jnp.abs(_y1_candidate))
return _y_error / (self.atol + _y * self.rtol)
with jax.numpy_dtype_promotion("standard"):
return _y_error / (self.atol + _y * self.rtol)

scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error))
keep_step = scaled_error < 1
Expand Down Expand Up @@ -607,7 +610,8 @@ def _scale(_y0, _y1_candidate, _y_error):
# a grad API boundary as part of a larger model.)
factor = lax.stop_gradient(factor)
factor = eqxi.nondifferentiable(factor)
dt = prev_dt * factor
with jax.numpy_dtype_promotion("standard"):
dt = prev_dt * factor.astype(prev_dt)

# E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0,
# so factor=factormin, and we shrunk our step.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"]
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.6"]

[build-system]
requires = ["hatchling"]
Expand Down
40 changes: 19 additions & 21 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,25 @@ def _all_pairs(*args):


@pytest.mark.parametrize(
"solver,t_dtype,y_dtype,treedef,stepsize_controller",
_all_pairs(
dict(
default=diffrax.Euler(),
opts=(
diffrax.LeapfrogMidpoint(),
diffrax.ReversibleHeun(),
diffrax.Tsit5(),
diffrax.ImplicitEuler(
root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)
),
diffrax.Kvaerno3(root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)),
),
),
dict(default=jnp.float32, opts=(int, float, jnp.int32)),
dict(default=jnp.float32, opts=(jnp.complex64,)),
dict(default=treedefs[0], opts=treedefs[1:]),
dict(
default=diffrax.ConstantStepSize(),
opts=(diffrax.PIDController(rtol=1e-5, atol=1e-8),),
),
"solver",
(
diffrax.Euler(),
diffrax.LeapfrogMidpoint(),
diffrax.ReversibleHeun(),
diffrax.Tsit5(),
diffrax.ImplicitEuler(root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)),
diffrax.Kvaerno3(root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)),
),
)
@pytest.mark.parametrize("t_dtype", (jnp.float32, int, float, jnp.int32))
@pytest.mark.parametrize("y_dtype", (jnp.float32, jnp.complex64))
@pytest.mark.parametrize("treedef", treedefs)
@pytest.mark.parametrize(
"stepsize_controller",
(
diffrax.ConstantStepSize(),
diffrax.PIDController(rtol=1e-5, atol=1e-8),
diffrax.PIDController(rtol=1e-5, atol=1e-8, pcoeff=0.3, icoeff=0.3, dcoeff=0.0),
),
)
def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):
Expand Down

0 comments on commit 41c58a6

Please sign in to comment.