From 5e7929474d79faed5fbc20284829be73215c493e Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Sun, 18 Aug 2024 16:58:58 +0300 Subject: [PATCH] Follow conjugation convention in `tree_dot` (#105) * Follow conjugation convention * Fix tree_dot arguments --- lineax/_norm.py | 4 ++-- lineax/_solver/bicgstab.py | 6 +++--- lineax/_solver/cg.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lineax/_norm.py b/lineax/_norm.py index 165fda8..79b6c92 100644 --- a/lineax/_norm.py +++ b/lineax/_norm.py @@ -36,8 +36,8 @@ def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Arra for leaf1, leaf2 in zip(leaves1, leaves2): dots.append( jnp.dot( - jnp.reshape(leaf1, -1), - jnp.conj(leaf2).reshape(-1), + jnp.conj(leaf1).reshape(-1), + jnp.reshape(leaf2, -1), precision=jax.lax.Precision.HIGHEST, # pyright: ignore ) ) diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 7578490..ab97f16 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -136,7 +136,7 @@ def cond_fun(carry): def body_fun(carry): y, r, alpha, omega, rho, p, v, diff, step = carry - rho_new = tree_dot(r, r0) + rho_new = tree_dot(r0, r) beta = (rho_new / rho) * (alpha / omega) p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω @@ -145,13 +145,13 @@ def body_fun(carry): x = preconditioner.mv(p_new) v_new = operator.mv(x) - alpha_new = rho_new / tree_dot(v_new, r0) + alpha_new = rho_new / tree_dot(r0, v_new) s = (r**ω - alpha_new * v_new**ω).ω z = preconditioner.mv(s) t = operator.mv(z) - omega_new = tree_dot(t, s) / tree_dot(t, t) + omega_new = tree_dot(s, t) / tree_dot(t, t) diff = (alpha_new * x**ω + omega_new * z**ω).ω y_new = (y**ω + diff**ω).ω diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index c7891c8..2d7c54e 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -134,7 +134,7 @@ def mv(vector: PyTree) -> PyTree: max_steps = self.max_steps r0 = (vector**ω - mv(y0) ** ω).ω p0 = preconditioner.mv(r0) - gamma0 = tree_dot(r0, p0) + gamma0 = tree_dot(p0, r0) rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves)) initial_value = ( ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω, @@ -176,7 +176,7 @@ def cond_fun(value): def body_fun(value): _, y, r, p, gamma, step = value mat_p = mv(p) - inner_prod = tree_dot(p, mat_p) + inner_prod = tree_dot(mat_p, p) alpha = gamma / inner_prod alpha = tree_where( jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), alpha, jnp.nan @@ -206,7 +206,7 @@ def cheap_r(): z = preconditioner.mv(r) gamma_prev = gamma - gamma = tree_dot(r, z) + gamma = tree_dot(z, r) beta = gamma / gamma_prev p = (z**ω + beta * p**ω).ω return diff, y, r, p, gamma, step