From 1d0ac6c1da93d2df2e3c9d0d7461b6a2e0100f7a Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Fri, 22 Mar 2024 11:07:31 +0200 Subject: [PATCH] Fix the inner product order in BiCG, add abs to CG --- lineax/_solver/bicgstab.py | 5 +++-- lineax/_solver/cg.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 2261d4f..aefaaa6 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -101,6 +101,7 @@ def compute( # We use the notation found on the wikipedia except with y instead of x: # https://en.wikipedia.org/wiki/ # Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB + # preconditioner in this case is K2^(-1) (i.e., right preconditioning) r0 = (vector**ω - operator.mv(y0) ** ω).ω @@ -134,7 +135,7 @@ def cond_fun(carry): def body_fun(carry): y, r, alpha, omega, rho, p, v, diff, step = carry - rho_new = tree_dot(r0, r) + rho_new = tree_dot(r, r0) beta = (rho_new / rho) * (alpha / omega) p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω @@ -143,7 +144,7 @@ def body_fun(carry): x = preconditioner.mv(p_new) v_new = operator.mv(x) - alpha_new = rho_new / tree_dot(r0, v_new) + alpha_new = rho_new / tree_dot(v_new, r0) s = (r**ω - alpha_new * v_new**ω).ω z = preconditioner.mv(s) diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index f8c673c..62d0bde 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -177,7 +177,7 @@ def body_fun(value): inner_prod = tree_dot(p, mat_p) alpha = gamma / inner_prod alpha = tree_where( - jnp.abs(inner_prod) > 100 * rcond * gamma, alpha, jnp.nan + jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), alpha, jnp.nan ) diff = (alpha * p**ω).ω y = (y**ω + diff**ω).ω