Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DTCWT operator #495

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Signal processing
DWT
DWT2D
DCT
DTCWT
Seislet
Radon2D
Radon3D
Expand Down
10 changes: 10 additions & 0 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio
In alphabetic order:


dtcwt
-----
`dtcwt <https://dtcwt.readthedocs.io/en/0.12.0/>`_ is used to implement the DT-CWT operators.
Install it via ``pip`` with:

.. code-block:: bash

>> pip install dtcwt


Devito
------
`Devito <https://github.com/devitocodes/devito>`_ is library used to solve PDEs via
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- isort
- black
- pip:
- dtcwt
- devito
- scikit-fmm
- spgl1
Expand Down
74 changes: 74 additions & 0 deletions examples/plot_dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Dual-Tree Complex Wavelet Transform
=========================
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go all the way to the end of the title

This example shows how to use the :py:class:`pylops.signalprocessing.DCT` operator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be pylops.signalprocessing.DTCWT

This operator performs the 1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional)
input array.
"""

import matplotlib.pyplot as plt
import numpy as np

import pylops

plt.close("all")

###############################################################################
# Let's define a 1D array x of having random values

n = 50
x = np.random.rand(n,)

###############################################################################
# We create the DTCWT operator with shape of our input array. DTCWT transform
# gives a Pyramid object that is flattened out `y`.

DOp = pylops.signalprocessing.DTCWT(dims=x.shape)
y = DOp @ x
xadj = DOp.H @ y

plt.figure(figsize=(8, 5))
plt.plot(x, "k", label="input array")
plt.plot(y, "r", label="transformed array")
plt.plot(xadj, "--b", label="transformed array")
plt.title("Dual-Tree Complex Wavelet Transform 1D")
plt.legend()
plt.tight_layout()

#################################################################################
# To get the Pyramid object use the `get_pyramid` method.
# We can get the Highpass signal and Lowpass signal from it

pyr = DOp.get_pyramid(y)

plt.figure(figsize=(10, 5))
plt.plot(x, "--b", label="orignal signal")
plt.plot(pyr.lowpass, "k", label="lowpass")
plt.plot(pyr.highpasses[0], "r", label="highpass level 1 signal")
plt.plot(pyr.highpasses[1], "b", label="highpass level 2 signal")
plt.plot(pyr.highpasses[2], "g", label="highpass level 3 signal")

plt.title("DTCWT Pyramid Object")
plt.legend()
plt.tight_layout()

###################################################################################
# DTCWT can also be performed on multi-dimension arrays. The number of levels can also
# be defined using the `nlevels`

n = 10
m = 2

x = np.random.rand(n, m)

DOp = pylops.signalprocessing.DTCWT(dims=x.shape, nlevels=5)
y = DOp @ x
xadj = DOp.H @ y

plt.figure(figsize=(8, 5))
plt.plot(x, "k", label="input array")
plt.plot(y, "r", label="transformed array")
plt.plot(xadj, "--b", label="transformed array")
plt.title("Dual-Tree Complex Wavelet Transform 1D on ND array")
plt.legend()
plt.tight_layout()
3 changes: 3 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transforms
Seislet Two dimensional Seislet operator.
Radon2D Two dimensional Radon transform.
Radon3D Three dimensional Radon transform.
Expand Down Expand Up @@ -60,6 +61,7 @@
from .dwt2d import *
from .seislet import *
from .dct import *
from .dtcwt import *

__all__ = [
"FFT",
Expand Down Expand Up @@ -89,4 +91,5 @@
"DWT2D",
"Seislet",
"DCT",
"DTCWT",
]
2 changes: 1 addition & 1 deletion pylops/signalprocessing/dct.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DCT(LinearOperator):
axes : :obj:`int` or :obj:`list`, optional
Axes over which the DCT is computed. If ``None``, the transform is applied
over all axes.
workers :obj:`int`, optional
workers : :obj:`int`, optional
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from os.cpu_count().
dtype : :obj:`DTypeLike`, optional
Expand Down
155 changes: 155 additions & 0 deletions pylops/signalprocessing/dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
__all__ = ["DTCWT"]

from typing import Union

import dtcwt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be wrapped in a pylops.utils.deps check like here:

https://github.com/PyLops/pylops/blob/dev/pylops/_torchoperator.py#LL5-LL10

import numpy as np

from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray


class DTCWT(LinearOperator):
r"""Dual-Tree Complex Wavelet Transform

Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a
multi-dimensional array of size ``dims``.

Note that the DTCWT operator is an overload of the ``dtcwt``
implementation of the DT-CWT transform. Refer to
https://dtcwt.readthedocs.io for a detailed description of the
input parameters.

Parameters
----------
dims : :obj:`int` or :obj:`tuple`
Number of samples for each dimension.
birot : :obj:`str`, optional
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`.
qshift : :obj:`str`, optional
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"`
nlevels : :obj:`int`, optional
Number of levels of wavelet decomposition. Default is 3.
include_scale : :obj:`bool`, optional
Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False.
axis : :obj:`int`, optional
Axis on which the transform is performed.
dtype : :obj:`DTypeLike`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)

Notes
-----
The DTCWT operator applies the dual-tree complex wavelet transform
in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode
from the ``dtcwt`` library.

The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain,
which is composed of:
- `lowpass` (coarsest scale lowpass signal);
- `highpasses` (complex subband coefficients for corresponding scales);
- `scales` (lowpass signal for corresponding scales finest to coarsest).

To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model
the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients)
are appended into one array using the `_coeff_to_array` method.

In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff`
method and then the inverse transform is performed.

"""

def __init__(
self,
dims: Union[int, InputDimsLike],
biort: str = "near_sym_a",
qshift: str = "qshift_a",
nlevels: int = 3,
include_scale: bool = False,
axis: int = -1,
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
self.dims = _value_or_sized_to_tuple(dims)
self.ndim = len(self.dims)
self.nlevels = nlevels
self.include_scale = include_scale
self.axis = axis
# dry-run of transform to
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift)
self._interpret_coeffs()
super().__init__(
dtype=np.dtype(dtype),
dims=self.dims,
dimsd=(self.coeff_array_size,),
name=name,
)

def _interpret_coeffs(self):
x = np.ones(self.dims)
x = x.swapaxes(self.axis, -1)
self.swapped_dims = x.shape
x = self._nd_to_2d(x)
pyr = self._transform.forward(x, nlevels=self.nlevels, include_scale=True)
self.coeff_array_size = 0
self.lowpass_size = len(pyr.lowpass)
self.slices = []
for _h in pyr.highpasses:
self.slices.append(len(_h))
self.coeff_array_size += len(_h)
self.coeff_array_size += self.lowpass_size
elements = np.prod(x.shape[1:])
self.coeff_array_size *= elements
self.lowpass_size *= elements
self.first_dim = elements

def _nd_to_2d(self, arr_nd):
arr_2d = arr_nd.reshape((self.dims[0], -1))
return arr_2d

def _2d_to_nd(self, arr_2d):
arr_nd = arr_2d.reshape(self.swapped_dims)
return arr_nd

def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray:
coeffs = pyr.highpasses
flat_coeffs = []
for band in coeffs:
for c in band:
flat_coeffs.append(c)
flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass))
return flat_coeffs

def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid:
lowpass = np.array([x.real for x in X[-self.lowpass_size :]]).reshape(
(-1, self.first_dim)
)
_ptr = 0
highpasses = ()
for _sl in self.slices:
_h = X[_ptr : _ptr + (_sl * self.first_dim)]
_ptr += _sl * self.first_dim
_h = _h.reshape((-1, self.first_dim))
highpasses += (_h,)
return dtcwt.Pyramid(lowpass, highpasses)

def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid:
"""Return Pyramid object from transformed array"""
return self._array_to_coeff(X)

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
x = x.swapaxes(self.axis, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@this is not good, you should use swapaxes like here

@reshaped(swapaxis=True)

When something is readily available and used in many other operators we should not deviate from it unless there is a special requirement, I do not see it here :)

x = self._nd_to_2d(x)
return self._coeff_to_array(
self._transform.forward(x, nlevels=self.nlevels, include_scale=False)
)

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
Y = self._transform.inverse(self._array_to_coeff(x))
Y = self._2d_to_nd(Y)
return Y.swapaxes(self.axis, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

72 changes: 72 additions & 0 deletions pytests/test_dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
import pytest

from pylops.signalprocessing import DTCWT

par1 = {"ny": 10, "nx": 10, "dtype": "float64"}
par2 = {"ny": 50, "nx": 50, "dtype": "float64"}


def sequential_array(shape):
num_elements = np.prod(shape)
seq_array = np.arange(1, num_elements + 1)
result = seq_array.reshape(shape)
return result


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input1D(par):
"""Test for DTCWT with 1D input"""

t = sequential_array((par["ny"],))

for levels in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input2D(par):
"""Test for DTCWT with 2D input"""

t = sequential_array((par["ny"], par["ny"],))

for levels in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input3D(par):
"""Test for DTCWT with 3D input"""

t = sequential_array((par["ny"], par["ny"], par["ny"]))

for levels in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_birot(par):
"""Test for DTCWT birot"""
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]

t = sequential_array((par["ny"], par["ny"],))

for _b in birots:
print(f"birot {_b}")
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ spgl1
scikit-fmm
sympy
devito
dtcwt
matplotlib
ipython
pytest
Expand Down
3 changes: 2 additions & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ isort
black
flake8
mypy
pydata-sphinx-theme
pydata-sphinx-theme
dtcwt
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def src(pth):
"numba",
"pyfftw",
"PyWavelets",
"dtcwt",
"scikit-fmm",
"spgl1",
]
Expand Down