-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added DTCWT
operator
#495
Closed
Closed
Added DTCWT
operator
#495
Changes from 2 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
52e6c12
added: DTCWT operator
4b7ec1e
added: dtcwt playground file
f887031
added dtcwt for nd. fixed doc
7435b51
added dtcwt tests
22e6780
revert setup.cfg
60f9ca0
added dtcwt example
b0e37c9
minor: added dtcwt to env and setup files and installation instructions
mrava87 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#playground file | ||
#this file will be removed after all changes in DTCWT are done | ||
from pylops.signalprocessing import DTCWT | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
n = 10 | ||
nlevel = 4 | ||
|
||
|
||
x = np.cumsum(np.random.rand(10, ) - 0.5, 0) | ||
|
||
|
||
|
||
DOp = DTCWT(dims=x.shape, nlevels=3) | ||
y = DOp @ x | ||
|
||
|
||
|
||
|
||
|
||
i = DOp.H @ y | ||
|
||
|
||
print("x ", x) | ||
print("y ",y) | ||
print("i ", i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
__all__ = ["DTCWT"] | ||
|
||
from typing import Union | ||
|
||
import dtcwt | ||
import numpy as np | ||
|
||
from pylops import LinearOperator | ||
from pylops.utils._internal import _value_or_sized_to_tuple | ||
from pylops.utils.decorators import reshaped | ||
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray | ||
|
||
|
||
class DTCWT(LinearOperator): | ||
r""" | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Perform Dual-Tree Complex Wavelet Transform on a given array. | ||
|
||
This operator wraps around :py:func:`dtcwt` package. | ||
|
||
Parameters | ||
---------- | ||
dims: :obj:`int` or :obj:`tuple` | ||
Number of samples for each dimension. | ||
transform: :obj:`int`, optional | ||
Type of transform 1D, 2D or 3D. Default is 1. | ||
birot: :obj:`str`, optional | ||
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`. | ||
qshift: :obj:`str`, optional | ||
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift()`. Default is `"qshift_a"` | ||
nlevels: :obj:`int`, optional | ||
Number of levels of wavelet decomposition. Default is 3. | ||
include_scale: :obj:`bool`, optional | ||
Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False. | ||
dtype : :obj:`DTypeLike`, optional | ||
Type of elements in input array. | ||
name : :obj:`str`, optional | ||
Name of operator (to be used by :func:`pylops.utils.describe.describe`) | ||
|
||
Raises | ||
------ | ||
NotImplementedError | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
If ``transform`` is 2 or 3. | ||
ValueError | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
If ``transform`` is anything other than 1, 2 or 3. | ||
|
||
Notes | ||
----- | ||
The :py:func:`dtcwt` library uses a Pyramid object to represent the transform domain signal. | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
It has | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
`lowpass` (coarsest scale lowpass signal) | ||
`highpasses` (complex subband coefficients for corresponding scales) | ||
`scales` (lowpass signal for corresponding scales finest to coarsest) | ||
|
||
To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
flattened out and all coeffs(highpasses and low pass signal) are appened into one array using the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. coeffs(highpasses and low pass signal) -> coefficents (high-pass and low pass coefficients) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated f887031 |
||
`_coeff_to_array` method. | ||
For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff` | ||
method and then inverse is performed. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
dims: Union[int, InputDimsLike], | ||
transform: int = 1, | ||
biort: str = "near_sym_a", | ||
qshift: str = "qshift_a", | ||
nlevels: int = 3, | ||
include_scale: bool = False, | ||
dtype: DTypeLike = "float64", | ||
name: str = "C", | ||
) -> None: | ||
self.dims = _value_or_sized_to_tuple(dims) | ||
self.transform = transform | ||
self.ndim = len(self.dims) | ||
self.nlevels = nlevels | ||
self.include_scale = include_scale | ||
|
||
if self.transform == 1: | ||
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) | ||
elif self.transform == 2: | ||
raise NotImplementedError("DTCWT is not implmented for 2D") | ||
elif self.transform == 3: | ||
raise NotImplementedError("DTCWT is not implmented for 3D") | ||
else: | ||
raise ValueError("DTCWT only supports 1D, 2D and 3D") | ||
|
||
pyr = self._transform.forward( | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np.ones(self.dims), nlevels=self.nlevels, include_scale=True | ||
) | ||
self.coeff_array_size = 0 | ||
self.lowpass_size = len(pyr.lowpass) | ||
self.slices = [] | ||
for _h in pyr.highpasses: | ||
self.slices.append(len(_h)) | ||
self.coeff_array_size += len(_h) | ||
self.coeff_array_size += self.lowpass_size | ||
self.second_dim = 1 | ||
if len(dims) > 1: | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.coeff_array_size *= self.dims[1] | ||
self.lowpass_size *= self.dims[1] | ||
self.second_dim = self.dims[1] | ||
super().__init__( | ||
dtype=np.dtype(dtype), | ||
dims=self.dims, | ||
dimsd=(self.coeff_array_size,), | ||
name=name, | ||
) | ||
|
||
def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: | ||
print("og lowpass ", pyr.lowpass) | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
print("og highpasses ", pyr.highpasses) | ||
coeffs = pyr.highpasses | ||
flat_coeffs = [] | ||
for band in coeffs: | ||
for c in band: | ||
flat_coeffs.append(c) | ||
flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass)) | ||
return flat_coeffs | ||
|
||
def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: | ||
lowpass = np.array([x.real for x in X[-1 * self.lowpass_size :]]).reshape( | ||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(-1, self.second_dim) | ||
) | ||
_ptr = 0 | ||
highpasses = () | ||
for _sl in self.slices: | ||
_h = X[_ptr : _ptr + (_sl * self.second_dim)] | ||
_ptr += _sl * self.second_dim | ||
_h = _h.reshape((-1, self.second_dim)) | ||
highpasses += (_h,) | ||
return dtcwt.Pyramid(lowpass, highpasses) | ||
|
||
def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid: | ||
return self._array_to_coeff(X) | ||
|
||
@reshaped | ||
def _matvec(self, x: NDArray) -> NDArray: | ||
return self._coeff_to_array( | ||
self._transform.forward(x, nlevels=self.nlevels, include_scale=False) | ||
) | ||
|
||
@reshaped | ||
def _rmatvec(self, X: NDArray) -> NDArray: | ||
return self._transform.inverse(self._array_to_coeff(X)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,3 +26,4 @@ isort | |
black | ||
flake8 | ||
mypy | ||
dtcwt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,4 +26,5 @@ isort | |
black | ||
flake8 | ||
mypy | ||
pydata-sphinx-theme | ||
pydata-sphinx-theme | ||
dtcwt |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be wrapped in a
pylops.utils.deps
check like here:https://github.com/PyLops/pylops/blob/dev/pylops/_torchoperator.py#LL5-LL10