Skip to content

Commit

Permalink
Change planning rigor and dealias_before_converting to transform clas…
Browse files Browse the repository at this point in the history
…s attributes rather than continuous config lookups
  • Loading branch information
kburns committed Dec 22, 2023
1 parent 0c45827 commit 1fc2048
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions dedalus/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
logger = logging.getLogger(__name__.split('.')[-1])

from ..tools.config import config
FFTW_RIGOR = lambda: config['transforms-fftw'].get('PLANNING_RIGOR')
DEALIAS_BEFORE_CONVERTING = lambda: config['transforms'].getboolean('DEALIAS_BEFORE_CONVERTING')
GET_FFTW_RIGOR = lambda: config['transforms-fftw'].get('PLANNING_RIGOR')
GET_DEALIAS_BEFORE_CONVERTING = lambda: config['transforms'].getboolean('DEALIAS_BEFORE_CONVERTING')


def register_transform(basis, name):
Expand Down Expand Up @@ -99,13 +99,16 @@ class JacobiTransform(SeparableTransform):
TODO: We need to define the normalization we use here.
"""

def __init__(self, grid_size, coeff_size, a, b, a0, b0):
def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None):
self.N = grid_size
self.M = coeff_size
self.a = a
self.b = b
self.a0 = a0
self.b0 = b0
if dealias_before_converting is None:
dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING()
self.dealias_before_converting = dealias_before_converting


@register_transform(basis.Jacobi, 'matrix')
Expand All @@ -125,7 +128,7 @@ def forward_matrix(self):
base_transform = (base_polynomials * base_weights)
# Zero higher coefficients for transforms with grid_size < coeff_size
base_transform[N:, :] = 0
if DEALIAS_BEFORE_CONVERTING():
if self.dealias_before_converting:
# Truncate to specified coeff_size
base_transform = base_transform[:M, :]
# Spectral conversion
Expand All @@ -134,7 +137,7 @@ def forward_matrix(self):
else:
conversion = jacobi.conversion_matrix(base_transform.shape[0], a0, b0, a, b)
forward_matrix = conversion @ base_transform
if not DEALIAS_BEFORE_CONVERTING():
if not self.dealias_before_converting:
# Truncate to specified coeff_size
forward_matrix = forward_matrix[:M, :]
# Ensure C ordering for fast dot products
Expand Down Expand Up @@ -286,16 +289,26 @@ def backward(self, cdata, gdata, axis):
np.copyto(gdata, temp)


class FFTWBase:
"""Abstract base class for FFTW transforms."""

def __init__(self, *args, rigor=None, **kw):
if rigor is None:
rigor = GET_FFTW_RIGOR()
self.rigor = rigor
super().__init__(*args, **kw)


@register_transform(basis.ComplexFourier, 'fftw')
class FFTWComplexFFT(ComplexFFT):
class FFTWComplexFFT(FFTWBase, ComplexFFT):
"""Complex-to-complex FFT using FFTW."""

@CachedMethod
def _build_fftw_plan(self, gshape, axis):
"""Build FFTW plans and temporary arrays."""
dtype = np.complex128
logger.debug("Building FFTW FFT plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis))
flags = ['FFTW_'+FFTW_RIGOR().upper()]
flags = ['FFTW_'+self.rigor.upper()]
plan = fftw.FourierTransform(dtype, gshape, axis, flags=flags)
temp = fftw.create_array(plan.cshape, np.complex128)
return plan, temp
Expand Down Expand Up @@ -522,15 +535,15 @@ def backward(self, cdata, gdata, axis):


@register_transform(basis.RealFourier, 'fftw')
class FFTWRealFFT(RealFFT):
class FFTWRealFFT(FFTWBase, RealFFT):
"""Real-to-real FFT using FFTW."""

@CachedMethod
def _build_fftw_plan(self, gshape, axis):
"""Build FFTW plans and temporary arrays."""
dtype = np.float64
logger.debug("Building FFTW FFT plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis))
flags = ['FFTW_'+FFTW_RIGOR().upper()]
flags = ['FFTW_'+self.rigor.upper()]
plan = fftw.FourierTransform(dtype, gshape, axis, flags=flags)
temp = fftw.create_array(plan.cshape, np.complex128)
return plan, temp
Expand All @@ -553,14 +566,14 @@ def backward(self, cdata, gdata, axis):


@register_transform(basis.RealFourier, 'fftw_hc')
class FFTWHalfComplexFFT(RealFourierTransform):
class FFTWHalfComplexFFT(FFTWBase, RealFourierTransform):
"""Real-to-real FFT using FFTW half-complex DFT."""

@CachedMethod
def _build_fftw_plan(self, dtype, gshape, axis):
"""Build FFTW plans and temporary arrays."""
logger.debug("Building FFTW R2HC plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis))
flags = ['FFTW_'+FFTW_RIGOR().upper()]
flags = ['FFTW_'+self.rigor.upper()]
plan = fftw.R2HCTransform(dtype, gshape, axis, flags=flags)
temp = fftw.create_array(gshape, dtype)
return plan, temp
Expand Down Expand Up @@ -756,14 +769,14 @@ def backward(self, cdata, gdata, axis):


#@register_transform(basis.Cosine, 'fftw')
class FFTWDCT(FastCosineTransform):
class FFTWDCT(FFTWBase, FastCosineTransform):
"""Fast cosine transform using FFTW."""

@CachedMethod
def _build_fftw_plan(self, dtype, gshape, axis):
"""Build FFTW plans and temporary arrays."""
logger.debug("Building FFTW DCT plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis))
flags = ['FFTW_'+FFTW_RIGOR().upper()]
flags = ['FFTW_'+self.rigor.upper()]
plan = fftw.DiscreteCosineTransform(dtype, gshape, axis, flags=flags)
temp = fftw.create_array(gshape, dtype)
return plan, temp
Expand Down Expand Up @@ -791,11 +804,11 @@ class FastChebyshevTransform(JacobiTransform):
Subclasses should inherit from this class, then a FastCosineTransform subclass.
"""

