Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New solver based on Woodbury matrix identity #97

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

aidancrilly
Copy link

In issue #3 , a solver based on Woodbury matrix identity was given as a new solver feature for lineax.

Here I have made a first attempt at implementing it. It works slightly differently than the other base LinearOperators as it takes a LinearOperator as an argument (for A), with U, C and V as JAX arrays. This allows for the correct specialised solvers to be used for inverse(A) operations. E.g. A can be a DiagonalLinearOperator/TridiagonalLinearOperator/etc. - see test_Woodbury in test_operator.py

Things I am not certain of:

  1. Is passing solvers and states around in the SolverState of Woodbury the best approach. It would seem to me to minimise re-initialising solvers for multiple RHS. However perhaps there is a neater solution.
  2. I have not included a tag for Woodbury, it seems difficult to identify if matrices have this structure so it seems to me it should be left to the user. However, I am not sure if I have handled the operation-type linear operators (Add,Mul,Div,etc.) correctly in this instance.
  3. I had to add a pyright ignore in _solve.py under _linear_solve_transpose for allow_struct, not sure why.
ft.partial(eqxi.materialise_zeros, allow_struct=True), # pyright: ignore

But it made me notice that materialise_zeros is from equinox 0.11.4 but pyproject.toml requires >=0.11.3. Does this require correcting?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you for putting this together! I really like the look of this.

I've just left an initial review -- let me know what you think :)

lineax/_operator.py Outdated Show resolved Hide resolved
lineax/_operator.py Outdated Show resolved Hide resolved
(K, L) = self.C.shape
if K != L:
raise ValueError(f"expecting square operator for C, got {K} by {L}")
N = N.shape[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N is an arbitrary PyTree, it doesn't necessarily have a .shape attribute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, This is mostly my ignorance of how lineax is using PyTrees as linear operators.

I am not sure how allowing arbitrary PyTrees for A meshes with the requirement of a Woodbury structure. In the context of the Woodbury matrix identity, I would think A would have to have a matrix representation of an n by n matrix. For a PyTree representation then would C, U and V also need to be PyTrees such that each leaf of the tree can be made to have Woodbury structure?

Copy link
Owner

@patrick-kidger patrick-kidger Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so! Basically what's going on here is exactly what jax.jacfwd does as well.
Consider for example:

import jax
import jax.numpy as jnp

def f(x: tuple[jax.Array, jax.Array]):
    return {"a": x[0] * x[1], "b": x[0]}

x = (jnp.arange(2.)[:, None], jnp.arange(3.)[None, :])
jac = jax.jacfwd(f)(x)

What should we get?

Well, a PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together), and the Jacobian of a function f: R^n -> R^m is a matrix of shape (m, n).

Reasoning by analogy, we can see that:

  • given an input PyTree whose leaves are enumerated by i, and for which each leaf has shape a_i (a tuple);
  • given an output PyTree whose leaves are enumerated by j, and for which each leaf has shape b_j (also a tuple);
    then the Jacobian should be a PyTree whose leaves are numerated by (j, i) (a PyTree-of-PyTrees if you will), where each leaf has shape (*b_j, *a_i). (Here unpacking each tuple using Python notation.)

And indeed this is exactly what we see:

import equinox as eqx
eqx.tree_pprint(jac)
# {'a': (f32[2,3,2,1], f32[2,3,1,3]), 'b': (f32[2,1,2,1], f32[2,1,1,3])}

the "outer" PyTree has structure {'a': *, 'b': *} (corresponding to the output of our function), the "inner" PyTree has structure (*, *) (corresponding to the input of our function).

Meanwhile, each leaf has a shape obtained by concatenating the shapes of the corresponding pair of input and output leaves. For all possible pairs, notably! So in our use case here we wouldn't have a pytree-of-things-with-Woodbury-structure. Rather, we would have a single PyTree, which when thought of as a linear operator (much like the Jacobian), would itself have Woodbury structure!


Okay, hopefully that makes some kind of sense!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thought I'd response on this since it has been a bit - thanks for the detailed response, it makes sense, I am (slowly) working on changes that would make the Woodbury implementation pytree compatible.

I think the size checking is easy enough if I have understood you correctly here (PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together). When checking U and V we need to use the in_size of A and C for N and K.

There will need to be a few tree_unflatten's to move between the flattened space (where U and V live) and the pytree input space (where A and C potentially live). This makes the pushthrough operator a bit tricky but should be do-able with a little time.

I suppose my question would be, is there a nice way to wrap this interlink between flattened vector space and pytree space so that implementing this kind of thing will be easier in the future? Does it already exist somewhere outside (or potentially inside) lineax?

lineax/_operator.py Outdated Show resolved Hide resolved
lineax/_operator.py Outdated Show resolved Hide resolved
lineax/_solver/Woodbury.py Outdated Show resolved Hide resolved
lineax/_solver/Woodbury.py Outdated Show resolved Hide resolved
lineax/_solver/Woodbury.py Outdated Show resolved Hide resolved
vmapped_solve = jax.vmap(
lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1
)
pushthrough_mat = jnp.linalg.inv(C) + V @ vmapped_solve(U)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make C an AbstractLinearOperator as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had C as an AbstractLinearOperator but it always had to be as_matrix() so its inverse can be computed by jnp.linalg. Therefore, it made the most sense that it is given as a matrix at input so that as_matrix operation is not hidden. If some linearoperators could have an inverse method then perhaps it would make sense for C to be an operator?

Copy link
Owner

@patrick-kidger patrick-kidger Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right!

Yes, you're completely correct. Indeed I think the appropriate thing to do would be to have a built-in notion of the inverse of an operator. (Tagging also @f0uriest as we were discussing this in #96.)

The inverse of an operator is another operator, and it is basically defined by the action of a linear solve. I think that should mean:

class InverseLinearOperator(AbstractLinearOperator):
    operator: AbstractLinearOperator
    solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True)

    def mv(self, vector):
        return linear_solve(self.operator, vector)

   ...  # other methods here

Which you can then use as InverseLinearOperator(C).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this seems fair. I tried to implement this way and just wanted to raise a two points which you might have comments on.

  1. The InverseLinearOperator has to go in _solve.py (or at least not _operator.py) to avoid circular imports. This is potentially a little confusing but maybe it is really the right place for it.
  2. Having no conjugate method for LinearOperators makes for a bit of head-ache for this one. conj is attached to the solver so you can update the state of the solver for a conjugate inside InverseLinearOperator (similarly for transpose). However, when you do it this way your operator is not updated. The most obvious way to me to implement the as_matrix method was to use jnp.linalg.inv(self.operator.as_matrix()) - but this ignores the solver_state which might have been transposed or conjugated. I suppose my question is, should LinearOperators have a conjugate method or should LinearSolvers have an as_matrix method for the inverse?

This probably should be its own issue, PR, etc.

lineax/_solver/Woodbury.py Outdated Show resolved Hide resolved

full_matrix = WB.as_matrix()

true_x = jr.normal(getkey(), (N,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
true_x = jr.normal(getkey(), (N,))
true_x = jr.normal(getkey(), (N,), dtype=dtype)

and in two other places too, and then you can add complex128 to type parametrization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants