Skip to content

Commit

Permalink
feat(fftw): implement a few fft functions with pyfftw
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Nov 7, 2024
1 parent f75db33 commit ba8a8ad
Show file tree
Hide file tree
Showing 2 changed files with 356 additions and 0 deletions.
355 changes: 355 additions & 0 deletions caput/fftw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
"""Fast FFT implementation using FFTW.
This module adds some minor abstraction to use pyfftw in a way
which seems to be faster than using the `pyfftw.builders` interface,
and uses the same api as `scipy.fft` and `numpy.fft`.
Only forward and reverse real->real or complex->complex transforms
are currently supported.
Examples
--------
The core of this module is the :class:`FFT`, which essentially just
abstracts the :class:`pyfftw:FFTW` in the simplest way.
>>> import numpy as np
>>> from caput import fftw
>>>
>>> shape = (24, 50)
>>> x = np.random.rand(*shape) + 1j * np.random.rand(*shape)
>>>
>>> fftobj = fftw.FFT(x.shape, x.dtype, axes=-1)
>>>
>>> X = fftobj.fft(x)
>>> xi = fftobj.ifft(X)
>>>
>>> np.allclose(x, xi)
True
The direct API can also be used, although it is slower when doing repeated
transforms for arrays of the same shape and type because a new :class:`FFT`
has to be created each time.
References
----------
.. https://pyfftw.readthedocs.io
.. http://www.fftw.org
Classes
=======
- :py:class:`FFT`
Functions
=========
- :py:meth:`fft`
- :py:meth:`ifft`
- :py:meth:`fftconvolve`
- :py:meth:`fftwindow`
"""

from __future__ import annotations

import pyfftw

from caput import mpiutil


class FFT:
"""Faster FFTs with FFTW."""

def __init__(
self,
shape: tuple,
dtype: type,
axes: None | int | tuple = None,
forward: bool = True,
backward: bool = True,
):
"""Create FFTW objects for repeat use.
This implementation is most efficient when used to repeatedly
apply ffts to arrays with the same shape and dtype, because a
single, highly optimised pathway can be used with a single
initialisation.
Even for a single use this will typically
be faster than the `scipy.fft` or `numpy.fft` implementations,
especially when multiple cores can be used.
Parameters
----------
shape
The shape of the arrays to initialise for
dtype
Datatype to create a pathway for. At the moment, only
complex -> complex or real -> real are supported. The
`pyfftw` implementation of the real -> real backward
transform will destroy the input array
axes
Axes over which to apply the fft. Default is all axes.
forward
If true, initialise the forward fft. Default is True.
backward
If true, initialise the backward fft. Default is True.
"""
ncpu = mpiutil.cpu_count()
self._nsimd = pyfftw.simd_alignment
flags = ("FFTW_MEASURE",)

if axes is None:
axes = tuple(range(len(shape)))
elif isinstance(axes, int):
axes = (axes,)

# Store fft params
self._params = {
"ncpu": ncpu,
"simd_alignment": self._nsimd,
"shape": shape,
"dtype": dtype,
"axes": axes,
"flags": flags,
}

fftargs = {
"input_array": pyfftw.empty_aligned(shape, dtype, n=self._nsimd),
"output_array": pyfftw.empty_aligned(shape, dtype, n=self._nsimd),
"axes": axes,
"flags": flags,
"threads": ncpu,
}

self._single_direction = forward ^ backward

if forward:
self._fft = pyfftw.FFTW(direction="FFTW_FORWARD", **fftargs)

if backward:
self._ifft = pyfftw.FFTW(direction="FFTW_BACKWARD", **fftargs)

@property
def params(self):
"""Display the parameters of this FFT.
Returns
-------
params: dict
ncpu, simd alignment, shape, dtype, axes, and flags
used by this FFT object.
"""
return self._params

def fft(self, x):
"""Perform a forward FFT.
Parameters
----------
x : np.ndarray
Input array, must match the dtype specified
at creation
Returns
-------
fft : np.ndarray
DFT of the input array over specified axes
"""
try:
return self._fft(
input_array=x,
output_array=pyfftw.empty_aligned(x.shape, x.dtype, n=self._nsimd),
)
except AttributeError:
raise NotImplementedError("Forward fft not initialised.")

def ifft(self, x):
"""Perform a backward FFT.
When performing the backward real -> real IFFT,
the input array is destroyed.
Parameters
----------
x : np.ndarray
Input array, must match the dtype specified
at creation
Returns
-------
fft : np.ndarray
IDFT of the input array over specified axes
"""
try:
return self._ifft(
input_array=x,
output_array=pyfftw.empty_aligned(x.shape, x.dtype, n=self._nsimd),
)
except AttributeError:
raise NotImplementedError("Backward fft not initialised.")