def __init__(self, grid_size, coeff_size, a, b, a0, b0):
def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw):
if not a0 == b0 == -1/2:
raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.")
# Jacobi initialization
super().__init__(grid_size, coeff_size, a, b, a0, b0)
super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw)
# DCT initialization to set scaling factors
if a != a0 or b != b0:
# Modify coeff_size to avoid truncation before conversion
Expand All @@ -817,7 +830,7 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0):
self.resize_rescale_backward = self._resize_rescale_backward
else:
# Conversion matrices
if DEALIAS_BEFORE_CONVERTING() and (self.M_orig < self.N): # truncate prior to conversion matrix
if self.dealias_before_converting and (self.M_orig < self.N): # truncate prior to conversion matrix
self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr()
else: # input to conversion matrix not truncated
self.forward_conversion = jacobi.conversion_matrix(self.N, a0, b0, a, b)
Expand Down Expand Up @@ -855,7 +868,7 @@ def _resize_rescale_forward_convert(self, data_in, data_out, axis, Kmax_DCT):
posfreq_odd = axslice(axis, 1, Kmax_DCT+1, 2)
data_in[posfreq_odd] *= -1
# Ultraspherical conversion
if DEALIAS_BEFORE_CONVERTING() and self.M_orig < self.N: # truncate data
if self.dealias_before_converting and self.M_orig < self.N: # truncate data
goodfreq = axslice(axis, 0, self.M_orig)
data_in = data_in[goodfreq]
apply_sparse(self.forward_conversion, data_in, axis, out=data_out)
Expand Down Expand Up @@ -1333,14 +1346,17 @@ class DiskRadialTransform(NonSeparableTransform):
- Remove dependence on grid_shape?
"""

def __init__(self, grid_shape, basis_shape, axis, m_maps, s, k, alpha, dtype=np.complex128):
def __init__(self, grid_shape, basis_shape, axis, m_maps, s, k, alpha, dtype=np.complex128, dealias_before_converting=None):
self.Nphi = basis_shape[0]
self.Nmax = basis_shape[1] - 1
super().__init__(grid_shape, self.Nmax+1, axis, dtype)
self.m_maps = m_maps
self.s = s
self.k = k
self.alpha = alpha
if dealias_before_converting is None:
dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING()
self.dealias_before_converting = dealias_before_converting

def forward_reduced(self, gdata, cdata):
# local_m = self.local_m
Expand Down Expand Up @@ -1396,14 +1412,14 @@ def _forward_matrices(self):
# Zero higher coefficients than can be correctly computed with base Gauss quadrature
dN = abs(m + self.s) // 2
W[max(self.N2g-dN,0):] = 0
if DEALIAS_BEFORE_CONVERTING():
if self.dealias_before_converting:
# Truncate to specified coeff_size
W = W[:max(self.N2c-Nmin,0)]
# Spectral conversion
if self.k > 0:
conversion = dedalus_sphere.zernike.operator(2, 'E')(+1)**self.k
W = conversion(W.shape[0], self.alpha, abs(m + self.s)) @ W
if not DEALIAS_BEFORE_CONVERTING():
if not self.dealias_before_converting:
# Truncate to specified coeff_size
W = W[:max(self.N2c-Nmin,0)]
m_matrices[m] = np.asarray(W.astype(np.float64), order='C')
Expand Down Expand Up @@ -1434,7 +1450,7 @@ def _backward_matrices(self):
@register_transform(basis.BallBasis, 'matrix')
class BallRadialTransform(Transform):

def __init__(self, grid_shape, coeff_size, axis, ell_maps, regindex, regtotal, k, alpha, dtype=np.complex128):
def __init__(self, grid_shape, coeff_size, axis, ell_maps, regindex, regtotal, k, alpha, dtype=np.complex128, dealias_before_converting=None):
self.N3g = grid_shape[axis]
self.N3c = coeff_size
self.ell_maps = ell_maps
Expand All @@ -1443,6 +1459,9 @@ def __init__(self, grid_shape, coeff_size, axis, ell_maps, regindex, regtotal, k
self.regtotal = regtotal
self.k = k
self.alpha = alpha
if dealias_before_converting is None:
dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING()
self.dealias_before_converting = dealias_before_converting

def forward(self, gdata, cdata, axis):
# Make reduced view into input arrays
Expand Down Expand Up @@ -1506,14 +1525,14 @@ def _forward_GSZP_matrix(self):
# Zero higher coefficients than can be correctly computed with base Gauss quadrature
dN = (ell + self.regtotal) // 2
W[max(self.N3g-dN,0):] = 0
if DEALIAS_BEFORE_CONVERTING():
if self.dealias_before_converting:
# Truncate to specified coeff_size
W = W[:max(self.N3c-Nmin,0)]
# Spectral conversion
if self.k > 0:
conversion = dedalus_sphere.zernike.operator(3, 'E')(+1)**self.k
W = conversion(W.shape[0], self.alpha, ell+self.regtotal) @ W
if not DEALIAS_BEFORE_CONVERTING():
if not self.dealias_before_converting:
# Truncate to specified coeff_size
W = W[:max(self.N3c-Nmin,0)]
# Ensure C ordering for fast dot products
Expand Down

0 comments on commit 1fc2048

Please sign in to comment.