Skip to content

Commit

Permalink
Fix the inner product order in BiCG, add abs to CG
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl authored and patrick-kidger committed Mar 23, 2024
1 parent 9b923c8 commit 1d0ac6c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions lineax/_solver/bicgstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def compute(
# We use the notation found on the wikipedia except with y instead of x:
# https://en.wikipedia.org/wiki/
# Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB
# preconditioner in this case is K2^(-1) (i.e., right preconditioning)

r0 = (vector**ω - operator.mv(y0) ** ω).ω

Expand Down Expand Up @@ -134,7 +135,7 @@ def cond_fun(carry):
def body_fun(carry):
y, r, alpha, omega, rho, p, v, diff, step = carry

rho_new = tree_dot(r0, r)
rho_new = tree_dot(r, r0)
beta = (rho_new / rho) * (alpha / omega)
p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω

Expand All @@ -143,7 +144,7 @@ def body_fun(carry):
x = preconditioner.mv(p_new)
v_new = operator.mv(x)

alpha_new = rho_new / tree_dot(r0, v_new)
alpha_new = rho_new / tree_dot(v_new, r0)
s = (r**ω - alpha_new * v_new**ω).ω

z = preconditioner.mv(s)
Expand Down
2 changes: 1 addition & 1 deletion lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def body_fun(value):
inner_prod = tree_dot(p, mat_p)
alpha = gamma / inner_prod
alpha = tree_where(
jnp.abs(inner_prod) > 100 * rcond * gamma, alpha, jnp.nan
jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), alpha, jnp.nan
)
diff = (alpha * p**ω).ω
y = (y**ω + diff**ω).ω
Expand Down

0 comments on commit 1d0ac6c

Please sign in to comment.