diff --git a/lineax/__init__.py b/lineax/__init__.py index f7c5ee9..b804142 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -45,6 +45,7 @@ TangentLinearOperator as TangentLinearOperator, tridiagonal as tridiagonal, TridiagonalLinearOperator as TridiagonalLinearOperator, + WoodburyLinearOperator as WoodburyLinearOperator, ) from ._solution import RESULTS as RESULTS, Solution as Solution from ._solve import ( diff --git a/lineax/_operator.py b/lineax/_operator.py index ad1e4f5..585063d 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -906,6 +906,73 @@ def out_structure(self): return self.operator.out_structure() +class WoodburyLinearOperator(AbstractLinearOperator, strict=True): + """As [`lineax.MatrixLinearOperator`][], but for specifically a matrix + with A + U C V structure, such that the Woodbury identity can be used""" + + A: AbstractLinearOperator + C: Inexact[Array, " k k"] + U: Inexact[Array, " n k"] + V: Inexact[Array, " k n"] + + def __init__( + self, + A: AbstractLinearOperator, + C: Inexact[Array, " k k"], + U: Inexact[Array, " n k"], + V: Inexact[Array, " k n"], + ): + """**Arguments:** + + Matrix of form A + U C V, such that the inverse can be computed + using Woodbury matrix identity + + - `A`: Linear operator, in/out shape (n,n) + - `C`: A rank-two JAX array. Shape (k,k) + - `U`: A rank-two JAX array. Shape (n,k) + - `V`: A rank-two JAX array. Shape (k,n) + + """ + self.A = A + self.C = inexact_asarray(C) + self.U = inexact_asarray(U) + self.V = inexact_asarray(V) + (N, M) = self.A.in_structure(), self.A.out_structure() + if not eqx.tree_equal(N, M): + raise ValueError(f"expecting square operator for A, got {N} by {M}") + (K, L) = self.C.shape + if K != L: + raise ValueError(f"expecting square operator for C, got {K} by {L}") + N = N.shape[0] + if self.U.shape != (N, K): + raise ValueError("U does not have consistent shape with A and C") + if self.V.shape != (K, N): + raise ValueError("V does not have consistent shape with A and C") + + def mv(self, vector): + Ax = self.A.mv(vector) + UCVx = self.U @ (self.C @ (self.V @ vector)) + return Ax + UCVx + + def as_matrix(self): + matrix = self.A.as_matrix() + self.U @ (self.C @ self.V) + return matrix + + def transpose(self): + return WoodburyLinearOperator( + self.A.transpose(), + jnp.transpose(self.C), + jnp.transpose(self.V), + jnp.transpose(self.U), + ) + + def in_structure(self): + return self.A.in_structure() + + def out_structure(self): + return self.A.out_structure() + + # # All operators below here are private to lineax. # @@ -1207,6 +1274,7 @@ def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator: @linearise.register(IdentityLinearOperator) @linearise.register(DiagonalLinearOperator) @linearise.register(TridiagonalLinearOperator) +@linearise.register(WoodburyLinearOperator) def _(operator): return operator @@ -1283,6 +1351,7 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator: @materialise.register(IdentityLinearOperator) @materialise.register(DiagonalLinearOperator) @materialise.register(TridiagonalLinearOperator) +@materialise.register(WoodburyLinearOperator) def _(operator): return operator @@ -1343,6 +1412,7 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]: @diagonal.register(MatrixLinearOperator) +@diagonal.register(WoodburyLinearOperator) @diagonal.register(PyTreeLinearOperator) @diagonal.register(JacobianLinearOperator) @diagonal.register(FunctionLinearOperator) @@ -1397,6 +1467,7 @@ def tridiagonal( @tridiagonal.register(MatrixLinearOperator) +@tridiagonal.register(WoodburyLinearOperator) @tridiagonal.register(PyTreeLinearOperator) @tridiagonal.register(JacobianLinearOperator) @tridiagonal.register(FunctionLinearOperator) @@ -1477,6 +1548,7 @@ def _(operator): @is_symmetric.register(TridiagonalLinearOperator) +@is_symmetric.register(WoodburyLinearOperator) def _(operator): return False @@ -1516,6 +1588,7 @@ def _(operator): return True +@is_diagonal.register(WoodburyLinearOperator) @is_diagonal.register(TridiagonalLinearOperator) def _(operator): return False @@ -1557,6 +1630,11 @@ def _(operator): return True +@is_tridiagonal.register(WoodburyLinearOperator) +def _(operator): + return False + + # has_unit_diagonal @@ -1591,6 +1669,7 @@ def _(operator): return True +@has_unit_diagonal.register(WoodburyLinearOperator) @has_unit_diagonal.register(DiagonalLinearOperator) @has_unit_diagonal.register(TridiagonalLinearOperator) def _(operator): @@ -1633,6 +1712,7 @@ def _(operator): return True +@is_lower_triangular.register(WoodburyLinearOperator) @is_lower_triangular.register(TridiagonalLinearOperator) def _(operator): return False @@ -1673,6 +1753,7 @@ def _(operator): return True +@is_upper_triangular.register(WoodburyLinearOperator) @is_upper_triangular.register(TridiagonalLinearOperator) def _(operator): return False @@ -1712,6 +1793,7 @@ def _(operator): return True +@is_positive_semidefinite.register(WoodburyLinearOperator) @is_positive_semidefinite.register(DiagonalLinearOperator) @is_positive_semidefinite.register(TridiagonalLinearOperator) def _(operator): @@ -1753,6 +1835,7 @@ def _(operator): return False +@is_negative_semidefinite.register(WoodburyLinearOperator) @is_negative_semidefinite.register(DiagonalLinearOperator) @is_negative_semidefinite.register(TridiagonalLinearOperator) def _(operator): diff --git a/lineax/_solve.py b/lineax/_solve.py index 36fad0b..f2fa197 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -42,6 +42,7 @@ is_upper_triangular, linearise, TangentLinearOperator, + WoodburyLinearOperator, ) from ._solution import RESULTS, Solution @@ -268,7 +269,7 @@ def _linear_solve_transpose(inputs, cts_out): _assert_defined, (operator, state, options, solver), is_leaf=_is_undefined ) cts_solution = jtu.tree_map( - ft.partial(eqxi.materialise_zeros, allow_struct=True), + ft.partial(eqxi.materialise_zeros, allow_struct=True), # pyright: ignore operator.in_structure(), cts_solution, ) @@ -498,6 +499,7 @@ def conj( _cholesky_token = eqxi.str2jax("cholesky_token") _lu_token = eqxi.str2jax("lu_token") _svd_token = eqxi.str2jax("svd_token") +_woodbury_token = eqxi.str2jax("woodbury_token") # Ugly delayed import because we have the dependency chain @@ -518,6 +520,7 @@ def _lookup(token) -> AbstractLinearSolver: _cholesky_token: _solver.Cholesky(), # pyright: ignore _lu_token: _solver.LU(), # pyright: ignore _svd_token: _solver.SVD(), # pyright: ignore + _woodbury_token: _solver.Woodbury(), # pyright: ignore } return _lookup_dict[token] @@ -535,6 +538,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative definite, then use [`lineax.Cholesky`][]. + - If the matrix has structure A + U C V, then use [`lineax.Woodbury`][]. - Else use [`lineax.LU`][]. This is a good choice if you want to be certain that an error is raised for @@ -554,6 +558,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True - If the operator is triangular, then use [`lineax.Triangular`][]. - If the matrix is positive or negative definite, then use [`lineax.Cholesky`][]. + - If the matrix has structure A + U C V, then use [`lineax.Woodbury`][]. - Else, use [`lineax.LU`][]. This is a good choice if your primary concern is computational efficiency. It will @@ -582,6 +587,8 @@ def _select_solver(self, operator: AbstractLinearOperator): operator ): token = _cholesky_token + elif isinstance(operator, WoodburyLinearOperator): + token = _woodbury_token else: token = _lu_token elif self.well_posed is False: diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index 425fc40..53d400e 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -22,3 +22,4 @@ from .svd import SVD as SVD from .triangular import Triangular as Triangular from .tridiagonal import Tridiagonal as Tridiagonal +from .woodbury import Woodbury as Woodbury diff --git a/lineax/_solver/woodbury.py b/lineax/_solver/woodbury.py new file mode 100644 index 0000000..e23c26c --- /dev/null +++ b/lineax/_solver/woodbury.py @@ -0,0 +1,143 @@ +from typing import Any +from typing_extensions import TypeAlias + +import jax +import jax.numpy as jnp +from jaxtyping import Array, PyTree + +from .._operator import ( + AbstractLinearOperator, + MatrixLinearOperator, + WoodburyLinearOperator, +) +from .._solution import RESULTS +from .._solve import AbstractLinearSolver, AutoLinearSolver +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + transpose_packed_structures, + unravel_solution, +) + + +_WoodburyState: TypeAlias = tuple[ + tuple[Array, Array, Array], + tuple[AbstractLinearSolver, Any, AbstractLinearSolver, Any], + PackedStructures, +] + + +def _compute_pushthrough( + A_solver: AbstractLinearSolver, A_state: Any, C: Array, U: Array, V: Array +) -> tuple[AbstractLinearSolver, Any]: + # Push through ( C^-1 + V A^-1 U) y = x + vmapped_solve = jax.vmap( + lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1 + ) + pushthrough_mat = jnp.linalg.inv(C) + V @ vmapped_solve(U) + pushthrough_op = MatrixLinearOperator(pushthrough_mat) + solver = AutoLinearSolver(well_posed=True).select_solver(pushthrough_op) + state = solver.init(pushthrough_op, {}) + return solver, state + + +class Woodbury(AbstractLinearSolver[_WoodburyState]): + """Solving system using Woodbury matrix identity""" + + def init( + self, + operator: AbstractLinearOperator, + options: dict[str, Any], + A_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True), + ): + del options + if not isinstance(operator, WoodburyLinearOperator): + raise ValueError( + "`Woodbury` may only be used for linear solves with A + U C V structure" + ) + else: + A, C, U, V = operator.A, operator.C, operator.U, operator.V # pyright: ignore + if A.in_size() != A.out_size(): + raise ValueError("""A must be square""") + # Find correct solvers and init for A + A_state = A_solver.init(A, {}) + # Compute pushthrough operator + pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) + return ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + pack_structures(A), + ) + + def compute( + self, + state: _WoodburyState, + vector, + options, + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + A_packed_structures, + ) = state + del state, options + vector = ravel_vector(vector, A_packed_structures) + + # Solution to A x = b + # [0] selects the solution vector + x_1 = A_solver.compute(A_state, vector, {})[0] + # Push through U ( C^-1 + V A^-1 U)^-1 V (A^-1 b) + # [0] selects the solution vector + x_pushthrough = U @ pt_solver.compute(pt_state, V @ x_1, {})[0] + # A^-1 on result of push through + # [0] selects the solution vector + x_2 = A_solver.compute(A_state, x_pushthrough, {})[0] + # See https://en.wikipedia.org/wiki/Woodbury_matrix_identity + solution = x_1 - x_2 + + solution = unravel_solution(solution, A_packed_structures) + return solution, RESULTS.successful, {} + + def transpose(self, state: _WoodburyState, options: dict[str, Any]): + ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + A_packed_structures, + ) = state + transposed_packed_structures = transpose_packed_structures(A_packed_structures) + C = jnp.transpose(C) + U = jnp.transpose(V) + V = jnp.transpose(U) + A_state, _ = A_solver.transpose(A_state, {}) + pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) + transpose_state = ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + transposed_packed_structures, + ) + return transpose_state, options + + def conj(self, state: _WoodburyState, options: dict[str, Any]): + ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + packed_structures, + ) = state + C = jnp.conj(C) + U = jnp.conj(U) + V = jnp.conj(V) + A_state, _ = A_solver.conj(A_state, {}) + pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) + conj_state = ( + (C, U, V), + (A_solver, A_state, pt_solver, pt_state), + packed_structures, + ) + return conj_state, options + + def allow_dependent_columns(self, operator): + return False + + def allow_dependent_rows(self, operator): + return False diff --git a/tests/test_operator.py b/tests/test_operator.py index e0e1e7a..284007e 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -141,6 +141,71 @@ def test_diagonal(dtype, getkey): assert jnp.allclose(lx.diagonal(operator), matrix_diag) +@pytest.mark.parametrize("dtype", (jnp.float64,)) +def test_Woodbury(dtype, getkey): + tol = 1e-4 + N = 20 + k = 2 + + A = jr.normal(getkey(), (N, N), dtype=dtype) + C = jr.normal(getkey(), (k, k), dtype=dtype) + U = jr.normal(getkey(), (N, k), dtype=dtype) + V = jr.normal(getkey(), (k, N), dtype=dtype) + + # Full matrix A + WB = lx.WoodburyLinearOperator(lx.MatrixLinearOperator(A), C, U, V) + + full_matrix = WB.as_matrix() + + true_x = jr.normal(getkey(), (N,)) + b = full_matrix @ true_x + b_WB = WB.mv(true_x) + + assert tree_allclose(b, b_WB, atol=tol, rtol=tol) + + WB_soln = lx.linear_solve(WB, b) + LU_soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + + # Tridiagonal matrix A + diagonal = jnp.diagonal(A, offset=0) + upper_diagonal = jnp.diagonal(A, offset=1) + lower_diagonal = jnp.diagonal(A, offset=-1) + WB = lx.WoodburyLinearOperator( + lx.TridiagonalLinearOperator(diagonal, lower_diagonal, upper_diagonal), C, U, V + ) + + full_matrix = WB.as_matrix() + + true_x = jr.normal(getkey(), (N,)) + b = full_matrix @ true_x + b_WB = WB.mv(true_x) + + assert tree_allclose(b, b_WB, atol=tol, rtol=tol) + + WB_soln = lx.linear_solve(WB, b) + LU_soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + + # Diagonal matrix A + WB = lx.WoodburyLinearOperator(lx.DiagonalLinearOperator(diagonal), C, U, V) + + full_matrix = WB.as_matrix() + + true_x = jr.normal(getkey(), (N,)) + b = full_matrix @ true_x + b_WB = WB.mv(true_x) + + assert tree_allclose(b, b_WB, atol=tol, rtol=tol) + + WB_soln = lx.linear_solve(WB, b) + LU_soln = lx.linear_solve(lx.MatrixLinearOperator(full_matrix), b) + + assert tree_allclose(WB_soln.value, LU_soln.value, atol=tol, rtol=tol) + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_symmetric(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype)