-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added DTCWT
operator
#495
Added DTCWT
operator
#495
Changes from all commits
52e6c12
4b7ec1e
f887031
7435b51
22e6780
60f9ca0
b0e37c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,7 @@ Signal processing | |
DWT | ||
DWT2D | ||
DCT | ||
DTCWT | ||
Seislet | ||
Radon2D | ||
Radon3D | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ dependencies: | |
- isort | ||
- black | ||
- pip: | ||
- dtcwt | ||
- devito | ||
- scikit-fmm | ||
- spgl1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Dual-Tree Complex Wavelet Transform | ||
========================= | ||
This example shows how to use the :py:class:`pylops.signalprocessing.DCT` operator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be pylops.signalprocessing.DTCWT |
||
This operator performs the 1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) | ||
input array. | ||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import pylops | ||
|
||
plt.close("all") | ||
|
||
############################################################################### | ||
# Let's define a 1D array x of having random values | ||
|
||
n = 50 | ||
x = np.random.rand(n,) | ||
|
||
############################################################################### | ||
# We create the DTCWT operator with shape of our input array. DTCWT transform | ||
# gives a Pyramid object that is flattened out `y`. | ||
|
||
DOp = pylops.signalprocessing.DTCWT(dims=x.shape) | ||
y = DOp @ x | ||
xadj = DOp.H @ y | ||
|
||
plt.figure(figsize=(8, 5)) | ||
plt.plot(x, "k", label="input array") | ||
plt.plot(y, "r", label="transformed array") | ||
plt.plot(xadj, "--b", label="transformed array") | ||
plt.title("Dual-Tree Complex Wavelet Transform 1D") | ||
plt.legend() | ||
plt.tight_layout() | ||
|
||
################################################################################# | ||
# To get the Pyramid object use the `get_pyramid` method. | ||
# We can get the Highpass signal and Lowpass signal from it | ||
|
||
pyr = DOp.get_pyramid(y) | ||
|
||
plt.figure(figsize=(10, 5)) | ||
plt.plot(x, "--b", label="orignal signal") | ||
plt.plot(pyr.lowpass, "k", label="lowpass") | ||
plt.plot(pyr.highpasses[0], "r", label="highpass level 1 signal") | ||
plt.plot(pyr.highpasses[1], "b", label="highpass level 2 signal") | ||
plt.plot(pyr.highpasses[2], "g", label="highpass level 3 signal") | ||
|
||
plt.title("DTCWT Pyramid Object") | ||
plt.legend() | ||
plt.tight_layout() | ||
|
||
################################################################################### | ||
# DTCWT can also be performed on multi-dimension arrays. The number of levels can also | ||
# be defined using the `nlevels` | ||
|
||
n = 10 | ||
m = 2 | ||
|
||
x = np.random.rand(n, m) | ||
|
||
DOp = pylops.signalprocessing.DTCWT(dims=x.shape, nlevels=5) | ||
y = DOp @ x | ||
xadj = DOp.H @ y | ||
|
||
plt.figure(figsize=(8, 5)) | ||
plt.plot(x, "k", label="input array") | ||
plt.plot(y, "r", label="transformed array") | ||
plt.plot(xadj, "--b", label="transformed array") | ||
plt.title("Dual-Tree Complex Wavelet Transform 1D on ND array") | ||
plt.legend() | ||
plt.tight_layout() |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,155 @@ | ||||
__all__ = ["DTCWT"] | ||||
|
||||
from typing import Union | ||||
|
||||
import dtcwt | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs to be wrapped in a https://github.com/PyLops/pylops/blob/dev/pylops/_torchoperator.py#LL5-LL10 |
||||
import numpy as np | ||||
|
||||
from pylops import LinearOperator | ||||
from pylops.utils._internal import _value_or_sized_to_tuple | ||||
from pylops.utils.decorators import reshaped | ||||
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray | ||||
|
||||
|
||||
class DTCWT(LinearOperator): | ||||
r"""Dual-Tree Complex Wavelet Transform | ||||
|
||||
Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a | ||||
multi-dimensional array of size ``dims``. | ||||
|
||||
Note that the DTCWT operator is an overload of the ``dtcwt`` | ||||
implementation of the DT-CWT transform. Refer to | ||||
https://dtcwt.readthedocs.io for a detailed description of the | ||||
input parameters. | ||||
|
||||
Parameters | ||||
---------- | ||||
dims : :obj:`int` or :obj:`tuple` | ||||
Number of samples for each dimension. | ||||
birot : :obj:`str`, optional | ||||
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`. | ||||
qshift : :obj:`str`, optional | ||||
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"` | ||||
nlevels : :obj:`int`, optional | ||||
Number of levels of wavelet decomposition. Default is 3. | ||||
include_scale : :obj:`bool`, optional | ||||
Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False. | ||||
axis : :obj:`int`, optional | ||||
Axis on which the transform is performed. | ||||
dtype : :obj:`DTypeLike`, optional | ||||
Type of elements in input array. | ||||
name : :obj:`str`, optional | ||||
Name of operator (to be used by :func:`pylops.utils.describe.describe`) | ||||
|
||||
Notes | ||||
----- | ||||
The DTCWT operator applies the dual-tree complex wavelet transform | ||||
in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode | ||||
from the ``dtcwt`` library. | ||||
|
||||
The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain, | ||||
which is composed of: | ||||
- `lowpass` (coarsest scale lowpass signal); | ||||
- `highpasses` (complex subband coefficients for corresponding scales); | ||||
- `scales` (lowpass signal for corresponding scales finest to coarsest). | ||||
|
||||
To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model | ||||
the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients) | ||||
are appended into one array using the `_coeff_to_array` method. | ||||
|
||||
In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff` | ||||
method and then the inverse transform is performed. | ||||
|
||||
""" | ||||
|
||||
def __init__( | ||||
self, | ||||
dims: Union[int, InputDimsLike], | ||||
biort: str = "near_sym_a", | ||||
qshift: str = "qshift_a", | ||||
nlevels: int = 3, | ||||
include_scale: bool = False, | ||||
axis: int = -1, | ||||
dtype: DTypeLike = "float64", | ||||
name: str = "C", | ||||
) -> None: | ||||
self.dims = _value_or_sized_to_tuple(dims) | ||||
self.ndim = len(self.dims) | ||||
self.nlevels = nlevels | ||||
self.include_scale = include_scale | ||||
self.axis = axis | ||||
# dry-run of transform to | ||||
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) | ||||
self._interpret_coeffs() | ||||
super().__init__( | ||||
dtype=np.dtype(dtype), | ||||
dims=self.dims, | ||||
dimsd=(self.coeff_array_size,), | ||||
name=name, | ||||
) | ||||
|
||||
def _interpret_coeffs(self): | ||||
x = np.ones(self.dims) | ||||
x = x.swapaxes(self.axis, -1) | ||||
self.swapped_dims = x.shape | ||||
x = self._nd_to_2d(x) | ||||
pyr = self._transform.forward(x, nlevels=self.nlevels, include_scale=True) | ||||
self.coeff_array_size = 0 | ||||
self.lowpass_size = len(pyr.lowpass) | ||||
self.slices = [] | ||||
for _h in pyr.highpasses: | ||||
self.slices.append(len(_h)) | ||||
self.coeff_array_size += len(_h) | ||||
self.coeff_array_size += self.lowpass_size | ||||
elements = np.prod(x.shape[1:]) | ||||
self.coeff_array_size *= elements | ||||
self.lowpass_size *= elements | ||||
self.first_dim = elements | ||||
|
||||
def _nd_to_2d(self, arr_nd): | ||||
arr_2d = arr_nd.reshape((self.dims[0], -1)) | ||||
return arr_2d | ||||
|
||||
def _2d_to_nd(self, arr_2d): | ||||
arr_nd = arr_2d.reshape(self.swapped_dims) | ||||
return arr_nd | ||||
|
||||
def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: | ||||
coeffs = pyr.highpasses | ||||
flat_coeffs = [] | ||||
for band in coeffs: | ||||
for c in band: | ||||
flat_coeffs.append(c) | ||||
flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass)) | ||||
return flat_coeffs | ||||
|
||||
def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: | ||||
lowpass = np.array([x.real for x in X[-self.lowpass_size :]]).reshape( | ||||
(-1, self.first_dim) | ||||
) | ||||
_ptr = 0 | ||||
highpasses = () | ||||
for _sl in self.slices: | ||||
_h = X[_ptr : _ptr + (_sl * self.first_dim)] | ||||
_ptr += _sl * self.first_dim | ||||
_h = _h.reshape((-1, self.first_dim)) | ||||
highpasses += (_h,) | ||||
return dtcwt.Pyramid(lowpass, highpasses) | ||||
|
||||
def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid: | ||||
"""Return Pyramid object from transformed array""" | ||||
return self._array_to_coeff(X) | ||||
|
||||
@reshaped | ||||
def _matvec(self, x: NDArray) -> NDArray: | ||||
x = x.swapaxes(self.axis, -1) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @this is not good, you should use
When something is readily available and used in many other operators we should not deviate from it unless there is a special requirement, I do not see it here :) |
||||
x = self._nd_to_2d(x) | ||||
return self._coeff_to_array( | ||||
self._transform.forward(x, nlevels=self.nlevels, include_scale=False) | ||||
) | ||||
|
||||
@reshaped | ||||
def _rmatvec(self, x: NDArray) -> NDArray: | ||||
Y = self._transform.inverse(self._array_to_coeff(x)) | ||||
Y = self._2d_to_nd(Y) | ||||
return Y.swapaxes(self.axis, -1) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pylops.signalprocessing import DTCWT | ||
|
||
par1 = {"ny": 10, "nx": 10, "dtype": "float64"} | ||
par2 = {"ny": 50, "nx": 50, "dtype": "float64"} | ||
|
||
|
||
def sequential_array(shape): | ||
num_elements = np.prod(shape) | ||
seq_array = np.arange(1, num_elements + 1) | ||
result = seq_array.reshape(shape) | ||
return result | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_input1D(par): | ||
"""Test for DTCWT with 1D input""" | ||
|
||
t = sequential_array((par["ny"],)) | ||
|
||
for levels in range(1, 10): | ||
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_input2D(par): | ||
"""Test for DTCWT with 2D input""" | ||
|
||
t = sequential_array((par["ny"], par["ny"],)) | ||
|
||
for levels in range(1, 10): | ||
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_input3D(par): | ||
"""Test for DTCWT with 3D input""" | ||
|
||
t = sequential_array((par["ny"], par["ny"], par["ny"])) | ||
|
||
for levels in range(1, 10): | ||
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_birot(par): | ||
"""Test for DTCWT birot""" | ||
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"] | ||
|
||
t = sequential_array((par["ny"], par["ny"],)) | ||
|
||
for _b in birots: | ||
print(f"birot {_b}") | ||
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ spgl1 | |
scikit-fmm | ||
sympy | ||
devito | ||
dtcwt | ||
matplotlib | ||
ipython | ||
pytest | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,4 +26,5 @@ isort | |
black | ||
flake8 | ||
mypy | ||
pydata-sphinx-theme | ||
pydata-sphinx-theme | ||
dtcwt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ def src(pth): | |
"numba", | ||
"pyfftw", | ||
"PyWavelets", | ||
"dtcwt", | ||
"scikit-fmm", | ||
"spgl1", | ||
] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to go all the way to the end of the title