Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched equivariant maps basis expansion (?) #76

Open
Danfoa opened this issue Aug 24, 2023 · 4 comments
Open

Batched equivariant maps basis expansion (?) #76

Danfoa opened this issue Aug 24, 2023 · 4 comments

Comments

@Danfoa
Copy link
Contributor

Danfoa commented Aug 24, 2023

Hi @Gabri95,

Do you see any easy way to enable a batched basis expansion of equivariant maps? What do I mean...

The process of construction of a linear equivariant map T from an array of w=weights (of the same dimension as the dimension of the basis of T) seems to be tailored for a single weight vector and a single resultant equivariant linear map. This is perfectly suitable for building the basis of linear layers.

However, it is not suitable for parametrically building several equivariant linear maps from a batched collection of weights (batch, dim(w)), resulting in batch equivariant linear operators. I tried my best to understand and devise a way to do this, but with the current implementation, it seems rather tricky.

When using the EMLP library, this was possible by finding the nullspace projector matrix Q [nxn, basis_dim], which we can use to project several weight vectors T =reshape(Q w) to their corresponding equiv linear matrices. This process had an immense memory complexity (because of the nxn: n being the dimension of the T, assuming squared T). I understand your approach is elegantly avoiding this memory complexity problem. Do you think of a way of making a batch version of your basis expansion?

@Gabri95
Copy link
Collaborator

Gabri95 commented Aug 24, 2023

hi @Danfoa

The single-block basis expansion and sampler classes could be used for that.
That's actually what I also internally do in the BlockBasisSampler class for example.

The external interface of the library (via the conv layers) does not directly support this, though.
Could you maybe provide a more detailed example of what you'd like to do, so I can suggest something more concrete or try to write some example of code?

For instance, do you need to compute a number of convolution kernels for an RdConv or do you want to run multiple RdPointConv in parallel? Or are you only interested in LinearLayers?

Best,
Gabriele

@Danfoa
Copy link
Contributor Author

Danfoa commented Aug 24, 2023

This sounds amazing thanks for the help!.

Let me describe my application case.

TLDR: I want to construct multiple equivariant linear maps T of shape [nxn]. We know that the basis of T is of dimension d. I don't want to learn this map, instead:

  1. I want to learn a function T(.): X -> R^d that parameterizes the linear maps T(x) \in R^(nxn), as a function of their input x of shape (batch, |x|).
  2. The output of the network of shape (batch, d) will be used to parameterize batch distinct equivariant maps, resulting in (batch, n, n).
  3. Then, I would like to apply the linear maps to each of the input vectors.

More details: I am learning equivariant dynamical systems with transition Operators. The nice thing about this approach is that if you find the appropriate non-linear change of coordinates x = f(z), the dynamics of your system become linear dx/dt = T(x)x | T(x) \in R^(nxn), instead of the potentially non-linear dynamics of z. Here, think of z as the state of your dynamical system (e.g., position and momentum) and x as a new "observable" state (e.g., a set of relevant functions of x, such as energy, polynomials, etc.). For equivariant systems, T(x) needs to be an equivariant linear map. And here is where I need to learn the function T(.): X -> R^d. Here d is the dimension of the space of endomorphisms X->X. Which is why your basis expansion has become so useful to me.

@Gabri95
Copy link
Collaborator

Gabri95 commented Aug 25, 2023

Hi @Danfoa

That sounds like a really cool application!

So, if you know in advance the size of batch, the simplest strategy you can use now is to generate a linear map of shape batch*n x n, and then reshaping it into batch, n, n.
You can just use a BlockBasisExpansion for expanding these weights.

I can make something a bit more flexible to achieve exactly what you want by removing this assert and just use the last dimension of weights.
I am not sure I have time to implement it properly right now, but you could try that yourself and open a PR maybe?

@Danfoa
Copy link
Contributor Author

Danfoa commented Aug 25, 2023

I can certainly try @Gabri95,

So, if you know in advance the size of batch, the simplest strategy you can use now is to generate a linear map of shape batch*n x n, and then reshaping it into batch, n, n.
You can just use a BlockBasisExpansion for expanding these weights.

I know the batch dimension, but I am a bit insecure about how to interact with the BlockBasisExpansion. The code is a bit hard to digest without investing a large amount of time on it. Any hints?

I can make something a bit more flexible to achieve exactly what you want by removing this assert and just use the last dimension of weights.
I am not sure I have time to implement it properly right now, but you could try that yourself and open a PR maybe?

I will give it a try. I think I already see the problem. It should not be difficult.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants