Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New solver based on Woodbury matrix identity #97

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TangentLinearOperator as TangentLinearOperator,
tridiagonal as tridiagonal,
TridiagonalLinearOperator as TridiagonalLinearOperator,
WoodburyLinearOperator as WoodburyLinearOperator,
)
from ._solution import RESULTS as RESULTS, Solution as Solution
from ._solve import (
Expand Down
83 changes: 83 additions & 0 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,73 @@ def out_structure(self):
return self.operator.out_structure()


class WoodburyLinearOperator(AbstractLinearOperator, strict=True):
"""As [`lineax.MatrixLinearOperator`][], but for specifically a matrix
with A + U C V structure, such that the Woodbury identity can be used"""

A: AbstractLinearOperator
C: Inexact[Array, " k k"]
U: Inexact[Array, " n k"]
V: Inexact[Array, " k n"]

def __init__(
self,
A: AbstractLinearOperator,
C: Inexact[Array, " k k"],
U: Inexact[Array, " n k"],
V: Inexact[Array, " k n"],
):
"""**Arguments:**

Matrix of form A + U C V, such that the inverse can be computed
using Woodbury matrix identity

- `A`: Linear operator, in/out shape (n,n)
- `C`: A rank-two JAX array. Shape (k,k)
- `U`: A rank-two JAX array. Shape (n,k)
- `V`: A rank-two JAX array. Shape (k,n)

"""
self.A = A
self.C = inexact_asarray(C)
self.U = inexact_asarray(U)
self.V = inexact_asarray(V)
(N, M) = self.A.in_structure(), self.A.out_structure()
if not eqx.tree_equal(N, M):
raise ValueError(f"expecting square operator for A, got {N} by {M}")
(K, L) = self.C.shape
if K != L:
raise ValueError(f"expecting square operator for C, got {K} by {L}")
N = N.shape[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N is an arbitrary PyTree, it doesn't necessarily have a .shape attribute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, This is mostly my ignorance of how lineax is using PyTrees as linear operators.

I am not sure how allowing arbitrary PyTrees for A meshes with the requirement of a Woodbury structure. In the context of the Woodbury matrix identity, I would think A would have to have a matrix representation of an n by n matrix. For a PyTree representation then would C, U and V also need to be PyTrees such that each leaf of the tree can be made to have Woodbury structure?

Copy link
Owner

@patrick-kidger patrick-kidger Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so! Basically what's going on here is exactly what jax.jacfwd does as well.
Consider for example:

import jax
import jax.numpy as jnp

def f(x: tuple[jax.Array, jax.Array]):
    return {"a": x[0] * x[1], "b": x[0]}

x = (jnp.arange(2.)[:, None], jnp.arange(3.)[None, :])
jac = jax.jacfwd(f)(x)

What should we get?

Well, a PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together), and the Jacobian of a function f: R^n -> R^m is a matrix of shape (m, n).

Reasoning by analogy, we can see that:

  • given an input PyTree whose leaves are enumerated by i, and for which each leaf has shape a_i (a tuple);
  • given an output PyTree whose leaves are enumerated by j, and for which each leaf has shape b_j (also a tuple);
    then the Jacobian should be a PyTree whose leaves are numerated by (j, i) (a PyTree-of-PyTrees if you will), where each leaf has shape (*b_j, *a_i). (Here unpacking each tuple using Python notation.)

And indeed this is exactly what we see:

import equinox as eqx
eqx.tree_pprint(jac)
# {'a': (f32[2,3,2,1], f32[2,3,1,3]), 'b': (f32[2,1,2,1], f32[2,1,1,3])}

the "outer" PyTree has structure {'a': *, 'b': *} (corresponding to the output of our function), the "inner" PyTree has structure (*, *) (corresponding to the input of our function).

Meanwhile, each leaf has a shape obtained by concatenating the shapes of the corresponding pair of input and output leaves. For all possible pairs, notably! So in our use case here we wouldn't have a pytree-of-things-with-Woodbury-structure. Rather, we would have a single PyTree, which when thought of as a linear operator (much like the Jacobian), would itself have Woodbury structure!


Okay, hopefully that makes some kind of sense!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thought I'd response on this since it has been a bit - thanks for the detailed response, it makes sense, I am (slowly) working on changes that would make the Woodbury implementation pytree compatible.

I think the size checking is easy enough if I have understood you correctly here (PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together). When checking U and V we need to use the in_size of A and C for N and K.

There will need to be a few tree_unflatten's to move between the flattened space (where U and V live) and the pytree input space (where A and C potentially live). This makes the pushthrough operator a bit tricky but should be do-able with a little time.

I suppose my question would be, is there a nice way to wrap this interlink between flattened vector space and pytree space so that implementing this kind of thing will be easier in the future? Does it already exist somewhere outside (or potentially inside) lineax?

if self.U.shape != (N, K):
raise ValueError("U does not have consistent shape with A and C")
if self.V.shape != (K, N):
raise ValueError("V does not have consistent shape with A and C")

def mv(self, vector):
Ax = self.A.mv(vector)
UCVx = self.U @ (self.C @ (self.V @ vector))
return Ax + UCVx

def as_matrix(self):
matrix = self.A.as_matrix() + self.U @ (self.C @ self.V)
return matrix

def transpose(self):
return WoodburyLinearOperator(
self.A.transpose(),
jnp.transpose(self.C),
jnp.transpose(self.V),
jnp.transpose(self.U),
)

def in_structure(self):
return self.A.in_structure()

def out_structure(self):
return self.A.out_structure()


#
# All operators below here are private to lineax.
#
Expand Down Expand Up @@ -1207,6 +1274,7 @@ def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
@linearise.register(IdentityLinearOperator)
@linearise.register(DiagonalLinearOperator)
@linearise.register(TridiagonalLinearOperator)
@linearise.register(WoodburyLinearOperator)
def _(operator):
return operator

Expand Down Expand Up @@ -1283,6 +1351,7 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
@materialise.register(IdentityLinearOperator)
@materialise.register(DiagonalLinearOperator)
@materialise.register(TridiagonalLinearOperator)
@materialise.register(WoodburyLinearOperator)
def _(operator):
return operator

Expand Down Expand Up @@ -1343,6 +1412,7 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:


@diagonal.register(MatrixLinearOperator)
@diagonal.register(WoodburyLinearOperator)
@diagonal.register(PyTreeLinearOperator)
@diagonal.register(JacobianLinearOperator)
@diagonal.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1397,6 +1467,7 @@ def tridiagonal(


@tridiagonal.register(MatrixLinearOperator)
@tridiagonal.register(WoodburyLinearOperator)
@tridiagonal.register(PyTreeLinearOperator)
@tridiagonal.register(JacobianLinearOperator)
@tridiagonal.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1477,6 +1548,7 @@ def _(operator):


@is_symmetric.register(TridiagonalLinearOperator)
@is_symmetric.register(WoodburyLinearOperator)
def _(operator):
return False

Expand Down Expand Up @@ -1516,6 +1588,7 @@ def _(operator):
return True


@is_diagonal.register(WoodburyLinearOperator)
@is_diagonal.register(TridiagonalLinearOperator)
def _(operator):
return False
Expand Down Expand Up @@ -1557,6 +1630,11 @@ def _(operator):
return True


@is_tridiagonal.register(WoodburyLinearOperator)
def _(operator):
return False


# has_unit_diagonal


Expand Down Expand Up @@ -1591,6 +1669,7 @@ def _(operator):
return True


@has_unit_diagonal.register(WoodburyLinearOperator)
@has_unit_diagonal.register(DiagonalLinearOperator)
@has_unit_diagonal.register(TridiagonalLinearOperator)
def _(operator):
Expand Down Expand Up @@ -1633,6 +1712,7 @@ def _(operator):
return True


@is_lower_triangular.register(WoodburyLinearOperator)
@is_lower_triangular.register(TridiagonalLinearOperator)
def _(operator):
return False
Expand Down Expand Up @@ -1673,6 +1753,7 @@ def _(operator):
return True


@is_upper_triangular.register(WoodburyLinearOperator)
@is_upper_triangular.register(TridiagonalLinearOperator)
def _(operator):
return False
Expand Down Expand Up @@ -1712,6 +1793,7 @@ def _(operator):
return True


@is_positive_semidefinite.register(WoodburyLinearOperator)
@is_positive_semidefinite.register(DiagonalLinearOperator)
@is_positive_semidefinite.register(TridiagonalLinearOperator)
def _(operator):
Expand Down Expand Up @@ -1753,6 +1835,7 @@ def _(operator):
return False


@is_negative_semidefinite.register(WoodburyLinearOperator)
@is_negative_semidefinite.register(DiagonalLinearOperator)
@is_negative_semidefinite.register(TridiagonalLinearOperator)
def _(operator):
Expand Down
9 changes: 8 additions & 1 deletion lineax/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_upper_triangular,
linearise,
TangentLinearOperator,
WoodburyLinearOperator,
)
from ._solution import RESULTS, Solution

Expand Down Expand Up @@ -268,7 +269,7 @@ def _linear_solve_transpose(inputs, cts_out):
_assert_defined, (operator, state, options, solver), is_leaf=_is_undefined
)
cts_solution = jtu.tree_map(
ft.partial(eqxi.materialise_zeros, allow_struct=True),
ft.partial(eqxi.materialise_zeros, allow_struct=True), # pyright: ignore
operator.in_structure(),
cts_solution,
)
Expand Down Expand Up @@ -498,6 +499,7 @@ def conj(
_cholesky_token = eqxi.str2jax("cholesky_token")
_lu_token = eqxi.str2jax("lu_token")
_svd_token = eqxi.str2jax("svd_token")
_woodbury_token = eqxi.str2jax("woodbury_token")


# Ugly delayed import because we have the dependency chain
Expand All @@ -518,6 +520,7 @@ def _lookup(token) -> AbstractLinearSolver:
_cholesky_token: _solver.Cholesky(), # pyright: ignore
_lu_token: _solver.LU(), # pyright: ignore
_svd_token: _solver.SVD(), # pyright: ignore
_woodbury_token: _solver.Woodbury(), # pyright: ignore
}
return _lookup_dict[token]

