From ef4d87a75614cd02f762a28ca06e3fd6a5c3fd0f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 10 Jan 2025 13:58:01 +0100 Subject: [PATCH 1/2] implement some methods in SphericalSignal (WIP) --- .../cuequivariance/experimental/escn.py | 62 ++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/experimental/escn.py b/cuequivariance/cuequivariance/experimental/escn.py index c08b0f6..1a0f990 100644 --- a/cuequivariance/cuequivariance/experimental/escn.py +++ b/cuequivariance/cuequivariance/experimental/escn.py @@ -182,9 +182,67 @@ def escn_tp_compact( class SphericalSignal(cue.Rep): - def __init__(self, mul: int, l_max: int, m_max: int): + """Representation of a signal on the sphere. + + Args: + mul (int): Multiplicity of the signal. The multiplicity is always the innermost dimension (stride 1). + l_max (int): Maximum angular resolution. + m_max (int): Maximum angular resolution around the principal axis. + primary (str): "m" or "l". + If "m", all the components with the same m are contiguous in memory. + If "l", all the components with the same l are contiguous in memory. + + + primary="m": + l= + 0 1 2 m + +-+ = + |0| -2 + +-+-+ + |1|2| -1 + +-+-+-+ + |3|4|5| 0 + +-+-+-+ + |6|7| 1 + +-+-+ + |8| 2 + +-+ + + primary="l": + l= +-+ + 0 |0| + +-+-+-+ + 1 |1|2|3| + +-+-+-+-+-+ + 2 |4|5|6|7|8| + +-+-+-+-+-+ + m= -2-1 0 1 2 + """ + + def __init__(self, mul: int, l_max: int, m_max: int, primary: str): self.mul = mul self.l_max = l_max self.m_max = m_max + self.primary = primary + + def _dim(self): + d = 0 + for ell in range(self.l_max + 1): + m_max = min(ell, self.m_max) + d += (2 * m_max + 1) * self.mul + return d + + def algebra(self=None) -> np.ndarray: + # note: shall we make an SO2 representation only since m_max can be smaller than l_max? + return cue.SO3.algebra() + + def continuous_generators(self) -> np.ndarray: + # note: if m_max is smaller than l_max, this is actually not a full representation of SO3 + # but it's a representation of the subgroup SO3 along the principal axis + raise NotImplementedError + + def discrete_generators(self) -> np.ndarray: + return np.zeros((0, self.dim, self.dim)) - # TODO + def trivial(self): + raise NotImplementedError From e0b9ea5e02be931f947fb4950043ce8574bd1135 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 10 Jan 2025 15:12:37 +0100 Subject: [PATCH 2/2] Update SphericalSignal diagram for clarity and accuracy --- .../cuequivariance/experimental/escn.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/cuequivariance/cuequivariance/experimental/escn.py b/cuequivariance/cuequivariance/experimental/escn.py index 1a0f990..a623381 100644 --- a/cuequivariance/cuequivariance/experimental/escn.py +++ b/cuequivariance/cuequivariance/experimental/escn.py @@ -195,18 +195,18 @@ class SphericalSignal(cue.Rep): primary="m": l= - 0 1 2 m - +-+ = - |0| -2 - +-+-+ - |1|2| -1 - +-+-+-+ - |3|4|5| 0 - +-+-+-+ - |6|7| 1 - +-+-+ - |8| 2 - +-+ + 0 1 2 3 m + +-+-+ = + |0|1| -2 + +-+-+-+ + |2|3|4| -1 + +-+-+-+-+ + |5|6|7|8| 0 + +-+-+-+-+ + |9|a|b| 1 + +-+-+-+ + |c|d| 2 + +-+-+ primary="l": l= +-+ @@ -216,6 +216,8 @@ class SphericalSignal(cue.Rep): +-+-+-+-+-+ 2 |4|5|6|7|8| +-+-+-+-+-+ + 3 |9|a|b|c|d| + +-+-+-+-+-+ m= -2-1 0 1 2 """