Skip to content

Commit

Permalink
Fix raveling
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Nov 19, 2024
1 parent b0f0e29 commit b55ec0e
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 116 deletions.
18 changes: 18 additions & 0 deletions lineax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def structure_equal(x, y) -> bool:
return eqx.tree_equal(x, y) is True


def is_complex_structure(structure):
return jnp.isdtype(
jnp.result_type(*(jax.tree.flatten(structure)[0])),
"complex floating",
)


def complex_to_real_structure(in_structure):
return jtu.tree_map(
lambda x: ShapeDtypeStruct(
Expand All @@ -124,6 +131,17 @@ def complex_to_real_structure(in_structure):
)


def complex_to_real_tree(x, in_structure):
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(
lambda x, struct: jnp.stack([x.real, x.imag], axis=-1)
if jnp.isdtype(struct.dtype, "complex floating")
else x,
x,
in_structure,
)


def real_to_complex_tree(x, in_structure):
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(
Expand Down
13 changes: 4 additions & 9 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
complex_to_real_structure,
default_floating_dtype,
inexact_asarray,
is_complex_structure,
jacobian,
NoneAux,
real_to_complex_tree,
Expand Down Expand Up @@ -1324,15 +1325,9 @@ def _(operator):

@materialise.register(FunctionLinearOperator)
def _(operator):
complex_input = jnp.isdtype(
jnp.result_type(*(jax.tree.flatten(operator.in_structure())[0])),
"complex floating",
)
real_output = not jnp.isdtype(
jnp.result_type(*(jax.tree.flatten(operator.out_structure())[0])),
"complex floating",
)
if complex_input and real_output:
if is_complex_structure(operator.in_structure()) and not is_complex_structure(
operator.out_structure()
):
# We'll use R^2->R representation for C->R function.
in_structure = complex_to_real_structure(operator.in_structure())

Expand Down
22 changes: 18 additions & 4 deletions lineax/_solver/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from .._misc import (
complex_to_real_structure,
complex_to_real_tree,
is_complex_structure,
real_to_complex_tree,
strip_weak_dtype,
structure_equal,
Expand Down Expand Up @@ -86,11 +88,14 @@ def ravel_vector(
pytree: PyTree[Array], packed_structures: PackedStructures
) -> Shaped[Array, " size"]:
leaves, treedef = packed_structures.value
out_structure, _ = jtu.tree_unflatten(treedef, leaves)
out_structure, in_structure = jtu.tree_unflatten(treedef, leaves)
# `is` in case `tree_equal` returns a Tracer.
if not structure_equal(pytree, out_structure):
raise ValueError("pytree does not match out_structure")
# not using `ravel_pytree` as that doesn't come with guarantees about order

if is_complex_structure(out_structure) and not is_complex_structure(in_structure):
pytree = complex_to_real_tree(pytree, out_structure)
leaves = jtu.tree_leaves(pytree)
dtype = jnp.result_type(*leaves)
return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])
Expand All @@ -100,15 +105,24 @@ def unravel_solution(
solution: Shaped[Array, " size"], packed_structures: PackedStructures
) -> PyTree[Array]:
leaves, treedef = packed_structures.value
_, in_structure = jtu.tree_unflatten(treedef, leaves)
leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure))
out_structure, in_structure = jtu.tree_unflatten(treedef, leaves)
complex_real = is_complex_structure(in_structure) and not is_complex_structure(
out_structure
)
if complex_real:
leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure))
else:
leaves, treedef = jtu.tree_flatten(in_structure)
sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])
split = jnp.split(solution, sizes)
assert len(split) == len(leaves)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)]
return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure)
if complex_real:
return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure)
else:
return jtu.tree_unflatten(treedef, shaped)


def transpose_packed_structures(
Expand Down
192 changes: 89 additions & 103 deletions tests/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def test_jvp(
assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3)


@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse)
@pytest.mark.parametrize(
"solver, tags, pseudoinverse",
[stp for stp in solvers_tags_pseudoinverse if stp[-1]],
) # only pseudoinverse
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize(
"make_matrix",
Expand All @@ -133,125 +136,108 @@ def test_jvp(
),
)
def test_jvp_c_to_r(getkey, solver, tags, pseudoinverse, use_state, make_matrix):
dtype = jnp.complex128
make_operator = make_real_function_operator
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
if solver not in (
lx.QR(),
lx.SVD(),
):
print(solver)
return
pytest.skip("Real function operators are only supported for QR and SVD.")
if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)

out_size, _ = matrix.shape
out_dtype = (
complex_to_real_dtype(matrix.dtype)
if make_operator == make_real_function_operator
else matrix.dtype
)
vec = jr.normal(getkey(), (out_size,), dtype=out_dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype)

if has_tag(tags, lx.unit_diagonal_tag):
# For all the other tags, A + εB with A, B \in {matrices satisfying the tag}
# still satisfies the tag itself.
# This is the exception.
t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0)

make_op = ft.partial(make_operator, getkey)
operator, t_operator = eqx.filter_jvp(
make_op, (matrix, tags), (t_matrix, t_tags)
)

if use_state:
state = solver.init(operator, options={})
linear_solve = ft.partial(lx.linear_solve, state=state)
else:
linear_solve = lx.linear_solve

solve_vec_only = lambda v: linear_solve(operator, v, solver).value
vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,))

