Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] eSCN representation #60

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions cuequivariance/cuequivariance/experimental/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,69 @@ def escn_tp_compact(


class SphericalSignal(cue.Rep):
def __init__(self, mul: int, l_max: int, m_max: int):
"""Representation of a signal on the sphere.

Args:
mul (int): Multiplicity of the signal. The multiplicity is always the innermost dimension (stride 1).
l_max (int): Maximum angular resolution.
m_max (int): Maximum angular resolution around the principal axis.
primary (str): "m" or "l".
If "m", all the components with the same m are contiguous in memory.
If "l", all the components with the same l are contiguous in memory.


primary="m":
l=
0 1 2 3 m
+-+-+ =
|0|1| -2
+-+-+-+
|2|3|4| -1
+-+-+-+-+
|5|6|7|8| 0
+-+-+-+-+
|9|a|b| 1
+-+-+-+
|c|d| 2
+-+-+

primary="l":
l= +-+
0 |0|
+-+-+-+
1 |1|2|3|
+-+-+-+-+-+
2 |4|5|6|7|8|
+-+-+-+-+-+
3 |9|a|b|c|d|
+-+-+-+-+-+
m= -2-1 0 1 2
"""

def __init__(self, mul: int, l_max: int, m_max: int, primary: str):
self.mul = mul
self.l_max = l_max
self.m_max = m_max
self.primary = primary

def _dim(self):
d = 0
for ell in range(self.l_max + 1):
m_max = min(ell, self.m_max)
d += (2 * m_max + 1) * self.mul
return d

def algebra(self=None) -> np.ndarray:
# note: shall we make an SO2 representation only since m_max can be smaller than l_max?
return cue.SO3.algebra()

def continuous_generators(self) -> np.ndarray:
# note: if m_max is smaller than l_max, this is actually not a full representation of SO3
# but it's a representation of the subgroup SO3 along the principal axis
raise NotImplementedError

def discrete_generators(self) -> np.ndarray:
return np.zeros((0, self.dim, self.dim))

# TODO
def trivial(self):
raise NotImplementedError
Loading