diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 982da237..00758fb2 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -20,8 +20,8 @@ logger = logging.getLogger(__name__.split('.')[-1]) from ..tools.config import config -FFTW_RIGOR = lambda: config['transforms-fftw'].get('PLANNING_RIGOR') -DEALIAS_BEFORE_CONVERTING = lambda: config['transforms'].getboolean('DEALIAS_BEFORE_CONVERTING') +GET_FFTW_RIGOR = lambda: config['transforms-fftw'].get('PLANNING_RIGOR') +GET_DEALIAS_BEFORE_CONVERTING = lambda: config['transforms'].getboolean('DEALIAS_BEFORE_CONVERTING') def register_transform(basis, name): @@ -99,13 +99,16 @@ class JacobiTransform(SeparableTransform): TODO: We need to define the normalization we use here. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None): self.N = grid_size self.M = coeff_size self.a = a self.b = b self.a0 = a0 self.b0 = b0 + if dealias_before_converting is None: + dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() + self.dealias_before_converting = dealias_before_converting @register_transform(basis.Jacobi, 'matrix') @@ -125,7 +128,7 @@ def forward_matrix(self): base_transform = (base_polynomials * base_weights) # Zero higher coefficients for transforms with grid_size < coeff_size base_transform[N:, :] = 0 - if DEALIAS_BEFORE_CONVERTING(): + if self.dealias_before_converting: # Truncate to specified coeff_size base_transform = base_transform[:M, :] # Spectral conversion @@ -134,7 +137,7 @@ def forward_matrix(self): else: conversion = jacobi.conversion_matrix(base_transform.shape[0], a0, b0, a, b) forward_matrix = conversion @ base_transform - if not DEALIAS_BEFORE_CONVERTING(): + if not self.dealias_before_converting: # Truncate to specified coeff_size forward_matrix = forward_matrix[:M, :] # Ensure C ordering for fast dot products @@ -286,8 +289,18 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +class FFTWBase: + """Abstract base class for FFTW transforms.""" + + def __init__(self, *args, rigor=None, **kw): + if rigor is None: + rigor = GET_FFTW_RIGOR() + self.rigor = rigor + super().__init__(*args, **kw) + + @register_transform(basis.ComplexFourier, 'fftw') -class FFTWComplexFFT(ComplexFFT): +class FFTWComplexFFT(FFTWBase, ComplexFFT): """Complex-to-complex FFT using FFTW.""" @CachedMethod @@ -295,7 +308,7 @@ def _build_fftw_plan(self, gshape, axis): """Build FFTW plans and temporary arrays.""" dtype = np.complex128 logger.debug("Building FFTW FFT plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis)) - flags = ['FFTW_'+FFTW_RIGOR().upper()] + flags = ['FFTW_'+self.rigor.upper()] plan = fftw.FourierTransform(dtype, gshape, axis, flags=flags) temp = fftw.create_array(plan.cshape, np.complex128) return plan, temp @@ -522,7 +535,7 @@ def backward(self, cdata, gdata, axis): @register_transform(basis.RealFourier, 'fftw') -class FFTWRealFFT(RealFFT): +class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @CachedMethod @@ -530,7 +543,7 @@ def _build_fftw_plan(self, gshape, axis): """Build FFTW plans and temporary arrays.""" dtype = np.float64 logger.debug("Building FFTW FFT plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis)) - flags = ['FFTW_'+FFTW_RIGOR().upper()] + flags = ['FFTW_'+self.rigor.upper()] plan = fftw.FourierTransform(dtype, gshape, axis, flags=flags) temp = fftw.create_array(plan.cshape, np.complex128) return plan, temp @@ -553,14 +566,14 @@ def backward(self, cdata, gdata, axis): @register_transform(basis.RealFourier, 'fftw_hc') -class FFTWHalfComplexFFT(RealFourierTransform): +class FFTWHalfComplexFFT(FFTWBase, RealFourierTransform): """Real-to-real FFT using FFTW half-complex DFT.""" @CachedMethod def _build_fftw_plan(self, dtype, gshape, axis): """Build FFTW plans and temporary arrays.""" logger.debug("Building FFTW R2HC plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis)) - flags = ['FFTW_'+FFTW_RIGOR().upper()] + flags = ['FFTW_'+self.rigor.upper()] plan = fftw.R2HCTransform(dtype, gshape, axis, flags=flags) temp = fftw.create_array(gshape, dtype) return plan, temp @@ -756,14 +769,14 @@ def backward(self, cdata, gdata, axis): #@register_transform(basis.Cosine, 'fftw') -class FFTWDCT(FastCosineTransform): +class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" @CachedMethod def _build_fftw_plan(self, dtype, gshape, axis): """Build FFTW plans and temporary arrays.""" logger.debug("Building FFTW DCT plan for (dtype, gshape, axis) = (%s, %s, %s)" %(dtype, gshape, axis)) - flags = ['FFTW_'+FFTW_RIGOR().upper()] + flags = ['FFTW_'+self.rigor.upper()] plan = fftw.DiscreteCosineTransform(dtype, gshape, axis, flags=flags) temp = fftw.create_array(gshape, dtype) return plan, temp @@ -791,11 +804,11 @@ class FastChebyshevTransform(JacobiTransform): Subclasses should inherit from this class, then a FastCosineTransform subclass. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): if not a0 == b0 == -1/2: raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.") # Jacobi initialization - super().__init__(grid_size, coeff_size, a, b, a0, b0) + super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw) # DCT initialization to set scaling factors if a != a0 or b != b0: # Modify coeff_size to avoid truncation before conversion @@ -817,7 +830,7 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0): self.resize_rescale_backward = self._resize_rescale_backward else: # Conversion matrices - if DEALIAS_BEFORE_CONVERTING() and (self.M_orig < self.N): # truncate prior to conversion matrix + if self.dealias_before_converting and (self.M_orig < self.N): # truncate prior to conversion matrix self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() else: # input to conversion matrix not truncated self.forward_conversion = jacobi.conversion_matrix(self.N, a0, b0, a, b) @@ -855,7 +868,7 @@ def _resize_rescale_forward_convert(self, data_in, data_out, axis, Kmax_DCT): posfreq_odd = axslice(axis, 1, Kmax_DCT+1, 2) data_in[posfreq_odd] *= -1 # Ultraspherical conversion - if DEALIAS_BEFORE_CONVERTING() and self.M_orig < self.N: # truncate data + if self.dealias_before_converting and self.M_orig < self.N: # truncate data goodfreq = axslice(axis, 0, self.M_orig) data_in = data_in[goodfreq] apply_sparse(self.forward_conversion, data_in, axis, out=data_out) @@ -1333,7 +1346,7 @@ class DiskRadialTransform(NonSeparableTransform): - Remove dependence on grid_shape? """ - def __init__(self, grid_shape, basis_shape, axis, m_maps, s, k, alpha, dtype=np.complex128): + def __init__(self, grid_shape, basis_shape, axis, m_maps, s, k, alpha, dtype=np.complex128, dealias_before_converting=None): self.Nphi = basis_shape[0] self.Nmax = basis_shape[1] - 1 super().__init__(grid_shape, self.Nmax+1, axis, dtype) @@ -1341,6 +1354,9 @@ def __init__(self, grid_shape, basis_shape, axis, m_maps, s, k, alpha, dtype=np. self.s = s self.k = k self.alpha = alpha + if dealias_before_converting is None: + dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() + self.dealias_before_converting = dealias_before_converting def forward_reduced(self, gdata, cdata): # local_m = self.local_m @@ -1396,14 +1412,14 @@ def _forward_matrices(self): # Zero higher coefficients than can be correctly computed with base Gauss quadrature dN = abs(m + self.s) // 2 W[max(self.N2g-dN,0):] = 0 - if DEALIAS_BEFORE_CONVERTING(): + if self.dealias_before_converting: # Truncate to specified coeff_size W = W[:max(self.N2c-Nmin,0)] # Spectral conversion if self.k > 0: conversion = dedalus_sphere.zernike.operator(2, 'E')(+1)**self.k W = conversion(W.shape[0], self.alpha, abs(m + self.s)) @ W - if not DEALIAS_BEFORE_CONVERTING(): + if not self.dealias_before_converting: # Truncate to specified coeff_size W = W[:max(self.N2c-Nmin,0)] m_matrices[m] = np.asarray(W.astype(np.float64), order='C') @@ -1434,7 +1450,7 @@ def _backward_matrices(self): @register_transform(basis.BallBasis, 'matrix') class BallRadialTransform(Transform): - def __init__(self, grid_shape, coeff_size, axis, ell_maps, regindex, regtotal, k, alpha, dtype=np.complex128): + def __init__(self, grid_shape, coeff_size, axis, ell_maps, regindex, regtotal, k, alpha, dtype=np.complex128, dealias_before_converting=None): self.N3g = grid_shape[axis] self.N3c = coeff_size self.ell_maps = ell_maps @@ -1443,6 +1459,9 @@ def __init__(self, grid_shape, coeff_size, axis, ell_maps, regindex, regtotal, k self.regtotal = regtotal self.k = k self.alpha = alpha + if dealias_before_converting is None: + dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() + self.dealias_before_converting = dealias_before_converting def forward(self, gdata, cdata, axis): # Make reduced view into input arrays @@ -1506,14 +1525,14 @@ def _forward_GSZP_matrix(self): # Zero higher coefficients than can be correctly computed with base Gauss quadrature dN = (ell + self.regtotal) // 2 W[max(self.N3g-dN,0):] = 0 - if DEALIAS_BEFORE_CONVERTING(): + if self.dealias_before_converting: # Truncate to specified coeff_size W = W[:max(self.N3c-Nmin,0)] # Spectral conversion if self.k > 0: conversion = dedalus_sphere.zernike.operator(3, 'E')(+1)**self.k W = conversion(W.shape[0], self.alpha, ell+self.regtotal) @ W - if not DEALIAS_BEFORE_CONVERTING(): + if not self.dealias_before_converting: # Truncate to specified coeff_size W = W[:max(self.N3c-Nmin,0)] # Ensure C ordering for fast dot products