solve_op_only = lambda op: linear_solve(op, vec, solver).value
solve_op_vec = lambda op, v: linear_solve(op, v, solver).value

op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,))
op_vec_out, t_op_vec_out = eqx.filter_jvp(
solve_op_vec,
(operator, vec),
(t_operator, t_vec),
)
(expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp(
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=jnp.complex128)

out_size, in_size = matrix.shape
out_dtype = complex_to_real_dtype(matrix.dtype)
vec = jr.normal(getkey(), (out_size,), dtype=out_dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype)

if has_tag(tags, lx.unit_diagonal_tag):
# For all the other tags, A + εB with A, B \in {matrices satisfying the tag}
# still satisfies the tag itself.
# This is the exception.
t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0)

make_op = ft.partial(make_real_function_operator, getkey)
operator, t_operator = eqx.filter_jvp(make_op, (matrix, tags), (t_matrix, t_tags))

if use_state:
state = solver.init(operator, options={})
linear_solve = ft.partial(lx.linear_solve, state=state)
else:
linear_solve = lx.linear_solve

solve_vec_only = lambda v: linear_solve(operator, v, solver).value
vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,))

solve_op_only = lambda op: linear_solve(op, vec, solver).value
solve_op_vec = lambda op, v: linear_solve(op, v, solver).value

op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,))
op_vec_out, t_op_vec_out = eqx.filter_jvp(
solve_op_vec,
(operator, vec),
(t_operator, t_vec),
)
(expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp(
lambda op: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec
), # pyright: ignore
(matrix,),
(t_matrix,),
)
(expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp(
lambda op, v: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v
),
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)

# Work around JAX issue #14868.
if jnp.any(jnp.isnan(t_expected_op_out)):
_, (t_expected_op_out, *_) = finite_difference_jvp(
lambda op: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec
), # pyright: ignore
(matrix,),
(t_matrix,),
)
(expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp(
if jnp.any(jnp.isnan(t_expected_op_vec_out)):
_, (t_expected_op_vec_out, *_) = finite_difference_jvp(
lambda op, v: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v
),
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)

# Work around JAX issue #14868.
if jnp.any(jnp.isnan(t_expected_op_out)):
_, (t_expected_op_out, *_) = finite_difference_jvp(
lambda op: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec
), # pyright: ignore
(matrix,),
(t_matrix,),
)
if jnp.any(jnp.isnan(t_expected_op_vec_out)):
_, (t_expected_op_vec_out, *_) = finite_difference_jvp(
lambda op, v: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v
),
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)
real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1)
pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore
expected_vec_out = pinv_matrix @ vec
with jax.numpy_dtype_promotion("standard"):
expected_complex_vec_out = (
expected_vec_out[:in_size] + 1.0j * expected_vec_out[in_size:]
)
expected_complex_op_out = (
expected_op_out[:in_size] + 1.0j * expected_op_out[in_size:]
)
expected_complex_op_vec_out = (
expected_op_vec_out[:in_size] + 1.0j * expected_op_vec_out[in_size:]
)

real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1)
pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore
expected_vec_out = pinv_matrix @ vec
with jax.numpy_dtype_promotion("standard"):
expected_complex_vec_out = (
expected_vec_out[:out_size] + 1.0j * expected_vec_out[out_size:]
)
assert tree_allclose(vec_out, expected_complex_vec_out)
assert tree_allclose(vec_out, expected_complex_vec_out)
assert tree_allclose(op_out, expected_complex_op_out)
assert tree_allclose(op_vec_out, expected_complex_op_vec_out)

with jax.numpy_dtype_promotion("standard"):
expected_complex_op_out = (
expected_op_out[:out_size] + 1.0j * expected_op_out[out_size:]
)
expected_complex_op_vec_out = (
expected_op_vec_out[:out_size] + 1.0j * expected_op_vec_out[out_size:]
)
t_expected_vec_out = pinv_matrix @ t_vec

assert tree_allclose(op_out, expected_complex_op_out)
assert tree_allclose(op_vec_out, expected_complex_op_vec_out)

t_expected_vec_out = pinv_matrix @ t_vec
with jax.numpy_dtype_promotion("standard"):
t_expected_complex_vec_out = (
t_expected_vec_out[:in_size] + 1.0j * t_expected_vec_out[in_size:]
)
t_expected_complex_op_out = (
t_expected_op_out[:in_size] + 1.0j * t_expected_op_out[in_size:]
)

with jax.numpy_dtype_promotion("standard"):
t_expected_complex_vec_out = (
t_expected_vec_out[:out_size] + 1.0j * t_expected_vec_out[out_size:]
)
t_expected_complex_op_out = (
t_expected_op_out[:out_size] + 1.0j * t_expected_op_out[out_size:]
)
t_expected_complex_op_vec_out = (
t_expected_op_vec_out[:out_size]
+ 1.0j * t_expected_op_vec_out[out_size:]
)
assert tree_allclose(
matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3
t_expected_complex_op_vec_out = (
t_expected_op_vec_out[:in_size] + 1.0j * t_expected_op_vec_out[in_size:]
)
assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3)
assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3)
assert tree_allclose(
matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3
)
assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3)
assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3)

0 comments on commit b55ec0e

Please sign in to comment.