def fftconvolve(self, in1, in2):
"""Convolve two arrays by multiplying in the Fourier domain.
`in1` and `in2` must have the same dtype.
Parameters
----------
in1 : np.ndarray
First input array
in2 : np.ndarray
Second input array to by convolved with `x`. Must have
the same dtype as `x`.
Returns
-------
out : np.ndarray
Discrete convolution of `in1` and `in2`
"""
if self._single_direction:
raise NotImplementedError(
"`fftconvolve` is not supported when only one "
"fft direction is initialised"
)

X1 = self.fft(in1)
X2 = self.fft(in2)

X1 *= X2

return self.ifft(X1)

def fftwindow(self, x, window):
"""Apply a window function in Fourier space.
The only difference between this and `fftconvolve` is that
this assumes that `window` is _already_ in the Fourier domain,
and `window` can be real or complex when `x` is complex.
Parameters
----------
x : np.ndarray
Input array
window : np.ndarray
Window to be applied in the Fourier domain.
Returns
-------
out : np.ndarray
Input array `x` with `window` applied in the Fourier domain.
"""
if self._single_direction:
raise NotImplementedError(
"`fftwindow` is not supported when only one "
"fft direction is initialised"
)

X = self.fft(x)
X *= window

return self.ifft(X)


def fft(x, axes=None):
"""Perform a forward discrete Fourer Transform.
If the fourier transform is to be applied repeatedly to
arrays with the same size and dtype, it is faster to use
the `FFT` class directly to avoid creating new `FFT` objects.
Parameters
----------
x : np.ndarray
Input array, real or complex
axes : None | int | tuple
Axes over which to take the fft. Default is all axes.
Returns
-------
fft : np.ndarray
DFT of the input array over specified axes
"""
fftobj = FFT(x.shape, x.dtype, axes, forward=True, backward=False)

return fftobj.fft(x)


def ifft(x, axes=None):
"""Perform an inverse discrete Fourier Transform.
If the fourier transform is to be applied repeatedly to
arrays with the same size and dtype, it is faster to use
the `FFT` class directly to avoid creating new `FFT` objects.
Parameters
----------
x : np.ndarray
Input array, real or complex
axes : None | int | tuple
Axes over which to take the ifft. Default is all axes.
Returns
-------
fft : np.ndarray
IDFT of the input array over specified axes
"""
fftobj = FFT(x.shape, x.dtype, axes, forward=False, backward=True)

return fftobj.ifft(x)


def fftconvolve(in1, in2, axes=None):
"""Convolve two arrays by multiplying in the Fourier domain.
`in1` and `in2` must have the same dtype.
If the convolution is to be applied repeatedly to
arrays with the same size and dtype, it is faster to use
the `FFT` class directly to avoid creating new `FFT` objects.
Parameters
----------
in1 : np.ndarray
First input array
in2 : np.ndarray
Second input array to by convolved with `x`. Must have
the same dtype as `x`.
axes : None | int | tuple
Axes over which to do the convolution. Default is all axes.
Returns
-------
out : np.ndarray
Discrete convolution of `in1` and `in2`
"""
fftobj = FFT(in1.shape, in1.dtype, axes, forward=True, backward=True)

return fftobj.fftconvolve(in1, in2)


def fftwindow(x, window, axes):
"""Apply a window function in Fourier space.
The only difference between this and `fftconvolve` is that
this assumes that `window` is _already_ in the Fourier domain,
and `window` can be real or complex when `x` is complex.
If the window is to be applied repeatedly to
arrays with the same size and dtype, it is faster to use
the `FFT` class directly to avoid creating new `FFT` objects.
Parameters
----------
x : np.ndarray
Input array
window : np.ndarray
Window to be applied in the Fourier domain.
axes : None | int | tuple
Axes over which to apply the window. Default is all axes.
Returns
-------
out : np.ndarray
Input array `x` with `window` applied in the Fourier domain.
"""
fftobj = FFT(x.shape, x.dtype, axes, forward=True, backward=True)

return fftobj.fftwindow(x, window)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"h5py",
"numpy>=1.24",
"psutil",
"pyfftw",
"PyYAML",
"scipy",
"skyfield>=1.31",
Expand Down

0 comments on commit ba8a8ad

Please sign in to comment.