Expand All @@ -535,6 +538,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True
- If the operator is triangular, then use [`lineax.Triangular`][].
- If the matrix is positive or negative definite, then use
[`lineax.Cholesky`][].
- If the matrix has structure A + U C V, then use [`lineax.Woodbury`][].
- Else use [`lineax.LU`][].

This is a good choice if you want to be certain that an error is raised for
Expand All @@ -554,6 +558,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True
- If the operator is triangular, then use [`lineax.Triangular`][].
- If the matrix is positive or negative definite, then use
[`lineax.Cholesky`][].
- If the matrix has structure A + U C V, then use [`lineax.Woodbury`][].
- Else, use [`lineax.LU`][].

This is a good choice if your primary concern is computational efficiency. It will
Expand Down Expand Up @@ -582,6 +587,8 @@ def _select_solver(self, operator: AbstractLinearOperator):
operator
):
token = _cholesky_token
elif isinstance(operator, WoodburyLinearOperator):
token = _woodbury_token
else:
token = _lu_token
elif self.well_posed is False:
Expand Down
1 change: 1 addition & 0 deletions lineax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .svd import SVD as SVD
from .triangular import Triangular as Triangular
from .tridiagonal import Tridiagonal as Tridiagonal
from .woodbury import Woodbury as Woodbury
143 changes: 143 additions & 0 deletions lineax/_solver/woodbury.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Any
from typing_extensions import TypeAlias

