add stable forward/inverse memory efficient Wigner transforms #238
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR extends the Risbo precompute Wigner transform to the setting where we instead ONLY compute the Wigner d coefficients for$\theta = \beta = \pi/2$ which we denote $\Delta^l_{mn} = d^l_{mn}(\pi/2)$ . The coefficients for colatitudes other than $\pi/2$ are accessed implicitly at run-time through the Fourier relation outlined by this McEwen et al 2016 with further details in this McEwen et al 2011.
Ultimately this algorithm costs an extra FFT at run-time but reduces the memory overhead from$\mathcal{O}(NL^3)\sim\mathcal{O}(L^4)$ to $\mathcal{O}(L^3)$ . A further reduction in memory by a factor of 8 can be achieved by exploiting the various symmetries of $\Delta^l_{mn}$ , however this has yet to be implemented.
As the Risbo recursion is extremely stable we can also consider reduced precision computation (f32, f16, f8) which should recover order ($10^{-8}, 10^{-4}, 10^{-2}$ ) respectively with corresponding memory reduction and acceleration by a factor of ($2, 4, 8$ ) respectively.
Tip
When building flax layers for$SO(3)$ convolutions using this implementation the dominant memory overhead is entirely independent from the number of channels, so the memory requirement for the precompute transform is fixed.
Note
With both the memory reductions from aforementioned symmetry relations and from the switched algorithm, the > precompute algorithm should be runnable at$L \sim 2048$ and potentially slightly beyond. However, this has yet to be tested explicitly.