Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DTCWT operator #495

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions dtcwt_temp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#playground file
#this file will be removed after all changes in DTCWT are done
from pylops.signalprocessing import DTCWT
import numpy as np
import matplotlib.pyplot as plt

n = 10
nlevel = 4


x = np.cumsum(np.random.rand(10, ) - 0.5, 0)



DOp = DTCWT(dims=x.shape, nlevels=3)
y = DOp @ x





i = DOp.H @ y


print("x ", x)
print("y ",y)
print("i ", i)
3 changes: 3 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transforms
Seislet Two dimensional Seislet operator.
Radon2D Two dimensional Radon transform.
Radon3D Three dimensional Radon transform.
Expand Down Expand Up @@ -60,6 +61,7 @@
from .dwt2d import *
from .seislet import *
from .dct import *
from .dtcwt import *

__all__ = [
"FFT",
Expand Down Expand Up @@ -89,4 +91,5 @@
"DWT2D",
"Seislet",
"DCT",
"DTCWT",
]
145 changes: 145 additions & 0 deletions pylops/signalprocessing/dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
__all__ = ["DTCWT"]

from typing import Union

import dtcwt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be wrapped in a pylops.utils.deps check like here:

https://github.com/PyLops/pylops/blob/dev/pylops/_torchoperator.py#LL5-LL10

import numpy as np

from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray


class DTCWT(LinearOperator):
r"""
cako marked this conversation as resolved.
Show resolved Hide resolved
Perform Dual-Tree Complex Wavelet Transform on a given array.

This operator wraps around :py:func:`dtcwt` package.

Parameters
----------
dims: :obj:`int` or :obj:`tuple`
Number of samples for each dimension.
transform: :obj:`int`, optional
Type of transform 1D, 2D or 3D. Default is 1.
birot: :obj:`str`, optional
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`.
qshift: :obj:`str`, optional
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift()`. Default is `"qshift_a"`
nlevels: :obj:`int`, optional
Number of levels of wavelet decomposition. Default is 3.
include_scale: :obj:`bool`, optional
Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False.
dtype : :obj:`DTypeLike`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)

Raises
------
NotImplementedError
cako marked this conversation as resolved.
Show resolved Hide resolved
If ``transform`` is 2 or 3.
ValueError
cako marked this conversation as resolved.
Show resolved Hide resolved
If ``transform`` is anything other than 1, 2 or 3.

Notes
-----
The :py:func:`dtcwt` library uses a Pyramid object to represent the transform domain signal.
cako marked this conversation as resolved.
Show resolved Hide resolved
It has
cako marked this conversation as resolved.
Show resolved Hide resolved
`lowpass` (coarsest scale lowpass signal)
`highpasses` (complex subband coefficients for corresponding scales)
`scales` (lowpass signal for corresponding scales finest to coarsest)

To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is
cako marked this conversation as resolved.
Show resolved Hide resolved
flattened out and all coeffs(highpasses and low pass signal) are appened into one array using the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coeffs(highpasses and low pass signal) -> coefficents (high-pass and low pass coefficients)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated f887031

`_coeff_to_array` method.
For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff`
method and then inverse is performed.

"""

def __init__(
self,
dims: Union[int, InputDimsLike],
transform: int = 1,
biort: str = "near_sym_a",
qshift: str = "qshift_a",
nlevels: int = 3,
include_scale: bool = False,
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
self.dims = _value_or_sized_to_tuple(dims)
self.transform = transform
self.ndim = len(self.dims)
self.nlevels = nlevels
self.include_scale = include_scale

if self.transform == 1:
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift)
elif self.transform == 2:
raise NotImplementedError("DTCWT is not implmented for 2D")
elif self.transform == 3:
raise NotImplementedError("DTCWT is not implmented for 3D")
else:
raise ValueError("DTCWT only supports 1D, 2D and 3D")

pyr = self._transform.forward(
cako marked this conversation as resolved.
Show resolved Hide resolved
np.ones(self.dims), nlevels=self.nlevels, include_scale=True
)
self.coeff_array_size = 0
self.lowpass_size = len(pyr.lowpass)
self.slices = []
for _h in pyr.highpasses:
self.slices.append(len(_h))
self.coeff_array_size += len(_h)
self.coeff_array_size += self.lowpass_size
self.second_dim = 1
if len(dims) > 1:
cako marked this conversation as resolved.
Show resolved Hide resolved
self.coeff_array_size *= self.dims[1]
self.lowpass_size *= self.dims[1]
self.second_dim = self.dims[1]
super().__init__(
dtype=np.dtype(dtype),
dims=self.dims,
dimsd=(self.coeff_array_size,),
name=name,
)

def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray:
print("og lowpass ", pyr.lowpass)
cako marked this conversation as resolved.
Show resolved Hide resolved
print("og highpasses ", pyr.highpasses)
coeffs = pyr.highpasses
flat_coeffs = []
for band in coeffs:
for c in band:
flat_coeffs.append(c)
flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass))
return flat_coeffs

def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid:
lowpass = np.array([x.real for x in X[-1 * self.lowpass_size :]]).reshape(
cako marked this conversation as resolved.
Show resolved Hide resolved
(-1, self.second_dim)
)
_ptr = 0
highpasses = ()
for _sl in self.slices:
_h = X[_ptr : _ptr + (_sl * self.second_dim)]
_ptr += _sl * self.second_dim
_h = _h.reshape((-1, self.second_dim))
highpasses += (_h,)
return dtcwt.Pyramid(lowpass, highpasses)

def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid:
return self._array_to_coeff(X)

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
return self._coeff_to_array(
self._transform.forward(x, nlevels=self.nlevels, include_scale=False)
)

@reshaped
def _rmatvec(self, X: NDArray) -> NDArray:
return self._transform.inverse(self._array_to_coeff(X))
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ isort
black
flake8
mypy
dtcwt
3 changes: 2 additions & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ isort
black
flake8
mypy
pydata-sphinx-theme
pydata-sphinx-theme
dtcwt