Skip to content

Commit

Permalink
Follow conjugation convention in tree_dot (#105)
Browse files Browse the repository at this point in the history
* Follow conjugation convention

* Fix tree_dot arguments
  • Loading branch information
Randl authored Aug 18, 2024
1 parent 4a7b108 commit 5e79294
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions lineax/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def tree_dot(tree1: PyTree[ArrayLike], tree2: PyTree[ArrayLike]) -> Inexact[Arra
for leaf1, leaf2 in zip(leaves1, leaves2):
dots.append(
jnp.dot(
jnp.reshape(leaf1, -1),
jnp.conj(leaf2).reshape(-1),
jnp.conj(leaf1).reshape(-1),
jnp.reshape(leaf2, -1),
precision=jax.lax.Precision.HIGHEST, # pyright: ignore
)
)
Expand Down
6 changes: 3 additions & 3 deletions lineax/_solver/bicgstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def cond_fun(carry):
def body_fun(carry):
y, r, alpha, omega, rho, p, v, diff, step = carry

rho_new = tree_dot(r, r0)
rho_new = tree_dot(r0, r)
beta = (rho_new / rho) * (alpha / omega)
p_new = (r**ω + beta * (p**ω - omega * v**ω)).ω

Expand All @@ -145,13 +145,13 @@ def body_fun(carry):
x = preconditioner.mv(p_new)
v_new = operator.mv(x)

alpha_new = rho_new / tree_dot(v_new, r0)
alpha_new = rho_new / tree_dot(r0, v_new)
s = (r**ω - alpha_new * v_new**ω).ω

z = preconditioner.mv(s)
t = operator.mv(z)

omega_new = tree_dot(t, s) / tree_dot(t, t)
omega_new = tree_dot(s, t) / tree_dot(t, t)

diff = (alpha_new * x**ω + omega_new * z**ω).ω
y_new = (y**ω + diff**ω).ω
Expand Down
6 changes: 3 additions & 3 deletions lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def mv(vector: PyTree) -> PyTree:
max_steps = self.max_steps
r0 = (vector**ω - mv(y0) ** ω).ω
p0 = preconditioner.mv(r0)
gamma0 = tree_dot(r0, p0)
gamma0 = tree_dot(p0, r0)
rcond = resolve_rcond(None, size, size, jnp.result_type(*leaves))
initial_value = (
ω(y0).call(lambda x: jnp.full_like(x, jnp.inf)).ω,
Expand Down Expand Up @@ -176,7 +176,7 @@ def cond_fun(value):
def body_fun(value):
_, y, r, p, gamma, step = value
mat_p = mv(p)
inner_prod = tree_dot(p, mat_p)
inner_prod = tree_dot(mat_p, p)
alpha = gamma / inner_prod
alpha = tree_where(
jnp.abs(inner_prod) > 100 * rcond * jnp.abs(gamma), alpha, jnp.nan
Expand Down Expand Up @@ -206,7 +206,7 @@ def cheap_r():

z = preconditioner.mv(r)
gamma_prev = gamma
gamma = tree_dot(r, z)
gamma = tree_dot(z, r)
beta = gamma / gamma_prev
p = (z**ω + beta * p**ω).ω
return diff, y, r, p, gamma, step
Expand Down

0 comments on commit 5e79294

Please sign in to comment.