import jax
import jax.numpy as jnp
from jaxtyping import Array, PyTree

from .._operator import (
AbstractLinearOperator,
MatrixLinearOperator,
WoodburyLinearOperator,
)
from .._solution import RESULTS
from .._solve import AbstractLinearSolver, AutoLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)


_WoodburyState: TypeAlias = tuple[
tuple[Array, Array, Array],
tuple[AbstractLinearSolver, Any, AbstractLinearSolver, Any],
PackedStructures,
]


def _compute_pushthrough(
A_solver: AbstractLinearSolver, A_state: Any, C: Array, U: Array, V: Array
) -> tuple[AbstractLinearSolver, Any]:
# Push through ( C^-1 + V A^-1 U) y = x
vmapped_solve = jax.vmap(
lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1
)
pushthrough_mat = jnp.linalg.inv(C) + V @ vmapped_solve(U)
pushthrough_op = MatrixLinearOperator(pushthrough_mat)
solver = AutoLinearSolver(well_posed=True).select_solver(pushthrough_op)
state = solver.init(pushthrough_op, {})
return solver, state


class Woodbury(AbstractLinearSolver[_WoodburyState]):
"""Solving system using Woodbury matrix identity"""

def init(
self,
operator: AbstractLinearOperator,
options: dict[str, Any],
A_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True),
):
del options
if not isinstance(operator, WoodburyLinearOperator):
raise ValueError(
"`Woodbury` may only be used for linear solves with A + U C V structure"
)
else:
A, C, U, V = operator.A, operator.C, operator.U, operator.V # pyright: ignore
if A.in_size() != A.out_size():
raise ValueError("""A must be square""")
# Find correct solvers and init for A
A_state = A_solver.init(A, {})
# Compute pushthrough operator
pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V)
return (
(C, U, V),
(A_solver, A_state, pt_solver, pt_state),
pack_structures(A),
)

