From 52e6c12794e3191c561b249357724663ce419d4d Mon Sep 17 00:00:00 2001 From: aniket singh rawat Date: Fri, 24 Feb 2023 02:09:49 +0530 Subject: [PATCH 1/7] added: DTCWT operator --- pylops/signalprocessing/__init__.py | 3 + pylops/signalprocessing/dtcwt.py | 145 ++++++++++++++++++++++++++++ requirements-dev.txt | 1 + requirements-doc.txt | 3 +- 4 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 pylops/signalprocessing/dtcwt.py diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index 2ce1fed1..5f8bbe10 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -23,6 +23,7 @@ DWT One dimensional Wavelet operator. DWT2D Two dimensional Wavelet operator. DCT Discrete Cosine Transform. + DTCWT Dual-Tree Complex Wavelet Transforms Seislet Two dimensional Seislet operator. Radon2D Two dimensional Radon transform. Radon3D Three dimensional Radon transform. @@ -60,6 +61,7 @@ from .dwt2d import * from .seislet import * from .dct import * +from .dtcwt import * __all__ = [ "FFT", @@ -89,4 +91,5 @@ "DWT2D", "Seislet", "DCT", + "DTCWT", ] diff --git a/pylops/signalprocessing/dtcwt.py b/pylops/signalprocessing/dtcwt.py new file mode 100644 index 00000000..e941c43c --- /dev/null +++ b/pylops/signalprocessing/dtcwt.py @@ -0,0 +1,145 @@ +__all__ = ["DTCWT"] + +from typing import Union + +import dtcwt +import numpy as np + +from pylops import LinearOperator +from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.decorators import reshaped +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray + + +class DTCWT(LinearOperator): + r""" + Perform Dual-Tree Complex Wavelet Transform on a given array. + + This operator wraps around :py:func:`dtcwt` package. + + Parameters + ---------- + dims: :obj:`int` or :obj:`tuple` + Number of samples for each dimension. + transform: :obj:`int`, optional + Type of transform 1D, 2D or 3D. Default is 1. + birot: :obj:`str`, optional + Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`. + qshift: :obj:`str`, optional + Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift()`. Default is `"qshift_a"` + nlevels: :obj:`int`, optional + Number of levels of wavelet decomposition. Default is 3. + include_scale: :obj:`bool`, optional + Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False. + dtype : :obj:`DTypeLike`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Raises + ------ + NotImplementedError + If ``transform`` is 2 or 3. + ValueError + If ``transform`` is anything other than 1, 2 or 3. + + Notes + ----- + The :py:func:`dtcwt` library uses a Pyramid object to represent the transform domain signal. + It has + `lowpass` (coarsest scale lowpass signal) + `highpasses` (complex subband coefficients for corresponding scales) + `scales` (lowpass signal for corresponding scales finest to coarsest) + + To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is + flattened out and all coeffs(highpasses and low pass signal) are appened into one array using the + `_coeff_to_array` method. + For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff` + method and then inverse is performed. + + """ + + def __init__( + self, + dims: Union[int, InputDimsLike], + transform: int = 1, + biort: str = "near_sym_a", + qshift: str = "qshift_a", + nlevels: int = 3, + include_scale: bool = False, + dtype: DTypeLike = "float64", + name: str = "C", + ) -> None: + self.dims = _value_or_sized_to_tuple(dims) + self.transform = transform + self.ndim = len(self.dims) + self.nlevels = nlevels + self.include_scale = include_scale + + if self.transform == 1: + self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) + elif self.transform == 2: + raise NotImplementedError("DTCWT is not implmented for 2D") + elif self.transform == 3: + raise NotImplementedError("DTCWT is not implmented for 3D") + else: + raise ValueError("DTCWT only supports 1D, 2D and 3D") + + pyr = self._transform.forward( + np.ones(self.dims), nlevels=self.nlevels, include_scale=True + ) + self.coeff_array_size = 0 + self.lowpass_size = len(pyr.lowpass) + self.slices = [] + for _h in pyr.highpasses: + self.slices.append(len(_h)) + self.coeff_array_size += len(_h) + self.coeff_array_size += self.lowpass_size + self.second_dim = 1 + if len(dims) > 1: + self.coeff_array_size *= self.dims[1] + self.lowpass_size *= self.dims[1] + self.second_dim = self.dims[1] + super().__init__( + dtype=np.dtype(dtype), + dims=self.dims, + dimsd=(self.coeff_array_size,), + name=name, + ) + + def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: + print("og lowpass ", pyr.lowpass) + print("og highpasses ", pyr.highpasses) + coeffs = pyr.highpasses + flat_coeffs = [] + for band in coeffs: + for c in band: + flat_coeffs.append(c) + flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass)) + return flat_coeffs + + def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: + lowpass = np.array([x.real for x in X[-1 * self.lowpass_size :]]).reshape( + (-1, self.second_dim) + ) + _ptr = 0 + highpasses = () + for _sl in self.slices: + _h = X[_ptr : _ptr + (_sl * self.second_dim)] + _ptr += _sl * self.second_dim + _h = _h.reshape((-1, self.second_dim)) + highpasses += (_h,) + return dtcwt.Pyramid(lowpass, highpasses) + + def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid: + return self._array_to_coeff(X) + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + return self._coeff_to_array( + self._transform.forward(x, nlevels=self.nlevels, include_scale=False) + ) + + @reshaped + def _rmatvec(self, X: NDArray) -> NDArray: + return self._transform.inverse(self._array_to_coeff(X)) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1b22f053..2c18e764 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -26,3 +26,4 @@ isort black flake8 mypy +dtcwt diff --git a/requirements-doc.txt b/requirements-doc.txt index d137b621..b0c0fa98 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -26,4 +26,5 @@ isort black flake8 mypy -pydata-sphinx-theme \ No newline at end of file +pydata-sphinx-theme +dtcwt From 4b7ec1e8b48c12d607aebc998190b4986a36ba22 Mon Sep 17 00:00:00 2001 From: aniket singh rawat Date: Fri, 24 Feb 2023 02:12:00 +0530 Subject: [PATCH 2/7] added: dtcwt playground file --- dtcwt_temp_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 dtcwt_temp_test.py diff --git a/dtcwt_temp_test.py b/dtcwt_temp_test.py new file mode 100644 index 00000000..72e9657a --- /dev/null +++ b/dtcwt_temp_test.py @@ -0,0 +1,27 @@ +#playground file +#this file will be removed after all changes in DTCWT are done +from pylops.signalprocessing import DTCWT +import numpy as np +import matplotlib.pyplot as plt + +n = 10 +nlevel = 4 + + +x = np.cumsum(np.random.rand(10, ) - 0.5, 0) + + + +DOp = DTCWT(dims=x.shape, nlevels=3) +y = DOp @ x + + + + + +i = DOp.H @ y + + +print("x ", x) +print("y ",y) +print("i ", i) \ No newline at end of file From f887031ff0aed0edac70bed9bf1388e2dbfb4b70 Mon Sep 17 00:00:00 2001 From: aniket singh rawat Date: Mon, 27 Feb 2023 21:04:08 +0530 Subject: [PATCH 3/7] added dtcwt for nd. fixed doc --- docs/source/api/index.rst | 1 + dtcwt_temp_test.py | 27 --------- pylops/signalprocessing/dtcwt.py | 98 ++++++++++++++++---------------- 3 files changed, 51 insertions(+), 75 deletions(-) delete mode 100644 dtcwt_temp_test.py diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index ea25d5bf..df8e64a8 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -102,6 +102,7 @@ Signal processing DWT DWT2D DCT + DTCWT Seislet Radon2D Radon3D diff --git a/dtcwt_temp_test.py b/dtcwt_temp_test.py deleted file mode 100644 index 72e9657a..00000000 --- a/dtcwt_temp_test.py +++ /dev/null @@ -1,27 +0,0 @@ -#playground file -#this file will be removed after all changes in DTCWT are done -from pylops.signalprocessing import DTCWT -import numpy as np -import matplotlib.pyplot as plt - -n = 10 -nlevel = 4 - - -x = np.cumsum(np.random.rand(10, ) - 0.5, 0) - - - -DOp = DTCWT(dims=x.shape, nlevels=3) -y = DOp @ x - - - - - -i = DOp.H @ y - - -print("x ", x) -print("y ",y) -print("i ", i) \ No newline at end of file diff --git a/pylops/signalprocessing/dtcwt.py b/pylops/signalprocessing/dtcwt.py index e941c43c..b04da564 100644 --- a/pylops/signalprocessing/dtcwt.py +++ b/pylops/signalprocessing/dtcwt.py @@ -12,8 +12,8 @@ class DTCWT(LinearOperator): - r""" - Perform Dual-Tree Complex Wavelet Transform on a given array. + r"""Dual-Tree Complex Wavelet Transform + Perform 1D Dual-Tree Complex Wavelet Transform on a given array. This operator wraps around :py:func:`dtcwt` package. @@ -21,8 +21,6 @@ class DTCWT(LinearOperator): ---------- dims: :obj:`int` or :obj:`tuple` Number of samples for each dimension. - transform: :obj:`int`, optional - Type of transform 1D, 2D or 3D. Default is 1. birot: :obj:`str`, optional Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`. qshift: :obj:`str`, optional @@ -31,62 +29,61 @@ class DTCWT(LinearOperator): Number of levels of wavelet decomposition. Default is 3. include_scale: :obj:`bool`, optional Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False. + axis: :obj:`int`, optional + Axis on which the transform is performed. Default is -1. dtype : :obj:`DTypeLike`, optional Type of elements in input array. name : :obj:`str`, optional Name of operator (to be used by :func:`pylops.utils.describe.describe`) - Raises - ------ - NotImplementedError - If ``transform`` is 2 or 3. - ValueError - If ``transform`` is anything other than 1, 2 or 3. Notes ----- - The :py:func:`dtcwt` library uses a Pyramid object to represent the transform domain signal. + The :py:func:`dtcwt` library uses a Pyramid object to represent the transformed domain signal. It has - `lowpass` (coarsest scale lowpass signal) - `highpasses` (complex subband coefficients for corresponding scales) - `scales` (lowpass signal for corresponding scales finest to coarsest) + - `lowpass` (coarsest scale lowpass signal) + - `highpasses` (complex subband coefficients for corresponding scales) + - `scales` (lowpass signal for corresponding scales finest to coarsest) To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is - flattened out and all coeffs(highpasses and low pass signal) are appened into one array using the + flattened out and all coefficents (high-pass and low pass coefficients) are appended into one array using the `_coeff_to_array` method. For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff` method and then inverse is performed. - """ def __init__( self, dims: Union[int, InputDimsLike], - transform: int = 1, biort: str = "near_sym_a", qshift: str = "qshift_a", nlevels: int = 3, include_scale: bool = False, + axis: int = -1, dtype: DTypeLike = "float64", name: str = "C", ) -> None: self.dims = _value_or_sized_to_tuple(dims) - self.transform = transform self.ndim = len(self.dims) self.nlevels = nlevels self.include_scale = include_scale + self.axis = axis + self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) + self._interpret_coeffs() + super().__init__( + dtype=np.dtype(dtype), + dims=self.dims, + dimsd=(self.coeff_array_size,), + name=name, + ) - if self.transform == 1: - self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) - elif self.transform == 2: - raise NotImplementedError("DTCWT is not implmented for 2D") - elif self.transform == 3: - raise NotImplementedError("DTCWT is not implmented for 3D") - else: - raise ValueError("DTCWT only supports 1D, 2D and 3D") - + def _interpret_coeffs(self): + T = np.ones(self.dims) + T = T.swapaxes(self.axis, -1) + self.swapped_dims = T.shape + T = self._nd_to_2d(T) pyr = self._transform.forward( - np.ones(self.dims), nlevels=self.nlevels, include_scale=True + T , nlevels=self.nlevels, include_scale=True ) self.coeff_array_size = 0 self.lowpass_size = len(pyr.lowpass) @@ -95,21 +92,20 @@ def __init__( self.slices.append(len(_h)) self.coeff_array_size += len(_h) self.coeff_array_size += self.lowpass_size - self.second_dim = 1 - if len(dims) > 1: - self.coeff_array_size *= self.dims[1] - self.lowpass_size *= self.dims[1] - self.second_dim = self.dims[1] - super().__init__( - dtype=np.dtype(dtype), - dims=self.dims, - dimsd=(self.coeff_array_size,), - name=name, - ) + elements = np.prod(T.shape[1:]) + self.coeff_array_size *= elements + self.lowpass_size *= elements + self.first_dim = elements + + def _nd_to_2d(self, arr_nd): + arr_2d = arr_nd.reshape((self.dims[0], -1)) + return arr_2d + + def _2d_to_nd(self, arr_2d): + arr_nd = arr_2d.reshape(self.swapped_dims) + return arr_nd def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: - print("og lowpass ", pyr.lowpass) - print("og highpasses ", pyr.highpasses) coeffs = pyr.highpasses flat_coeffs = [] for band in coeffs: @@ -119,27 +115,33 @@ def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: return flat_coeffs def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: - lowpass = np.array([x.real for x in X[-1 * self.lowpass_size :]]).reshape( - (-1, self.second_dim) + lowpass = np.array([x.real for x in X[-self.lowpass_size :]]).reshape( + (-1, self.first_dim) ) _ptr = 0 highpasses = () for _sl in self.slices: - _h = X[_ptr : _ptr + (_sl * self.second_dim)] - _ptr += _sl * self.second_dim - _h = _h.reshape((-1, self.second_dim)) + _h = X[_ptr : _ptr + (_sl * self.first_dim)] + _ptr += _sl * self.first_dim + _h = _h.reshape((-1, self.first_dim)) highpasses += (_h,) return dtcwt.Pyramid(lowpass, highpasses) def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid: + """Return Pyramid object from transformed array + """ return self._array_to_coeff(X) @reshaped def _matvec(self, x: NDArray) -> NDArray: + x = x.swapaxes(self.axis, -1) + x = self._nd_to_2d(x) return self._coeff_to_array( self._transform.forward(x, nlevels=self.nlevels, include_scale=False) ) @reshaped - def _rmatvec(self, X: NDArray) -> NDArray: - return self._transform.inverse(self._array_to_coeff(X)) + def _rmatvec(self, x: NDArray) -> NDArray: + Y = self._transform.inverse(self._array_to_coeff(x)) + Y = self._2d_to_nd(Y) + return Y.swapaxes(self.axis, -1) From 7435b519c83e0e379b45028d14020d084cc7b867 Mon Sep 17 00:00:00 2001 From: aniket singh rawat Date: Mon, 27 Feb 2023 21:52:42 +0530 Subject: [PATCH 4/7] added dtcwt tests --- pytests/test_dtcwt.py | 72 +++++++++++++++++++++++++++++++++++++++++++ setup.cfg | 2 +- 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 pytests/test_dtcwt.py diff --git a/pytests/test_dtcwt.py b/pytests/test_dtcwt.py new file mode 100644 index 00000000..e6a0a3fb --- /dev/null +++ b/pytests/test_dtcwt.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest + +from pylops.signalprocessing import DTCWT + +par1 = {"ny": 10, "nx": 10, "dtype": "float64"} +par2 = {"ny": 50, "nx": 50, "dtype": "float64"} + + +def sequential_array(shape): + num_elements = np.prod(shape) + seq_array = np.arange(1, num_elements + 1) + result = seq_array.reshape(shape) + return result + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_input1D(par): + """Test for DTCWT with 1D input""" + + t = sequential_array((par["ny"],)) + + for levels in range(1, 10): + Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_input2D(par): + """Test for DTCWT with 2D input""" + + t = sequential_array((par["ny"], par["ny"],)) + + for levels in range(1, 10): + Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_input3D(par): + """Test for DTCWT with 3D input""" + + t = sequential_array((par["ny"], par["ny"], par["ny"])) + + for levels in range(1, 10): + Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) + + +@pytest.mark.parametrize("par", [(par1), (par2)]) +def test_dtcwt1D_birot(par): + """Test for DTCWT birot""" + birots = ["antonini", "legall", "near_sym_a", "near_sym_b"] + + t = sequential_array((par["ny"], par["ny"],)) + + for _b in birots: + print(f"birot {_b}") + Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"]) + x = Dtcwt @ t + y = Dtcwt.H @ x + + np.testing.assert_allclose(t, y) diff --git a/setup.cfg b/setup.cfg index 1691b3bc..bd564d6e 100755 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ test=pytest [tool:pytest] addopts = --verbose -python_files = pytests/*.py +python_files = pytests/test_dtcwt.py [flake8] ignore = E203, E501, W503, E402 From 22e6780e1318a2d51e5d7c513508d1c4e00ecae5 Mon Sep 17 00:00:00 2001 From: aniket singh rawat Date: Mon, 27 Feb 2023 23:09:24 +0530 Subject: [PATCH 5/7] revert setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index bd564d6e..1691b3bc 100755 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ test=pytest [tool:pytest] addopts = --verbose -python_files = pytests/test_dtcwt.py +python_files = pytests/*.py [flake8] ignore = E203, E501, W503, E402 From 60f9ca0a32451933e2f5c1944faa3f6b80e5a0ef Mon Sep 17 00:00:00 2001 From: aniket singh rawat Date: Wed, 1 Mar 2023 22:10:43 +0530 Subject: [PATCH 6/7] added dtcwt example --- examples/plot_dtcwt.py | 74 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 examples/plot_dtcwt.py diff --git a/examples/plot_dtcwt.py b/examples/plot_dtcwt.py new file mode 100644 index 00000000..100ba899 --- /dev/null +++ b/examples/plot_dtcwt.py @@ -0,0 +1,74 @@ +""" +Dual-Tree Complex Wavelet Transform +========================= +This example shows how to use the :py:class:`pylops.signalprocessing.DCT` operator. +This operator performs the 1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) +input array. +""" + +import matplotlib.pyplot as plt +import numpy as np + +import pylops + +plt.close("all") + +############################################################################### +# Let's define a 1D array x of having random values + +n = 50 +x = np.random.rand(n,) + +############################################################################### +# We create the DTCWT operator with shape of our input array. DTCWT transform +# gives a Pyramid object that is flattened out `y`. + +DOp = pylops.signalprocessing.DTCWT(dims=x.shape) +y = DOp @ x +xadj = DOp.H @ y + +plt.figure(figsize=(8, 5)) +plt.plot(x, "k", label="input array") +plt.plot(y, "r", label="transformed array") +plt.plot(xadj, "--b", label="transformed array") +plt.title("Dual-Tree Complex Wavelet Transform 1D") +plt.legend() +plt.tight_layout() + +################################################################################# +# To get the Pyramid object use the `get_pyramid` method. +# We can get the Highpass signal and Lowpass signal from it + +pyr = DOp.get_pyramid(y) + +plt.figure(figsize=(10, 5)) +plt.plot(x, "--b", label="orignal signal") +plt.plot(pyr.lowpass, "k", label="lowpass") +plt.plot(pyr.highpasses[0], "r", label="highpass level 1 signal") +plt.plot(pyr.highpasses[1], "b", label="highpass level 2 signal") +plt.plot(pyr.highpasses[2], "g", label="highpass level 3 signal") + +plt.title("DTCWT Pyramid Object") +plt.legend() +plt.tight_layout() + +################################################################################### +# DTCWT can also be performed on multi-dimension arrays. The number of levels can also +# be defined using the `nlevels` + +n = 10 +m = 2 + +x = np.random.rand(n, m) + +DOp = pylops.signalprocessing.DTCWT(dims=x.shape, nlevels=5) +y = DOp @ x +xadj = DOp.H @ y + +plt.figure(figsize=(8, 5)) +plt.plot(x, "k", label="input array") +plt.plot(y, "r", label="transformed array") +plt.plot(xadj, "--b", label="transformed array") +plt.title("Dual-Tree Complex Wavelet Transform 1D on ND array") +plt.legend() +plt.tight_layout() From b0e37c99996e93effaa4d42582726b4ba4adba56 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 8 Mar 2023 22:55:04 +0300 Subject: [PATCH 7/7] minor: added dtcwt to env and setup files and installation instructions --- docs/source/installation.rst | 10 +++++ environment-dev.yml | 1 + pylops/signalprocessing/dct.py | 2 +- pylops/signalprocessing/dtcwt.py | 76 ++++++++++++++++++-------------- requirements-dev.txt | 2 +- setup.py | 1 + 6 files changed, 56 insertions(+), 36 deletions(-) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 2e22eb2d..16bb7806 100755 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -318,6 +318,16 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio In alphabetic order: +dtcwt +----- +`dtcwt `_ is used to implement the DT-CWT operators. +Install it via ``pip`` with: + +.. code-block:: bash + + >> pip install dtcwt + + Devito ------ `Devito `_ is library used to solve PDEs via diff --git a/environment-dev.yml b/environment-dev.yml index f00830c2..21555120 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -25,6 +25,7 @@ dependencies: - isort - black - pip: + - dtcwt - devito - scikit-fmm - spgl1 diff --git a/pylops/signalprocessing/dct.py b/pylops/signalprocessing/dct.py index eb46e872..1a336be6 100644 --- a/pylops/signalprocessing/dct.py +++ b/pylops/signalprocessing/dct.py @@ -29,7 +29,7 @@ class DCT(LinearOperator): axes : :obj:`int` or :obj:`list`, optional Axes over which the DCT is computed. If ``None``, the transform is applied over all axes. - workers :obj:`int`, optional + workers : :obj:`int`, optional Maximum number of workers to use for parallel computation. If negative, the value wraps around from os.cpu_count(). dtype : :obj:`DTypeLike`, optional diff --git a/pylops/signalprocessing/dtcwt.py b/pylops/signalprocessing/dtcwt.py index b04da564..4d227259 100644 --- a/pylops/signalprocessing/dtcwt.py +++ b/pylops/signalprocessing/dtcwt.py @@ -13,43 +13,53 @@ class DTCWT(LinearOperator): r"""Dual-Tree Complex Wavelet Transform - Perform 1D Dual-Tree Complex Wavelet Transform on a given array. - This operator wraps around :py:func:`dtcwt` package. + Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a + multi-dimensional array of size ``dims``. + + Note that the DTCWT operator is an overload of the ``dtcwt`` + implementation of the DT-CWT transform. Refer to + https://dtcwt.readthedocs.io for a detailed description of the + input parameters. Parameters ---------- - dims: :obj:`int` or :obj:`tuple` + dims : :obj:`int` or :obj:`tuple` Number of samples for each dimension. - birot: :obj:`str`, optional - Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`. - qshift: :obj:`str`, optional - Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift()`. Default is `"qshift_a"` - nlevels: :obj:`int`, optional + birot : :obj:`str`, optional + Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`. + qshift : :obj:`str`, optional + Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"` + nlevels : :obj:`int`, optional Number of levels of wavelet decomposition. Default is 3. - include_scale: :obj:`bool`, optional - Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False. - axis: :obj:`int`, optional - Axis on which the transform is performed. Default is -1. + include_scale : :obj:`bool`, optional + Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False. + axis : :obj:`int`, optional + Axis on which the transform is performed. dtype : :obj:`DTypeLike`, optional Type of elements in input array. name : :obj:`str`, optional Name of operator (to be used by :func:`pylops.utils.describe.describe`) - Notes ----- - The :py:func:`dtcwt` library uses a Pyramid object to represent the transformed domain signal. - It has - - `lowpass` (coarsest scale lowpass signal) - - `highpasses` (complex subband coefficients for corresponding scales) - - `scales` (lowpass signal for corresponding scales finest to coarsest) - - To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is - flattened out and all coefficents (high-pass and low pass coefficients) are appended into one array using the - `_coeff_to_array` method. - For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff` - method and then inverse is performed. + The DTCWT operator applies the dual-tree complex wavelet transform + in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode + from the ``dtcwt`` library. + + The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain, + which is composed of: + - `lowpass` (coarsest scale lowpass signal); + - `highpasses` (complex subband coefficients for corresponding scales); + - `scales` (lowpass signal for corresponding scales finest to coarsest). + + To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model + the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients) + are appended into one array using the `_coeff_to_array` method. + + In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff` + method and then the inverse transform is performed. + """ def __init__( @@ -68,6 +78,7 @@ def __init__( self.nlevels = nlevels self.include_scale = include_scale self.axis = axis + # dry-run of transform to self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) self._interpret_coeffs() super().__init__( @@ -78,13 +89,11 @@ def __init__( ) def _interpret_coeffs(self): - T = np.ones(self.dims) - T = T.swapaxes(self.axis, -1) - self.swapped_dims = T.shape - T = self._nd_to_2d(T) - pyr = self._transform.forward( - T , nlevels=self.nlevels, include_scale=True - ) + x = np.ones(self.dims) + x = x.swapaxes(self.axis, -1) + self.swapped_dims = x.shape + x = self._nd_to_2d(x) + pyr = self._transform.forward(x, nlevels=self.nlevels, include_scale=True) self.coeff_array_size = 0 self.lowpass_size = len(pyr.lowpass) self.slices = [] @@ -92,7 +101,7 @@ def _interpret_coeffs(self): self.slices.append(len(_h)) self.coeff_array_size += len(_h) self.coeff_array_size += self.lowpass_size - elements = np.prod(T.shape[1:]) + elements = np.prod(x.shape[1:]) self.coeff_array_size *= elements self.lowpass_size *= elements self.first_dim = elements @@ -128,8 +137,7 @@ def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: return dtcwt.Pyramid(lowpass, highpasses) def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid: - """Return Pyramid object from transformed array - """ + """Return Pyramid object from transformed array""" return self._array_to_coeff(X) @reshaped diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c18e764..8ba1f8b0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ spgl1 scikit-fmm sympy devito +dtcwt matplotlib ipython pytest @@ -26,4 +27,3 @@ isort black flake8 mypy -dtcwt diff --git a/setup.py b/setup.py index 8b82afa2..74a58404 100755 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def src(pth): "numba", "pyfftw", "PyWavelets", + "dtcwt", "scikit-fmm", "spgl1", ]