Skip to content

Commit

Permalink
Complex2 fix (#330)
Browse files Browse the repository at this point in the history
* tweaked dtypes such that complex types are supported

* added tests for complex dtype support

* Fixed complex dtype support after merge

* Fixed issue with float64 and weak dtypes, see discussion at: #197 (comment)

* Added warning and error checking for complex dtypes

* Fix minor complex-related problems

* Nit fix

* fixed pre-commit

* Delete .idea directory

---------

Co-authored-by: Timon Hoess <[email protected]>
Co-authored-by: Evgenii Zheltonozhskii <[email protected]>
  • Loading branch information
3 people authored Nov 7, 2023
1 parent 712c208 commit 017758d
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 15 deletions.
25 changes: 22 additions & 3 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .saveat import SaveAt, SubSaveAt
from .solution import is_okay, is_successful, RESULTS, Solution
from .solver import (
AbstractImplicitSolver,
AbstractItoSolver,
AbstractSolver,
AbstractStratonovichSolver,
Expand Down Expand Up @@ -605,6 +606,18 @@ def diffeqsolve(
pred = (t1 - t0) * dt0 < 0
dt0 = eqxi.error_if(jnp.array(dt0), pred, msg)

# Error checking and warning for complex dtypes
if any(jtu.tree_leaves(jtu.tree_map(jnp.iscomplexobj, y0))):
if isinstance(solver, AbstractImplicitSolver):
raise ValueError(
"Implicit solvers in conjunction with complex dtypes is currently not "
"supported."
)
warnings.warn(
"Complex dtype support is work in progress, please read "
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully."
)

# Backward compatibility
if isinstance(
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
Expand Down Expand Up @@ -664,8 +677,10 @@ def _get_subsaveat_ts(saveat):
)

# Time will affect state, so need to promote the state dtype as well if necessary.
# fixing issue with float64 and weak dtypes, see discussion at:
# https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527
def _promote(yi):
_dtype = jnp.result_type(yi, *timelikes) # noqa: F821
_dtype = jnp.result_type(yi, dtype) # noqa: F821
return jnp.asarray(yi, dtype=_dtype)

y0 = jtu.tree_map(_promote, y0)
Expand Down Expand Up @@ -759,7 +774,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
save_index = 0
ts = jnp.full(out_size, jnp.inf)
struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args)
ys = jtu.tree_map(lambda y: jnp.full((out_size,) + y.shape, jnp.inf), struct)
ys = jtu.tree_map(
lambda y: jnp.full((out_size,) + y.shape, jnp.inf, dtype=y.dtype), struct
)
return SaveState(
ts=ts, ys=ys, save_index=save_index, saveat_ts_index=saveat_ts_index
)
Expand All @@ -779,7 +796,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
)
dense_ts = jnp.full(max_steps + 1, jnp.inf)
_make_full = lambda x: jnp.full((max_steps,) + jnp.shape(x), jnp.inf)
_make_full = lambda x: jnp.full(
(max_steps,) + jnp.shape(x), jnp.inf, dtype=x.dtype
)
dense_infos = jtu.tree_map(_make_full, dense_info)
dense_save_index = 0
else:
Expand Down
4 changes: 2 additions & 2 deletions diffrax/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def _rms_norm_jvp(x, tx):
pred = (out == 0) | jnp.isinf(out)
numerator = jnp.where(pred, 0, x)
denominator = jnp.where(pred, 1, out * x.size)
t_out = jnp.dot(numerator / denominator, tx)
return out, t_out
t_out = jnp.dot(numerator / denominator, jnp.conj(tx))
return out, jnp.real(t_out)


def adjoint_rms_seminorm(x: Tuple[PyTree, PyTree, PyTree, PyTree]) -> Scalar:
Expand Down
4 changes: 3 additions & 1 deletion diffrax/nonlinear_solver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,6 @@ def jac(fn: Callable, x: PyTree, args: PyTree) -> LU_Jacobian:
if not jnp.issubdtype(flat, jnp.inexact):
# Handle integer arguments
flat = flat.astype(jnp.float32)
return jsp.linalg.lu_factor(jax.jacfwd(curried)(flat))
return jsp.linalg.lu_factor(
jax.jacfwd(curried, holomorphic=jnp.iscomplexobj(flat))(flat)
)
2 changes: 1 addition & 1 deletion diffrax/step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def adapt_step_size(
# ε_n = atol + norm(y) * rtol with y on the nth step
# r_n = norm(y_error) with y_error on the nth step
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth
# step and y on the mth step
# step and y on the mth step
# β_1 = pcoeff + icoeff + dcoeff
# β_2 = -(pcoeff + 2 * dcoeff)
# β_3 = dcoeff
Expand Down
4 changes: 2 additions & 2 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def implicit_tol(solver):
return solver


def random_pytree(key, treedef):
def random_pytree(key, treedef, dtype=None):
keys = jrandom.split(key, treedef.num_leaves)
leaves = []
for key in keys:
dimkey, sizekey, valuekey = jrandom.split(key, 3)
num_dims = jrandom.randint(dimkey, (), 0, 5)
dim_sizes = jrandom.randint(sizekey, (num_dims,), 0, 5)
value = jrandom.normal(valuekey, dim_sizes)
value = jrandom.normal(valuekey, dim_sizes, dtype=dtype)
leaves.append(value)
return jtu.tree_unflatten(treedef, leaves)

Expand Down
26 changes: 20 additions & 6 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _all_pairs(*args):


@pytest.mark.parametrize(
"solver,t_dtype,treedef,stepsize_controller",
"solver,t_dtype,y_dtype,treedef,stepsize_controller",
_all_pairs(
dict(
default=diffrax.Euler(),
Expand All @@ -58,21 +58,32 @@ def _all_pairs(*args):
),
),
dict(default=jnp.float32, opts=(int, float, jnp.int32)),
dict(default=jnp.float32, opts=(jnp.complex64,)),
dict(default=treedefs[0], opts=treedefs[1:]),
dict(
default=diffrax.ConstantStepSize(),
opts=(diffrax.PIDController(rtol=1e-3, atol=1e-6),),
),
),
)
def test_basic(solver, t_dtype, treedef, stepsize_controller, getkey):
def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):
if not isinstance(solver, diffrax.AbstractAdaptiveSolver) and isinstance(
stepsize_controller, diffrax.PIDController
):
return

def f(t, y, args):
return jtu.tree_map(operator.neg, y)
if jnp.iscomplexobj(y_dtype):

def f(t, y, args):
return jtu.tree_map(lambda _y: operator.mul(-1j, _y), y)

if isinstance(solver, diffrax.AbstractImplicitSolver):
return

else:

def f(t, y, args):
return jtu.tree_map(operator.neg, y)

if t_dtype is int:
t0 = 0
Expand All @@ -92,7 +103,7 @@ def f(t, y, args):
dt0 = jnp.array(0.01)
else:
raise ValueError
y0 = random_pytree(getkey(), treedef)
y0 = random_pytree(getkey(), treedef, dtype=y_dtype)
try:
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f),
Expand All @@ -113,7 +124,10 @@ def f(t, y, args):
else:
raise
y1 = sol.ys
true_y1 = jtu.tree_map(lambda x: (x * math.exp(-1))[None], y0)
if jnp.iscomplexobj(y_dtype):
true_y1 = jtu.tree_map(lambda x: (x * jnp.exp(-1j))[None], y0)
else:
true_y1 = jtu.tree_map(lambda x: (x * math.exp(-1))[None], y0)
assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2)


Expand Down

0 comments on commit 017758d

Please sign in to comment.