def compute(
self,
state: _WoodburyState,
vector,
options,
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
(
(C, U, V),
(A_solver, A_state, pt_solver, pt_state),
A_packed_structures,
) = state
del state, options
vector = ravel_vector(vector, A_packed_structures)

# Solution to A x = b
# [0] selects the solution vector
x_1 = A_solver.compute(A_state, vector, {})[0]
# Push through U ( C^-1 + V A^-1 U)^-1 V (A^-1 b)
# [0] selects the solution vector
x_pushthrough = U @ pt_solver.compute(pt_state, V @ x_1, {})[0]
# A^-1 on result of push through
# [0] selects the solution vector
x_2 = A_solver.compute(A_state, x_pushthrough, {})[0]
# See https://en.wikipedia.org/wiki/Woodbury_matrix_identity
solution = x_1 - x_2

solution = unravel_solution(solution, A_packed_structures)
return solution, RESULTS.successful, {}

def transpose(self, state: _WoodburyState, options: dict[str, Any]):
(
(C, U, V),
(A_solver, A_state, pt_solver, pt_state),
A_packed_structures,
) = state
transposed_packed_structures = transpose_packed_structures(A_packed_structures)
C = jnp.transpose(C)
U = jnp.transpose(V)
V = jnp.transpose(U)
A_state, _ = A_solver.transpose(A_state, {})
pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V)
transpose_state = (
(C, U, V),
(A_solver, A_state, pt_solver, pt_state),
transposed_packed_structures,
)
return transpose_state, options

def conj(self, state: _WoodburyState, options: dict[str, Any]):
(
(C, U, V),
(A_solver, A_state, pt_solver, pt_state),
packed_structures,
) = state
C = jnp.conj(C)
U = jnp.conj(U)
V = jnp.conj(V)
A_state, _ = A_solver.conj(A_state, {})
pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V)
conj_state = (
(C, U, V),
(A_solver, A_state, pt_solver, pt_state),
packed_structures,
)
return conj_state, options

def allow_dependent_columns(self, operator):
return False

def allow_dependent_rows(self, operator):
return False
Loading