Skip to content

Commit

Permalink
Merge pull request #275 from DedalusProject/cylinder
Browse files Browse the repository at this point in the history
DirectProduct bases
  • Loading branch information
kburns authored Dec 22, 2023
2 parents d9d19dc + f7a9f2f commit 527969c
Show file tree
Hide file tree
Showing 38 changed files with 2,376 additions and 2,203 deletions.
8 changes: 5 additions & 3 deletions dedalus/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,11 @@ def Gamma(self, A_tensorsig, B_tensorsig, C_tensorsig, A_group, B_group, C_group
G = self.Gamma(A_tensorsig, B_tensorsig, C_tensorsig, A_group, B_group, C_group, axis-1)
# Apply Q
cs = self.dist.get_coordsystem(axis)
QA = cs.backward_intertwiner(axis, len(A_tensorsig), A_group).T
QB = cs.backward_intertwiner(axis, len(B_tensorsig), B_group).T
QC = cs.forward_intertwiner(axis, len(C_tensorsig), C_group)
cs_axis = self.dist.get_axis(cs)
subaxis = axis - cs_axis
QA = cs.backward_intertwiner(subaxis, len(A_tensorsig), A_group[cs_axis:]).T
QB = cs.backward_intertwiner(subaxis, len(B_tensorsig), B_group[cs_axis:]).T
QC = cs.forward_intertwiner(subaxis, len(C_tensorsig), C_group[cs_axis:])
Q = kron(QA, QB, QC)
G = (Q @ G.ravel()).reshape(G.shape)
return G
Expand Down
1,151 changes: 567 additions & 584 deletions dedalus/core/basis.py

Large diffs are not rendered by default.

200 changes: 127 additions & 73 deletions dedalus/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from ..libraries.dedalus_sphere import jacobi
from ..libraries import dedalus_sphere

from ..tools.array import nkron
from ..tools.cache import CachedMethod
from ..tools.array import nkron, sparse_block_diag
from ..tools.cache import CachedMethod, CachedAttribute

# Public interface
__all__ = ['Coordinate',
'DirectProduct',
'CartesianCoordinates',
'S2Coordinates',
'PolarCoordinates',
Expand All @@ -35,24 +36,34 @@ def __getitem__(self, key):
else:
return self.coords[key]

def set_distributor(self, distributor):
self.dist = distributor
for coord in self.coords:
coord.dist = distributor

def check_bounds(self, coord, bounds):
pass

@property
def first_axis(self):
return self.dist.coords.index(self.coords[0])
def forward_intertwiner(self, subaxis, order, group):
raise NotImplementedError("Subclasses must implement.")

def backward_intertwiner(self, subaxis, order, group):
raise NotImplementedError("Subclasses must implement.")


class SeparableIntertwiners:

def forward_vector_intertwiner(self, subaxis, group):
raise NotImplementedError("Subclasses must implement.")

@property
def axis(self):
return self.dist.coords.index(self.coords[0])
def backward_vector_intertwiner(self, subaxis, group):
raise NotImplementedError("Subclasses must implement.")

def forward_intertwiner(self, subaxis, order, group):
vector = self.forward_vector_intertwiner(subaxis, group)
return nkron(vector, order)

class Coordinate:
def backward_intertwiner(self, subaxis, order, group):
vector = self.backward_vector_intertwiner(subaxis, group)
return nkron(vector, order)


class Coordinate(SeparableIntertwiners):
dim = 1
default_nonconst_groups = (1,)
curvilinear = False
Expand All @@ -66,45 +77,107 @@ def __str__(self):
return self.name

def __eq__(self, other):
if self.name == other.name: return True
else: return False
if type(self) is type(other):
if self.name == other.name:
return True
return False

def __hash__(self):
return id(self)

@property
def axis(self):
return self.dist.coords.index(self)

def check_bounds(self, bounds):
if self.cs == None: return
else: self.cs.check_bounds(self, bounds)

def set_distributor(self, distributor):
self.dist = distributor
if self.cs:
self.cs.dist = distributor
def forward_vector_intertwiner(self, subaxis, group):
return np.array([[1]])

def backward_vector_intertwiner(self, subaxis, group):
return np.array([[1]])


class DirectProduct(SeparableIntertwiners, CoordinateSystem):

def __init__(self, *coordsystems, right_handed=None):
for cs in coordsystems:
if not isinstance(cs, SeparableIntertwiners):
raise NotImplementedError("Direct products only implemented for separable intertwiners.")
self.coordsystems = coordsystems
self.coords = sum((cs.coords for cs in coordsystems), ())
if len(set(self.coords)) < len(self.coords):
raise ValueError("Cannot repeat coordinates in DirectProduct.")
self.dim = sum(cs.dim for cs in coordsystems)
if self.dim == 3:
if self.curvilinear:
if right_handed is None:
right_handed = False
else:
if right_handed is None:
right_handed = True
self.right_handed = right_handed

@CachedAttribute
def subaxis_by_cs(self):
subaxis_dict = {}
subaxis = 0
for cs in self.coordsystems:
subaxis_dict[cs] = subaxis
subaxis += cs.dim
return subaxis_dict

@CachedAttribute
def curvilinear(self):
return any(cs.curvilinear for cs in self.coordsystems)

def forward_vector_intertwiner(self, subaxis, group):
factors = []
start_axis = 0
for cs in self.coordsystems:
if start_axis <= subaxis < start_axis + cs.dim:
factors.append(cs.forward_vector_intertwiner(subaxis-start_axis, group))
else:
factors.append(np.identity(cs.dim))
start_axis += cs.dim
return sparse_block_diag(factors).A

def backward_vector_intertwiner(self, subaxis, group):
factors = []
start_axis = 0
for cs in self.coordsystems:
if start_axis <= subaxis < start_axis + cs.dim:
factors.append(cs.backward_vector_intertwiner(subaxis-start_axis, group))
else:
factors.append(np.identity(cs.dim))
start_axis += cs.dim
return sparse_block_diag(factors).A

class CartesianCoordinates(CoordinateSystem):
@CachedAttribute
def default_nonconst_groups(self):
return sum((cs.default_nonconst_groups for cs in self.coordsystems), ())


class CartesianCoordinates(SeparableIntertwiners, CoordinateSystem):

curvilinear = False

def __init__(self, *names, right_handed=True):
if len(set(names)) < len(names):
raise ValueError("Must specify unique names.")
self.names = names
self.dim = len(names)
self.coords = tuple(Coordinate(name, cs=self) for name in names)
self.right_handed = right_handed
if self.dim == 3:
self.right_handed = right_handed
self.default_nonconst_groups = (1,) * self.dim

def __str__(self):
return '{' + ','.join([c.name for c in self.coords]) + '}'

def forward_intertwiner(self, axis, order, group):
return np.identity(self.dim**order)
def forward_vector_intertwiner(self, subaxis, group):
return np.identity(self.dim)

def backward_intertwiner(self, axis, order, group):
return np.identity(self.dim**order)
def backward_vector_intertwiner(self, subaxis, group):
return np.identity(self.dim)

@CachedMethod
def unit_vector_fields(self, dist):
Expand All @@ -125,7 +198,7 @@ class CurvilinearCoordinateSystem(CoordinateSystem):
curvilinear = True


class S2Coordinates(CurvilinearCoordinateSystem):
class S2Coordinates(SeparableIntertwiners, CurvilinearCoordinateSystem):
"""
S2 coordinate system: (azimuth, colatitude)
Coord component ordering: (azimuth, colatitude)
Expand All @@ -149,41 +222,37 @@ def _U_forward(cls, order):
Ui = {+1: np.array([+1j, 1]) / np.sqrt(2),
-1: np.array([-1j, 1]) / np.sqrt(2)}
U = np.array([Ui[spin] for spin in cls.spin_ordering])
return nkron(U, order)
if order > 1:
U = nkron(U, order)
return U

@classmethod
def _U_backward(cls, order):
"""Unitary transform from spin to coord components."""
return cls._U_forward(order).T.conj()

@property
def axis(self):
return self.azimuth.axis

def forward_intertwiner(self, axis, order, group):
subaxis = axis - self.axis
def forward_vector_intertwiner(self, subaxis, group):
if subaxis == 0:
# Azimuth intertwiner is identity, independent of group
return np.identity(self.dim**order)
return np.identity(self.dim)
elif subaxis == 1:
# Colatitude intertwiner is spin-U, independent of group
return self._U_forward(order)
return self._U_forward(1)
else:
raise ValueError("Invalid axis")

def backward_intertwiner(self, axis, order, group):
subaxis = axis - self.axis
def backward_vector_intertwiner(self, subaxis, group):
if subaxis == 0:
# Azimuth intertwiner is identity, independent of group
return np.identity(self.dim**order)
return np.identity(self.dim)
elif subaxis == 1:
# Colatitude intertwiner is spin-U, independent of group
return self._U_backward(order)
return self._U_backward(1)
else:
raise ValueError("Invalid axis")


class PolarCoordinates(CurvilinearCoordinateSystem):
class PolarCoordinates(SeparableIntertwiners, CurvilinearCoordinateSystem):
"""
Polar coordinate system: (azimuth, radius)
Coord component ordering: (azimuth, radius)
Expand All @@ -207,36 +276,32 @@ def _U_forward(cls, order):
Ui = {+1: np.array([+1j, 1]) / np.sqrt(2),
-1: np.array([-1j, 1]) / np.sqrt(2)}
U = np.array([Ui[spin] for spin in cls.spin_ordering])
return nkron(U, order)
if order > 1:
U = nkron(U, order)
return U

@classmethod
def _U_backward(cls, order):
"""Unitary transform from spin to coord components."""
return cls._U_forward(order).T.conj()

@property
def axis(self):
return self.azimuth.axis

def forward_intertwiner(self, axis, order, group):
subaxis = axis - self.axis
def forward_vector_intertwiner(self, subaxis, group):
if subaxis == 0:
# Azimuth intertwiner is identity, independent of group
return np.identity(self.dim**order)
return np.identity(self.dim)
elif subaxis == 1:
# Radial intertwiner is spin-U, independent of group
return self._U_forward(order)
return self._U_forward(1)
else:
raise ValueError("Invalid axis")

def backward_intertwiner(self, axis, order, group):
subaxis = axis - self.axis
def backward_vector_intertwiner(self, subaxis, group):
if subaxis == 0:
# Azimuth intertwiner is identity, independent of group
return np.identity(self.dim**order)
return np.identity(self.dim)
elif subaxis == 1:
# Radial intertwiner is spin-U, independent of group
return self._U_backward(order)
return self._U_backward(1)
else:
raise ValueError("Invalid axis")

Expand Down Expand Up @@ -296,10 +361,6 @@ def _Q_backward(cls, ell, order):
# This may not rebust to having spin and reg orderings be different?
return dedalus_sphere.spin_operators.Intertwiner(ell, indexing=cls.reg_ordering)(order)

@property
def axis(self):
return self.azimuth.axis

def check_bounds(self, coord, bounds):
if coord == self.radius:
if min(bounds) < 0:
Expand All @@ -316,20 +377,14 @@ def sub_cs(self, other):
else: return False
return False

def set_distributor(self, distributor):
self.dist = distributor
super().set_distributor(distributor)
self.S2coordsys.set_distributor(distributor)

@staticmethod
def cartesian(phi, theta, r):
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
return x, y, z

def forward_intertwiner(self, axis, order, group):
subaxis = axis - self.axis
def forward_intertwiner(self, subaxis, order, group):
if subaxis == 0:
# Azimuth intertwiner is identity, independent of group
return np.identity(self.dim**order)
Expand All @@ -338,13 +393,12 @@ def forward_intertwiner(self, axis, order, group):
return self._U_forward(order)
elif subaxis == 2:
# Radius intertwiner is reg-Q, dependent on ell
ell = group[axis-1]
ell = group[subaxis-1]
return self._Q_forward(ell, order)
else:
raise ValueError("Invalid axis")

def backward_intertwiner(self, axis, order, group):
subaxis = axis - self.axis
def backward_intertwiner(self, subaxis, order, group):
if subaxis == 0:
# Azimuth intertwiner is identity, independent of group
return np.identity(self.dim**order)
Expand All @@ -353,7 +407,7 @@ def backward_intertwiner(self, axis, order, group):
return self._U_backward(order)
elif subaxis == 2:
# Radius intertwiner is reg-Q, dependent on ell
ell = group[axis-1]
ell = group[subaxis-1]
return self._Q_backward(ell, order)
else:
raise ValueError("Invalid axis")
Loading

0 comments on commit 527969c

Please sign in to comment.