-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New solver based on Woodbury matrix identity #97
Open
aidancrilly
wants to merge
4
commits into
patrick-kidger:main
Choose a base branch
from
aidancrilly:Woodbury
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from typing import Any | ||
from typing_extensions import TypeAlias | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, PyTree | ||
|
||
from .._operator import ( | ||
AbstractLinearOperator, | ||
MatrixLinearOperator, | ||
WoodburyLinearOperator, | ||
) | ||
from .._solution import RESULTS | ||
from .._solve import AbstractLinearSolver, AutoLinearSolver | ||
from .misc import ( | ||
pack_structures, | ||
PackedStructures, | ||
ravel_vector, | ||
transpose_packed_structures, | ||
unravel_solution, | ||
) | ||
|
||
|
||
_WoodburyState: TypeAlias = tuple[ | ||
tuple[Array, Array, Array], | ||
tuple[AbstractLinearSolver, Any, AbstractLinearSolver, Any], | ||
PackedStructures, | ||
] | ||
|
||
|
||
def _compute_pushthrough( | ||
A_solver: AbstractLinearSolver, A_state: Any, C: Array, U: Array, V: Array | ||
) -> tuple[AbstractLinearSolver, Any]: | ||
# Push through ( C^-1 + V A^-1 U) y = x | ||
vmapped_solve = jax.vmap( | ||
lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1 | ||
) | ||
pushthrough_mat = jnp.linalg.inv(C) + V @ vmapped_solve(U) | ||
pushthrough_op = MatrixLinearOperator(pushthrough_mat) | ||
solver = AutoLinearSolver(well_posed=True).select_solver(pushthrough_op) | ||
state = solver.init(pushthrough_op, {}) | ||
return solver, state | ||
|
||
|
||
class Woodbury(AbstractLinearSolver[_WoodburyState]): | ||
"""Solving system using Woodbury matrix identity""" | ||
|
||
def init( | ||
self, | ||
operator: AbstractLinearOperator, | ||
options: dict[str, Any], | ||
A_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True), | ||
): | ||
del options | ||
if not isinstance(operator, WoodburyLinearOperator): | ||
raise ValueError( | ||
"`Woodbury` may only be used for linear solves with A + U C V structure" | ||
) | ||
else: | ||
A, C, U, V = operator.A, operator.C, operator.U, operator.V # pyright: ignore | ||
if A.in_size() != A.out_size(): | ||
raise ValueError("""A must be square""") | ||
# Find correct solvers and init for A | ||
A_state = A_solver.init(A, {}) | ||
# Compute pushthrough operator | ||
pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) | ||
return ( | ||
(C, U, V), | ||
(A_solver, A_state, pt_solver, pt_state), | ||
pack_structures(A), | ||
) | ||
|
||
def compute( | ||
self, | ||
state: _WoodburyState, | ||
vector, | ||
options, | ||
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: | ||
( | ||
(C, U, V), | ||
(A_solver, A_state, pt_solver, pt_state), | ||
A_packed_structures, | ||
) = state | ||
del state, options | ||
vector = ravel_vector(vector, A_packed_structures) | ||
|
||
# Solution to A x = b | ||
# [0] selects the solution vector | ||
x_1 = A_solver.compute(A_state, vector, {})[0] | ||
# Push through U ( C^-1 + V A^-1 U)^-1 V (A^-1 b) | ||
# [0] selects the solution vector | ||
x_pushthrough = U @ pt_solver.compute(pt_state, V @ x_1, {})[0] | ||
# A^-1 on result of push through | ||
# [0] selects the solution vector | ||
x_2 = A_solver.compute(A_state, x_pushthrough, {})[0] | ||
# See https://en.wikipedia.org/wiki/Woodbury_matrix_identity | ||
solution = x_1 - x_2 | ||
|
||
solution = unravel_solution(solution, A_packed_structures) | ||
return solution, RESULTS.successful, {} | ||
|
||
def transpose(self, state: _WoodburyState, options: dict[str, Any]): | ||
( | ||
(C, U, V), | ||
(A_solver, A_state, pt_solver, pt_state), | ||
A_packed_structures, | ||
) = state | ||
transposed_packed_structures = transpose_packed_structures(A_packed_structures) | ||
C = jnp.transpose(C) | ||
U = jnp.transpose(V) | ||
V = jnp.transpose(U) | ||
A_state, _ = A_solver.transpose(A_state, {}) | ||
pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) | ||
transpose_state = ( | ||
(C, U, V), | ||
(A_solver, A_state, pt_solver, pt_state), | ||
transposed_packed_structures, | ||
) | ||
return transpose_state, options | ||
|
||
def conj(self, state: _WoodburyState, options: dict[str, Any]): | ||
( | ||
(C, U, V), | ||
(A_solver, A_state, pt_solver, pt_state), | ||
packed_structures, | ||
) = state | ||
C = jnp.conj(C) | ||
U = jnp.conj(U) | ||
V = jnp.conj(V) | ||
A_state, _ = A_solver.conj(A_state, {}) | ||
pt_solver, pt_state = _compute_pushthrough(A_solver, A_state, C, U, V) | ||
conj_state = ( | ||
(C, U, V), | ||
(A_solver, A_state, pt_solver, pt_state), | ||
packed_structures, | ||
) | ||
return conj_state, options | ||
|
||
def allow_dependent_columns(self, operator): | ||
return False | ||
|
||
def allow_dependent_rows(self, operator): | ||
return False |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N
is an arbitrary PyTree, it doesn't necessarily have a.shape
attribute.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, This is mostly my ignorance of how lineax is using PyTrees as linear operators.
I am not sure how allowing arbitrary PyTrees for A meshes with the requirement of a Woodbury structure. In the context of the Woodbury matrix identity, I would think A would have to have a matrix representation of an n by n matrix. For a PyTree representation then would C, U and V also need to be PyTrees such that each leaf of the tree can be made to have Woodbury structure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so! Basically what's going on here is exactly what
jax.jacfwd
does as well.Consider for example:
What should we get?
Well, a PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together), and the Jacobian of a function
f: R^n -> R^m
is a matrix of shape(m, n)
.Reasoning by analogy, we can see that:
i
, and for which each leaf has shapea_i
(a tuple);j
, and for which each leaf has shapeb_j
(also a tuple);then the Jacobian should be a PyTree whose leaves are numerated by
(j, i)
(a PyTree-of-PyTrees if you will), where each leaf has shape(*b_j, *a_i)
. (Here unpacking each tuple using Python notation.)And indeed this is exactly what we see:
the "outer" PyTree has structure
{'a': *, 'b': *}
(corresponding to the output of our function), the "inner" PyTree has structure(*, *)
(corresponding to the input of our function).Meanwhile, each leaf has a shape obtained by concatenating the shapes of the corresponding pair of input and output leaves. For all possible pairs, notably! So in our use case here we wouldn't have a pytree-of-things-with-Woodbury-structure. Rather, we would have a single PyTree, which when thought of as a linear operator (much like the Jacobian), would itself have Woodbury structure!
Okay, hopefully that makes some kind of sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just thought I'd response on this since it has been a bit - thanks for the detailed response, it makes sense, I am (slowly) working on changes that would make the Woodbury implementation pytree compatible.
I think the size checking is easy enough if I have understood you correctly here (PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together). When checking U and V we need to use the in_size of A and C for N and K.
There will need to be a few tree_unflatten's to move between the flattened space (where U and V live) and the pytree input space (where A and C potentially live). This makes the pushthrough operator a bit tricky but should be do-able with a little time.
I suppose my question would be, is there a nice way to wrap this interlink between flattened vector space and pytree space so that implementing this kind of thing will be easier in the future? Does it already exist somewhere outside (or potentially inside) lineax?