diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst
index e86981fe..1d77dc15 100755
--- a/docs/source/api/index.rst
+++ b/docs/source/api/index.rst
@@ -29,6 +29,7 @@ Templates
FunctionOperator
MemoizeOperator
TorchOperator
+ JaxOperator
Basic operators
~~~~~~~~~~~~~~~
diff --git a/docs/source/conf.py b/docs/source/conf.py
index caf745e5..c5e6536d 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -21,6 +21,7 @@
"numpydoc",
"nbsphinx",
"sphinx_gallery.gen_gallery",
+ "sphinxemoji.sphinxemoji",
# 'sphinx.ext.napoleon',
]
@@ -29,6 +30,8 @@
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "cupy": ("https://docs.cupy.dev/en/stable/", None),
+ "jax": ("https://jax.readthedocs.io/en/latest", None),
"sklearn": ("http://scikit-learn.org/stable/", None),
"pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None),
"matplotlib": ("https://matplotlib.org/", None),
diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst
index e861c567..7f74c373 100755
--- a/docs/source/gpu.rst
+++ b/docs/source/gpu.rst
@@ -1,55 +1,404 @@
.. _gpu:
-GPU Support
-===========
+GPU / TPU Support
+=================
Overview
--------
-PyLops supports computations on GPUs powered by `CuPy `_ (``cupy-cudaXX>=v13.0.0``).
+From ``v1.12.0``, PyLops supports computations on GPUs powered by
+`CuPy `_ (``cupy-cudaXX>=8.1.0``).
This library must be installed *before* PyLops is installed.
-.. note::
-
- Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend.
- This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system,
- otherwise you will get an error when importing PyLops.
+From ``v2.3.0``, PyLops supports also computations on GPUs/TPUs powered by
+`JAX `_.
+This library must be installed *before* PyLops is installed.
+.. note::
+ Set environment variables ``CUPY_PYLOPS=0`` and/or ``JAX_PYLOPS=0`` to force PyLops to ignore
+ ``cupy`` and ``jax`` backends. This can be also used if a previous version of ``cupy``
+ or ``jax`` is installed in your system, otherwise you will get an error when importing PyLops.
Apart from a few exceptions, all operators and solvers in PyLops can
-seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy`` arrays
-on GPU. Users do simply need to consistently create operators and
+seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy/jax`` arrays
+on GPU. For CuPy, users simply need to consistently create operators and
provide data vectors to the solvers, e.g., when using
:class:`pylops.MatrixMult` the input matrix must be a
``cupy`` array if the data provided to a solver is also ``cupy`` array.
+For JAX, apart from following the same procedure described for CuPy, the PyLops operator must
+be also wrapped into a :class:`pylops.JaxOperator`.
-.. warning::
- Some :class:`pylops.LinearOperator` methods are currently on GPU:
+In the following, we provide a list of methods in :class:`pylops.LinearOperator` with their current status (available on CPU,
+GPU with CuPy, and GPU with JAX):
- - :meth:`pylops.LinearOperator.eigs`
- - :meth:`pylops.LinearOperator.cond`
- - :meth:`pylops.LinearOperator.tosparse`
- - :meth:`pylops.LinearOperator.estimate_spectral_norm`
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
-.. warning::
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :meth:`pylops.LinearOperator.cond`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :meth:`pylops.LinearOperator.conj`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.LinearOperator.div`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.LinearOperator.eigs`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :meth:`pylops.LinearOperator.todense`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.LinearOperator.tosparse`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :meth:`pylops.LinearOperator.trace`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+
+Similarly, we provide a list of operators with their current status.
+
+Basic operators:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.basicoperators.MatrixMult`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Identity`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Zero`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Diagonal`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.basicoperators.Transpose`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Flip`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Roll`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Pad`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Sum`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Symmetrize`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Restriction`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Regression`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.LinearRegression`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.CausalIntegration`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Spread`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.basicoperators.VStack`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.HStack`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Block`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.BlockDiag`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+
+
+Smoothing and derivatives:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.basicoperators.FirstDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.SecondDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Laplacian`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Gradient`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.FirstDirectionalDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.SecondDirectionalDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+
+Signal processing:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.signalprocessing.Convolve1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:warning:|
+ * - :class:`pylops.signalprocessing.Convolve2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.ConvolveND`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.NonStationaryConvolve1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.NonStationaryFilters1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.NonStationaryConvolve2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.NonStationaryFilters2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Interp`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.Bilinear`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.FFT`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.FFT2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.FFTND`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.Shift`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.DWT`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.DWT2D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.DCT`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Seislet`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Radon2D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Radon3D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.ChirpRadon2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.ChirpRadon3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Sliding1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Sliding2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Sliding3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Patch2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Patch3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Fredholm1`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
- Some operators are currently not available on GPU:
+Wave-Equation processing
- - :class:`pylops.Spread`
- - :class:`pylops.signalprocessing.Radon2D`
- - :class:`pylops.signalprocessing.Radon3D`
- - :class:`pylops.signalprocessing.DWT`
- - :class:`pylops.signalprocessing.DWT2D`
- - :class:`pylops.signalprocessing.Seislet`
- - :class:`pylops.waveeqprocessing.Demigration`
- - :class:`pylops.waveeqprocessing.LSM`
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.avo.avo.PressureToVelocity`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.UpDownComposition2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.UpDownComposition3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.BlendingContinuous`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.BlendingGroup`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.BlendingHalf`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.MDC`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.Kirchhoff`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.avo.avo.AcousticWave2D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+
+Geophysical subsurface characterization:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.avo.avo.AVOLinearModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.poststack.PoststackLinearModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.prestack.PrestackLinearModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:warning:|
+ * - :class:`pylops.avo.prestack.PrestackWaveletModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:warning:|
.. warning::
- Some solvers are currently not available on GPU:
- - :class:`pylops.optimization.sparsity.SPGL1`
+ 1. The JAX backend of the :class:`pylops.signalprocessing.Convolve1D` operator
+ currently works only with 1d-arrays due to a different behaviour of
+ :meth:`scipy.signal.convolve` and :meth:`jax.scipy.signal.convolve` with
+ nd-arrays.
+
+ 2. The JAX backend of the :class:`pylops.avo.prestack.PrestackLinearModelling`
+ operator currently works only with ``explicit=True`` due to the same issue as
+ in point 1 for the :class:`pylops.signalprocessing.Convolve1D` operator employed
+ when ``explicit=False``.
Example
@@ -68,8 +417,7 @@ Finally, let's briefly look at an example. First we write a code snippet using
y = Gop * x
xest = Gop / y
-
-Now we write a code snippet using ``cupy`` arrays which PyLops will run on
+Now we write a code snippet using ``cupy`` arrays which PyLops will run on
your GPU:
.. code-block:: python
@@ -83,9 +431,28 @@ your GPU:
xest = Gop / y
The code is almost unchanged apart from the fact that we now use ``cupy`` arrays,
-PyLops will figure this out!
+PyLops will figure this out.
+
+Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on
+your GPU/TPU:
+
+.. code-block:: python
+
+ ny, nx = 400, 400
+ G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
+ x = jnp.ones(nx, dtype=np.float32)
+
+ Gop = JaxOperator(MatrixMult(G, dtype='float32'))
+ y = Gop * x
+ xest = Gop / y
+
+ # Adjoint via AD
+ xadj = Gop.rmatvecad(x, y)
+
+
+Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays,
.. note::
- The CuPy backend is in active development, with many examples not yet in the docs.
- You can find many `other examples `_ from the `PyLops Notebooks repository `_.
+ More examples for the CuPy and JAX backends be found `here `_
+ and `here `_.
\ No newline at end of file
diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml
index 413a9759..0081a0ad 100755
--- a/environment-dev-arm.yml
+++ b/environment-dev-arm.yml
@@ -11,6 +11,7 @@ dependencies:
- scipy>=1.11.0
- pytorch>=1.2.0
- cpuonly
+ - jax
- pyfftw
- pywavelets
- sympy
@@ -34,6 +35,7 @@ dependencies:
- pydata-sphinx-theme
- sphinx-gallery
- nbsphinx
+ - sphinxemoji
- image
- flake8
- mypy
diff --git a/environment-dev.yml b/environment-dev.yml
index ef51f696..eb51c4dc 100755
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -11,6 +11,7 @@ dependencies:
- scipy>=1.11.0
- pytorch>=1.2.0
- cpuonly
+ - jax
- pyfftw
- pywavelets
- sympy
@@ -35,6 +36,7 @@ dependencies:
- pydata-sphinx-theme
- sphinx-gallery
- nbsphinx
+ - sphinxemoji
- image
- flake8
- mypy
diff --git a/pylops/__init__.py b/pylops/__init__.py
index 55d4ce3d..7672fda4 100755
--- a/pylops/__init__.py
+++ b/pylops/__init__.py
@@ -48,6 +48,7 @@
from .config import *
from .linearoperator import *
from .torchoperator import *
+from .jaxoperator import *
from .basicoperators import *
from . import (
avo,
diff --git a/pylops/avo/poststack.py b/pylops/avo/poststack.py
index 7e514707..8a9001ac 100644
--- a/pylops/avo/poststack.py
+++ b/pylops/avo/poststack.py
@@ -27,6 +27,7 @@
get_csc_matrix,
get_lstsq,
get_module_name,
+ inplace_set,
)
from pylops.utils.signalprocessing import convmtx, nonstationary_convmtx
from pylops.utils.typing import NDArray, ShapeLike
@@ -93,12 +94,13 @@ def _PoststackLinearModelling(
D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * ncp.ones(nt0 - 1, dtype=dtype), -1
)
- D[0] = D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, 0)
+ D = inplace_set(ncp.array(0.0), D, -1)
else:
D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
ncp.ones(nt0, dtype=dtype), k=0
)
- D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, -1)
# Create wavelet operator
if len(wav.shape) == 1:
diff --git a/pylops/avo/prestack.py b/pylops/avo/prestack.py
index 8630bc9a..4cf6c4eb 100644
--- a/pylops/avo/prestack.py
+++ b/pylops/avo/prestack.py
@@ -31,6 +31,7 @@
get_block_diag,
get_lstsq,
get_module_name,
+ inplace_set,
)
from pylops.utils.signalprocessing import convmtx
from pylops.utils.typing import NDArray, ShapeLike
@@ -182,12 +183,13 @@ def PrestackLinearModelling(
D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=-1
)
- D[0] = D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, 0)
+ D = inplace_set(ncp.array(0.0), D, -1)
else:
D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
ncp.ones(nt0, dtype=dtype), k=0
)
- D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, -1)
D = get_block_diag(theta)(*([D] * nG))
# Create wavelet operator
@@ -339,7 +341,8 @@ def PrestackWaveletModelling(
D = ncp.diag(0.5 * np.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * np.ones(nt0 - 1, dtype=dtype), k=-1
)
- D[0] = D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, 0)
+ D = inplace_set(ncp.array(0.0), D, -1)
D = get_block_diag(theta)(*([D] * nG))
# Create infinite-reflectivity data
diff --git a/pylops/basicoperators/blockdiag.py b/pylops/basicoperators/blockdiag.py
index e13ed026..166ae137 100644
--- a/pylops/basicoperators/blockdiag.py
+++ b/pylops/basicoperators/blockdiag.py
@@ -21,7 +21,7 @@
from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.typing import DTypeLike, NDArray
@@ -175,18 +175,22 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.nnops[iop] : self.nnops[iop + 1]] = oper.matvec(
- x[self.mmops[iop] : self.mmops[iop + 1]]
- ).squeeze()
+ y = inplace_set(
+ oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze(),
+ y,
+ slice(self.nnops[iop], self.nnops[iop + 1]),
+ )
return y
def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.mmops[iop] : self.mmops[iop + 1]] = oper.rmatvec(
- x[self.nnops[iop] : self.nnops[iop + 1]]
- ).squeeze()
+ y = inplace_set(
+ oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze(),
+ y,
+ slice(self.mmops[iop], self.mmops[iop + 1]),
+ )
return y
def _matvec_multiproc(self, x: NDArray) -> NDArray:
diff --git a/pylops/basicoperators/firstderivative.py b/pylops/basicoperators/firstderivative.py
index 58edf17f..f8bd208e 100644
--- a/pylops/basicoperators/firstderivative.py
+++ b/pylops/basicoperators/firstderivative.py
@@ -7,7 +7,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -100,6 +100,16 @@ def __init__(
self.kind = kind
self.edge = edge
self.order = order
+ self.slice = {
+ i: {
+ j: tuple([slice(None, None)] * (len(dims) - 1) + [slice(i, j)])
+ for j in (None, -1, -2, -3, -4)
+ }
+ for i in (None, 1, 2, 3, 4)
+ }
+ self.sample = {
+ i: tuple([slice(None, None)] * (len(dims) - 1) + [i]) for i in range(-3, 4)
+ }
self._register_multiplications(self.kind, self.order)
def _register_multiplications(
@@ -140,15 +150,20 @@ def _rmatvec(self, x: NDArray) -> NDArray:
def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ # y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ y = inplace_set(
+ (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[None][-1]
+ )
return y
@reshaped(swapaxis=True)
def _rmatvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-1] -= x[..., :-1]
- y[..., 1:] += x[..., :-1]
+ # y[..., :-1] -= x[..., :-1]
+ y = inplace_add(-x[..., :-1], y, self.slice[None][-1])
+ # y[..., 1:] += x[..., :-1]
+ y = inplace_add(x[..., :-1], y, self.slice[1][None])
y /= self.sampling
return y
@@ -156,10 +171,13 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray:
def _matvec_centered3(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2])
+ # y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2])
+ y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice[1][-1])
if self.edge:
- y[..., 0] = x[..., 1] - x[..., 0]
- y[..., -1] = x[..., -1] - x[..., -2]
+ # y[..., 0] = x[..., 1] - x[..., 0]
+ y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0])
+ # y[..., -1] = x[..., -1] - x[..., -2]
+ y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1])
y /= self.sampling
return y
@@ -167,13 +185,19 @@ def _matvec_centered3(self, x: NDArray) -> NDArray:
def _rmatvec_centered3(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] -= 0.5 * x[..., 1:-1]
- y[..., 2:] += 0.5 * x[..., 1:-1]
+ # y[..., :-2] -= 0.5 * x[..., 1:-1]
+ y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice[None][-2])
+ # y[..., 2:] += 0.5 * x[..., 1:-1]
+ y = inplace_add(0.5 * x[..., 1:-1], y, self.slice[2][None])
if self.edge:
- y[..., 0] -= x[..., 0]
- y[..., 1] += x[..., 0]
- y[..., -2] -= x[..., -1]
- y[..., -1] += x[..., -1]
+ # y[..., 0] -= x[..., 0]
+ y = inplace_add(-x[..., 0], y, self.sample[0])
+ # y[..., 1] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[1])
+ # y[..., -2] -= x[..., -1]
+ y = inplace_add(-x[..., -1], y, self.sample[-2])
+ # y[..., -1] += x[..., -1]
+ y = inplace_add(x[..., -1], y, self.sample[-1])
y /= self.sampling
return y
@@ -181,17 +205,31 @@ def _rmatvec_centered3(self, x: NDArray) -> NDArray:
def _matvec_centered5(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 2:-2] = (
- x[..., :-4] / 12.0
- - 2 * x[..., 1:-3] / 3.0
- + 2 * x[..., 3:-1] / 3.0
- - x[..., 4:] / 12.0
+ # y[..., 2:-2] = (
+ # x[..., :-4] / 12.0
+ # - 2 * x[..., 1:-3] / 3.0
+ # + 2 * x[..., 3:-1] / 3.0
+ # - x[..., 4:] / 12.0
+ # )
+ y = inplace_set(
+ (
+ x[..., :-4] / 12.0
+ - 2 * x[..., 1:-3] / 3.0
+ + 2 * x[..., 3:-1] / 3.0
+ - x[..., 4:] / 12.0
+ ),
+ y,
+ self.slice[2][-2],
)
if self.edge:
- y[..., 0] = x[..., 1] - x[..., 0]
- y[..., 1] = 0.5 * (x[..., 2] - x[..., 0])
- y[..., -2] = 0.5 * (x[..., -1] - x[..., -3])
- y[..., -1] = x[..., -1] - x[..., -2]
+ # y[..., 0] = x[..., 1] - x[..., 0]
+ y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0])
+ # y[..., 1] = 0.5 * (x[..., 2] - x[..., 0])
+ y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample[1])
+ # y[..., -2] = 0.5 * (x[..., -1] - x[..., -3])
+ y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample[-2])
+ # y[..., -1] = x[..., -1] - x[..., -2]
+ y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1])
y /= self.sampling
return y
@@ -199,17 +237,27 @@ def _matvec_centered5(self, x: NDArray) -> NDArray:
def _rmatvec_centered5(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-4] += x[..., 2:-2] / 12.0
- y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0
- y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0
- y[..., 4:] -= x[..., 2:-2] / 12.0
+ # y[..., :-4] += x[..., 2:-2] / 12.0
+ y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice[None][-4])
+ # y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0
+ y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice[1][-3])
+ # y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0
+ y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice[3][-1])
+ # y[..., 4:] -= x[..., 2:-2] / 12.0
+ y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice[4][None])
if self.edge:
- y[..., 0] -= x[..., 0] + 0.5 * x[..., 1]
- y[..., 1] += x[..., 0]
- y[..., 2] += 0.5 * x[..., 1]
- y[..., -3] -= 0.5 * x[..., -2]
- y[..., -2] -= x[..., -1]
- y[..., -1] += 0.5 * x[..., -2] + x[..., -1]
+ # y[..., 0] -= x[..., 0] + 0.5 * x[..., 1]
+ y = inplace_add(-(x[..., 0] + 0.5 * x[..., 1]), y, self.sample[0])
+ # y[..., 1] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[1])
+ # y[..., 2] += 0.5 * x[..., 1]
+ y = inplace_add(0.5 * x[..., 1], y, self.sample[2])
+ # y[..., -3] -= 0.5 * x[..., -2]
+ y = inplace_add(-0.5 * x[..., -2], y, self.sample[-3])
+ # y[..., -2] -= x[..., -1]
+ y = inplace_add(-x[..., -1], y, self.sample[-2])
+ # y[..., -1] += 0.5 * x[..., -2] + x[..., -1]
+ y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample[-1])
y /= self.sampling
return y
@@ -217,14 +265,19 @@ def _rmatvec_centered5(self, x: NDArray) -> NDArray:
def _matvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ # y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ y = inplace_set(
+ (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[1][None]
+ )
return y
@reshaped(swapaxis=True)
def _rmatvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-1] -= x[..., 1:]
- y[..., 1:] += x[..., 1:]
+ # y[..., :-1] -= x[..., 1:]
+ y = inplace_add(-x[..., 1:], y, self.slice[None][-1])
+ # y[..., 1:] += x[..., 1:]
+ y = inplace_add(x[..., 1:], y, self.slice[1][None])
y /= self.sampling
return y
diff --git a/pylops/basicoperators/hstack.py b/pylops/basicoperators/hstack.py
index 5cfbbec0..b71e8723 100644
--- a/pylops/basicoperators/hstack.py
+++ b/pylops/basicoperators/hstack.py
@@ -21,7 +21,7 @@
from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.typing import NDArray
@@ -165,14 +165,22 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y += oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze()
+ y = inplace_add(
+ oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze(),
+ y,
+ slice(None, None),
+ )
return y
def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.mmops[iop] : self.mmops[iop + 1]] = oper.rmatvec(x).squeeze()
+ y = inplace_set(
+ oper.rmatvec(x).squeeze(),
+ y,
+ slice(self.mmops[iop], self.mmops[iop + 1]),
+ )
return y
def _matvec_multiproc(self, x: NDArray) -> NDArray:
diff --git a/pylops/basicoperators/identity.py b/pylops/basicoperators/identity.py
index c2d05a30..50b76831 100644
--- a/pylops/basicoperators/identity.py
+++ b/pylops/basicoperators/identity.py
@@ -6,7 +6,7 @@
import numpy as np
from pylops import LinearOperator
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -181,7 +181,7 @@ def _matvec(self, x: NDArray) -> NDArray:
y = x[self.sliceN]
else:
y = ncp.zeros(self.dimsd, dtype=self.dtype)
- y[self.sliceM] = x
+ y = inplace_set(x, y, self.sliceM)
return y
@reshaped
@@ -193,7 +193,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y = x
elif self.mode == "model":
y = ncp.zeros(self.dims, dtype=self.dtype)
- y[self.sliceN] = x
+ y = inplace_set(x, y, self.sliceN)
else:
y = x[self.sliceM]
return y
diff --git a/pylops/basicoperators/pad.py b/pylops/basicoperators/pad.py
index 45b63af8..d98a894c 100644
--- a/pylops/basicoperators/pad.py
+++ b/pylops/basicoperators/pad.py
@@ -6,6 +6,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -85,10 +86,12 @@ def __init__(
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
- return np.pad(x, self.pad, mode="constant")
+ ncp = get_array_module(x)
+ return ncp.pad(x, self.pad, mode="constant")
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
for ax, (before, _) in enumerate(self.pad):
- x = np.take(x, np.arange(before, before + self.dims[ax]), axis=ax)
+ x = ncp.take(x, ncp.arange(before, before + self.dims[ax]), axis=ax)
return x
diff --git a/pylops/basicoperators/restriction.py b/pylops/basicoperators/restriction.py
index fc81a252..1a745b30 100644
--- a/pylops/basicoperators/restriction.py
+++ b/pylops/basicoperators/restriction.py
@@ -16,7 +16,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module, to_cupy_conditional
+from pylops.utils.backend import get_array_module, inplace_set, to_cupy_conditional
from pylops.utils.typing import DTypeLike, InputDimsLike, IntNDArray, NDArray
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
@@ -26,13 +26,13 @@ def _compute_iavamask(dims, axis, iava, ncp):
"""Compute restriction mask when using cupy arrays"""
otherdims = np.array(dims)
otherdims = np.delete(otherdims, axis)
- iavamask = ncp.zeros(int(dims[axis]), dtype=int)
+ iavamask = np.zeros(int(dims[axis]), dtype=int)
iavamask[iava] = 1
- iavamask = ncp.moveaxis(
- ncp.broadcast_to(iavamask, list(otherdims) + [dims[axis]]), -1, axis
+ iavamask = np.moveaxis(
+ np.broadcast_to(iavamask, list(otherdims) + [dims[axis]]), -1, axis
)
- iavamask = ncp.where(iavamask.ravel() == 1)[0]
- return iavamask
+ iavamask = np.where(iavamask.ravel() == 1)[0]
+ return ncp.asarray(iavamask)
class Restriction(LinearOperator):
@@ -179,7 +179,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
self.iava = to_cupy_conditional(x, self.iava)
self.iavamask = _compute_iavamask(self.dims, self.axis, self.iava, ncp)
y = ncp.zeros(int(self.shape[-1]), dtype=self.dtype)
- y[self.iavamask] = x.ravel()
+ y = inplace_set(x.ravel(), y, self.iavamask)
y = y.ravel()
return y
diff --git a/pylops/basicoperators/roll.py b/pylops/basicoperators/roll.py
index 8fc27e4d..29e6f613 100644
--- a/pylops/basicoperators/roll.py
+++ b/pylops/basicoperators/roll.py
@@ -6,6 +6,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -64,8 +65,10 @@ def __init__(
@reshaped(swapaxis=True)
def _matvec(self, x: NDArray) -> NDArray:
- return np.roll(x, shift=self.shift, axis=-1)
+ ncp = get_array_module(x)
+ return ncp.roll(x, shift=self.shift, axis=-1)
@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
- return np.roll(x, shift=-self.shift, axis=-1)
+ ncp = get_array_module(x)
+ return ncp.roll(x, shift=-self.shift, axis=-1)
diff --git a/pylops/basicoperators/secondderivative.py b/pylops/basicoperators/secondderivative.py
index 744d067a..8433987d 100644
--- a/pylops/basicoperators/secondderivative.py
+++ b/pylops/basicoperators/secondderivative.py
@@ -7,7 +7,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -90,6 +90,16 @@ def __init__(
self.sampling = sampling
self.kind = kind
self.edge = edge
+ self.slice = {
+ i: {
+ j: tuple([slice(None, None)] * (len(dims) - 1) + [slice(i, j)])
+ for j in (None, -1, -2, -3, -4)
+ }
+ for i in (None, 1, 2, 3, 4)
+ }
+ self.sample = {
+ i: tuple([slice(None, None)] * (len(dims) - 1) + [i]) for i in range(-3, 4)
+ }
self._register_multiplications(self.kind)
def _register_multiplications(
@@ -123,7 +133,10 @@ def _rmatvec(self, x: NDArray) -> NDArray:
def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ # y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ y = inplace_set(
+ x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[None][-2]
+ )
y /= self.sampling**2
return y
@@ -131,9 +144,12 @@ def _matvec_forward(self, x: NDArray) -> NDArray:
def _rmatvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] += x[..., :-2]
- y[..., 1:-1] -= 2 * x[..., :-2]
- y[..., 2:] += x[..., :-2]
+ # y[..., :-2] += x[..., :-2]
+ y = inplace_add(x[..., :-2], y, self.slice[None][-2])
+ # y[..., 1:-1] -= 2 * x[..., :-2]
+ y = inplace_add(-2 * x[..., :-2], y, self.slice[1][-1])
+ # y[..., 2:] += x[..., :-2]
+ y = inplace_add(x[..., :-2], y, self.slice[2][None])
y /= self.sampling**2
return y
@@ -141,10 +157,17 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray:
def _matvec_centered(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ # y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ y = inplace_set(
+ x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[1][-1]
+ )
if self.edge:
- y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2]
- y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1]
+ # y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2]
+ y = inplace_set(x[..., 0] - 2 * x[..., 1] + x[..., 2], y, self.sample[0])
+ # y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1]
+ y = inplace_set(
+ x[..., -3] - 2 * x[..., -2] + x[..., -1], y, self.sample[-1]
+ )
y /= self.sampling**2
return y
@@ -152,16 +175,25 @@ def _matvec_centered(self, x: NDArray) -> NDArray:
def _rmatvec_centered(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] += x[..., 1:-1]
- y[..., 1:-1] -= 2 * x[..., 1:-1]
- y[..., 2:] += x[..., 1:-1]
+ # y[..., :-2] += x[..., 1:-1]
+ y = inplace_add(x[..., 1:-1], y, self.slice[None][-2])
+ # y[..., 1:-1] -= 2 * x[..., 1:-1]
+ y = inplace_add(-2 * x[..., 1:-1], y, self.slice[1][-1])
+ # y[..., 2:] += x[..., 1:-1]
+ y = inplace_add(x[..., 1:-1], y, self.slice[2][None])
if self.edge:
- y[..., 0] += x[..., 0]
- y[..., 1] -= 2 * x[..., 0]
- y[..., 2] += x[..., 0]
- y[..., -3] += x[..., -1]
- y[..., -2] -= 2 * x[..., -1]
- y[..., -1] += x[..., -1]
+ # y[..., 0] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[0])
+ # y[..., 1] -= 2 * x[..., 0]
+ y = inplace_add(-2 * x[..., 0], y, self.sample[1])
+ # y[..., 2] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[2])
+ # y[..., -3] += x[..., -1]
+ y = inplace_add(x[..., -1], y, self.sample[-3])
+ # y[..., -2] -= 2 * x[..., -1]
+ y = inplace_add(-2 * x[..., -1], y, self.sample[-2])
+ # y[..., -1] += x[..., -1]
+ y = inplace_add(x[..., -1], y, self.sample[-1])
y /= self.sampling**2
return y
@@ -169,7 +201,10 @@ def _rmatvec_centered(self, x: NDArray) -> NDArray:
def _matvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ # y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ y = inplace_set(
+ x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[2][None]
+ )
y /= self.sampling**2
return y
@@ -177,8 +212,11 @@ def _matvec_backward(self, x: NDArray) -> NDArray:
def _rmatvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] += x[..., 2:]
- y[..., 1:-1] -= 2 * x[..., 2:]
- y[..., 2:] += x[..., 2:]
+ # y[..., :-2] += x[..., 2:]
+ y = inplace_add(x[..., 2:], y, self.slice[None][-2])
+ # y[..., 1:-1] -= 2 * x[..., 2:]
+ y = inplace_add(-2 * x[..., 2:], y, self.slice[1][-1])
+ # y[..., 2:] += x[..., 2:]
+ y = inplace_add(x[..., 2:], y, self.slice[2][None])
y /= self.sampling**2
return y
diff --git a/pylops/basicoperators/symmetrize.py b/pylops/basicoperators/symmetrize.py
index 47814154..41ca122b 100644
--- a/pylops/basicoperators/symmetrize.py
+++ b/pylops/basicoperators/symmetrize.py
@@ -6,7 +6,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -80,6 +80,13 @@ def __init__(
self.nsym = dims[self.axis]
dimsd = list(dims)
dimsd[self.axis] = 2 * dims[self.axis] - 1
+ self.slice1 = tuple([slice(None, None)] * (len(dims) - 1) + [slice(1, None)])
+ self.slicensym_1 = tuple(
+ [slice(None, None)] * (len(dims) - 1) + [slice(self.nsym - 1, None)]
+ )
+ self.slice_nsym_1 = tuple(
+ [slice(None, None)] * (len(dims) - 1) + [slice(None, self.nsym - 1)]
+ )
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
@@ -88,12 +95,12 @@ def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.dimsd, dtype=self.dtype)
y = y.swapaxes(self.axis, -1)
- y[..., self.nsym - 1 :] = x
- y[..., : self.nsym - 1] = x[..., -1:0:-1]
+ y = inplace_set(x, y, self.slicensym_1)
+ y = inplace_set(x[..., -1:0:-1], y, self.slice_nsym_1)
return y
@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
y = x[..., self.nsym - 1 :].copy()
- y[..., 1:] += x[..., self.nsym - 2 :: -1]
+ y = inplace_add(x[..., self.nsym - 2 :: -1], y, self.slice1)
return y
diff --git a/pylops/basicoperators/vstack.py b/pylops/basicoperators/vstack.py
index 812b1a7e..0d66642e 100644
--- a/pylops/basicoperators/vstack.py
+++ b/pylops/basicoperators/vstack.py
@@ -12,16 +12,16 @@
from scipy.sparse.linalg.interface import LinearOperator as spLinearOperator
from scipy.sparse.linalg.interface import _get_dtype
else:
- from scipy.sparse.linalg._interface import _get_dtype
from scipy.sparse.linalg._interface import (
LinearOperator as spLinearOperator,
)
+ from scipy.sparse.linalg._interface import _get_dtype
from typing import Callable, Optional, Sequence
from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.typing import DTypeLike, NDArray
@@ -165,14 +165,20 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.nnops[iop] : self.nnops[iop + 1]] = oper.matvec(x).squeeze()
+ y = inplace_set(
+ oper.matvec(x).squeeze(), y, slice(self.nnops[iop], self.nnops[iop + 1])
+ )
return y
def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y += oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze()
+ y = inplace_add(
+ oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze(),
+ y,
+ slice(None, None),
+ )
return y
def _matvec_multiproc(self, x: NDArray) -> NDArray:
diff --git a/pylops/jaxoperator.py b/pylops/jaxoperator.py
new file mode 100644
index 00000000..5d5c40ed
--- /dev/null
+++ b/pylops/jaxoperator.py
@@ -0,0 +1,104 @@
+__all__ = [
+ "JaxOperator",
+]
+
+from typing import Any, NewType
+
+from pylops import LinearOperator
+from pylops.utils import deps
+
+if deps.jax_enabled:
+ import jax
+
+ jaxarrayin_type = jax.typing.ArrayLike
+ jaxarrayout_type = jax.Array
+else:
+ jax_message = (
+ "JAX package not installed. In order to be able to use"
+ 'the jaxoperator module run "pip install jax" or'
+ '"conda install -c conda-forge jax".'
+ )
+ jaxarrayin_type = Any
+ jaxarrayout_type = Any
+
+JaxTypeIn = NewType("JaxTypeIn", jaxarrayin_type)
+JaxTypeOut = NewType("JaxTypeOut", jaxarrayout_type)
+
+
+class JaxOperator(LinearOperator):
+ """Enable JAX backend for PyLops operator.
+
+ This class can be used to wrap a pylops operator to enable the JAX
+ backend. Doing so, users can run all of the methods of a pylops
+ operator with JAX arrays. Moreover, the forward and adjoint
+ are internally just-in-time compiled, and other JAX functionalities
+ such as automatic differentiation and automatic vectorization
+ are enabled.
+
+ Parameters
+ ----------
+ Op : :obj:`pylops.LinearOperator`
+ PyLops operator
+
+ """
+
+ def __init__(self, Op: LinearOperator) -> None:
+ if not deps.jax_enabled:
+ raise NotImplementedError(jax_message)
+ super().__init__(
+ dtype=Op.dtype,
+ dims=Op.dims,
+ dimsd=Op.dimsd,
+ clinear=Op.clinear,
+ explicit=False,
+ forceflat=Op.forceflat,
+ name=Op.name,
+ )
+ self._matvec = jax.jit(Op._matvec)
+ self._rmatvec = jax.jit(Op._rmatvec)
+
+ def __call__(self, x, *args, **kwargs):
+ return self._matvec(x)
+
+ def _rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
+ _, f_vjp = jax.vjp(self._matvec, x)
+ xadj = jax.jit(f_vjp)(y)[0]
+ return xadj
+
+ def rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
+ """Vector-Jacobian product
+
+ JIT-compiled Vector-Jacobian product
+
+ Parameters
+ ----------
+ x : :obj:`jax.Array`
+ Input array for forward
+ y : :obj:`jax.Array`
+ Input array for adjoint
+
+ Returns
+ -------
+ xadj : :obj:`jax.typing.ArrayLike`
+ Output array
+
+ """
+ M, N = self.shape
+
+ if x.shape != (M,) and x.shape != (M, 1):
+ raise ValueError(
+ f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)."
+ )
+
+ y = self._rmatvecad(x, y)
+
+ if x.ndim == 1:
+ y = y.reshape(N)
+ elif x.ndim == 2:
+ y = y.reshape(N, 1)
+ else:
+ raise ValueError(
+ f"Invalid shape returned by user-defined rmatvecad(). "
+ f"Expected 2-d ndarray or matrix, not {x.ndim}-d ndarray"
+ )
+ return y
diff --git a/pylops/linearoperator.py b/pylops/linearoperator.py
index 44e561cd..661178f5 100644
--- a/pylops/linearoperator.py
+++ b/pylops/linearoperator.py
@@ -442,10 +442,11 @@ def _matmat(self, X: NDArray) -> NDArray:
Modified version of scipy _matmat to avoid having trailing dimension
in col when provided to matvec
"""
+ ncp = get_array_module(X)
if sp.sparse.issparse(X):
- y = np.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T
else:
- y = np.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T
return y
def _rmatmat(self, X: NDArray) -> NDArray:
@@ -454,10 +455,11 @@ def _rmatmat(self, X: NDArray) -> NDArray:
Modified version of scipy _rmatmat to avoid having trailing dimension
in col when provided to rmatvec
"""
+ ncp = get_array_module(X)
if sp.sparse.issparse(X):
- y = np.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T
else:
- y = np.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T
return y
def _adjoint(self) -> LinearOperator:
@@ -508,7 +510,9 @@ def matvec(self, x: NDArray) -> NDArray:
M, N = self.shape
if x.shape != (N,) and x.shape != (N, 1):
- raise ValueError("dimension mismatch")
+ raise ValueError(
+ f"Dimension mismatch. Got {x.shape}, but expected ({N},) or ({N}, 1)."
+ )
y = self._matvec(x)
@@ -517,7 +521,7 @@ def matvec(self, x: NDArray) -> NDArray:
elif x.ndim == 2:
y = y.reshape(M, 1)
else:
- raise ValueError("invalid shape returned by user-defined matvec()")
+ raise ValueError("Invalid shape returned by user-defined matvec()")
return y
@count(forward=False)
@@ -542,7 +546,9 @@ def rmatvec(self, x: NDArray) -> NDArray:
M, N = self.shape
if x.shape != (M,) and x.shape != (M, 1):
- raise ValueError("dimension mismatch")
+ raise ValueError(
+ f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)."
+ )
y = self._rmatvec(x)
@@ -551,7 +557,7 @@ def rmatvec(self, x: NDArray) -> NDArray:
elif x.ndim == 2:
y = y.reshape(N, 1)
else:
- raise ValueError("invalid shape returned by user-defined rmatvec()")
+ raise ValueError("Invalid shape returned by user-defined rmatvec()")
return y
@count(forward=True, matmat=True)
@@ -574,9 +580,9 @@ def matmat(self, X: NDArray) -> NDArray:
"""
if X.ndim != 2:
- raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim)
+ raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray")
if X.shape[0] != self.shape[1]:
- raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape))
+ raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}")
Y = self._matmat(X)
return Y
@@ -600,9 +606,9 @@ def rmatmat(self, X: NDArray) -> NDArray:
"""
if X.ndim != 2:
- raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim)
+ raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray")
if X.shape[0] != self.shape[0]:
- raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape))
+ raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}")
Y = self._rmatmat(X)
return Y
@@ -791,7 +797,7 @@ def todense(
Parameters
----------
backend : :obj:`str`, optional
- Backend used to densify matrix (``numpy`` or ``cupy``). Note that
+ Backend used to densify matrix (``numpy`` or ``cupy`` or ``jax``). Note that
this must be consistent with how the operator has been created.
Returns
@@ -816,7 +822,7 @@ def todense(
if Op.shape[1] == shapemin:
matrix = Op.matmat(identity)
else:
- matrix = np.conj(Op.rmatmat(identity)).T
+ matrix = ncp.conj(Op.rmatmat(identity)).T
return matrix
def tosparse(self) -> NDArray:
diff --git a/pylops/signalprocessing/convolve1d.py b/pylops/signalprocessing/convolve1d.py
index fd82eb91..bc154a94 100644
--- a/pylops/signalprocessing/convolve1d.py
+++ b/pylops/signalprocessing/convolve1d.py
@@ -48,10 +48,10 @@ def _choose_convfunc(
def _pad_along_axis(array: np.ndarray, pad_size: tuple, axis: int = 0) -> np.ndarray:
-
+ ncp = get_array_module(array)
npad = [(0, 0)] * array.ndim
npad[axis] = pad_size
- return np.pad(array, pad_width=npad)
+ return ncp.pad(array, pad_width=npad)
class _Convolve1Dshort(LinearOperator):
@@ -67,6 +67,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
+ ncp = get_array_module(h)
dims = _value_or_sized_to_tuple(dims)
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dims, name=name)
self.axis = axis
@@ -83,7 +84,7 @@ def __init__(
(max(self.offset, 0), -min(self.offset, 0)),
axis=-1 if h.ndim == 1 else axis,
)
- self.hstar = np.flip(self.h, axis=-1)
+ self.hstar = ncp.flip(self.h, axis=-1)
# add dimensions to filter to match dimensions of model and data
if self.h.ndim == 1:
@@ -127,6 +128,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
+ ncp = get_array_module(h)
dims = _value_or_sized_to_tuple(dims)
dimsd = h.shape
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
@@ -140,13 +142,13 @@ def __init__(
self.offset = 2 * (self.dims[self.axis] // 2 - int(offset))
if self.dims[self.axis] % 2 == 0:
self.offset -= 1
- self.hstar = np.flip(self.h, axis=-1)
+ self.hstar = ncp.flip(self.h, axis=-1)
- self.pad = np.zeros((len(dims), 2), dtype=int)
+ self.pad = ncp.zeros((len(dims), 2), dtype=int)
self.pad[self.axis, 0] = max(self.offset, 0)
self.pad[self.axis, 1] = -min(self.offset, 0)
- self.padd = np.zeros((len(dims), 2), dtype=int)
+ self.padd = ncp.zeros((len(dims), 2), dtype=int)
self.padd[self.axis, 1] = max(self.offset, 0)
self.padd[self.axis, 0] = -min(self.offset, 0)
@@ -162,12 +164,13 @@ def __init__(
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if type(self.h) is not type(x):
self.h = to_cupy_conditional(x, self.h)
self.convfunc, self.method = _choose_convfunc(
self.h, self.method, self.dims, self.axis
)
- x = np.pad(x, self.pad)
+ x = ncp.pad(x, self.pad)
y = self.convfunc(self.h, x, mode="same")
return y
@@ -179,7 +182,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
self.convfunc, self.method = _choose_convfunc(
self.hstar, self.method, self.dims, self.axis
)
- x = np.pad(x, self.padd)
+ x = ncp.pad(x, self.padd)
y = self.convfunc(self.hstar, x)
if self.dims[self.axis] % 2 == 0:
y = ncp.take(
diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py
index 64444bcd..6af81a30 100644
--- a/pylops/signalprocessing/fft.py
+++ b/pylops/signalprocessing/fft.py
@@ -11,6 +11,7 @@
from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFT, _FFTNorms
from pylops.utils import deps
+from pylops.utils.backend import get_array_module, inplace_divide, inplace_multiply
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -60,53 +61,61 @@ def __init__(
self._scale = self.nfft
elif self.norm is _FFTNorms.ONE_OVER_N:
self._scale = 1.0 / self.nfft
+ self.slice = tuple(
+ [slice(None, None)] * (len(self.dims) - 1)
+ + [slice(1, 1 + (self.nfft - 1) // 2)]
+ )
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.ifftshift_before:
- x = np.fft.ifftshift(x, axes=self.axis)
+ x = ncp.fft.ifftshift(x, axes=self.axis)
if not self.clinear:
- x = np.real(x)
+ x = ncp.real(x)
if self.real:
- y = np.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ y = ncp.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
- y = np.swapaxes(y, -1, self.axis)
- y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2)
- y = np.swapaxes(y, self.axis, -1)
+ y = ncp.swapaxes(y, -1, self.axis)
+ # y[..., 1 : 1 + (self.nfft - 1) // 2] *= ncp.sqrt(2)
+ y = inplace_multiply(ncp.sqrt(2), y, self.slice)
+ y = ncp.swapaxes(y, self.axis, -1)
else:
- y = np.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ y = ncp.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
if self.fftshift_after:
- y = np.fft.fftshift(y, axes=self.axis)
+ y = ncp.fft.fftshift(y, axes=self.axis)
y = y.astype(self.cdtype)
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.fftshift_after:
- x = np.fft.ifftshift(x, axes=self.axis)
+ x = ncp.fft.ifftshift(x, axes=self.axis)
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
- x = np.swapaxes(x, -1, self.axis)
- x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2)
- x = np.swapaxes(x, self.axis, -1)
- y = np.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ x = ncp.swapaxes(x, -1, self.axis)
+ # x[..., 1 : 1 + (self.nfft - 1) // 2] /= ncp.sqrt(2)
+ x = inplace_divide(ncp.sqrt(2), x, self.slice)
+ x = ncp.swapaxes(x, self.axis, -1)
+ y = ncp.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
else:
- y = np.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ y = ncp.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
if self.nfft > self.dims[self.axis]:
- y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis)
+ y = ncp.take(y, range(0, self.dims[self.axis]), axis=self.axis)
elif self.nfft < self.dims[self.axis]:
- y = np.pad(y, self.ifftpad)
+ y = ncp.pad(y, self.ifftpad)
if not self.clinear:
- y = np.real(y)
+ y = ncp.real(y)
if self.ifftshift_before:
- y = np.fft.fftshift(y, axes=self.axis)
+ y = ncp.fft.fftshift(y, axes=self.axis)
y = y.astype(self.rdtype)
return y
@@ -453,7 +462,7 @@ def FFT(
Nyquist to the frequency bin before zero.
engine : :obj:`str`, optional
Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose
- ``numpy`` when working with cupy arrays.
+ ``numpy`` when working with cupy and jax arrays.
.. note:: Since version 1.17.0, accepts "scipy".
diff --git a/pylops/signalprocessing/fft2d.py b/pylops/signalprocessing/fft2d.py
index f54e2972..2f4b5f15 100644
--- a/pylops/signalprocessing/fft2d.py
+++ b/pylops/signalprocessing/fft2d.py
@@ -9,6 +9,7 @@
from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
+from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike
@@ -67,51 +68,53 @@ def __init__(
@reshaped
def _matvec(self, x):
+ ncp = get_array_module(x)
if self.ifftshift_before.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
- x = np.real(x)
+ x = ncp.real(x)
if self.real:
- y = np.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
- y = np.swapaxes(y, -1, self.axes[-1])
- y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
- y = np.swapaxes(y, self.axes[-1], -1)
+ y = ncp.swapaxes(y, -1, self.axes[-1])
+ y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
+ y = ncp.swapaxes(y, self.axes[-1], -1)
else:
- y = np.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
- y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y
@reshaped
def _rmatvec(self, x):
+ ncp = get_array_module(x)
if self.fftshift_after.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
- x = np.swapaxes(x, -1, self.axes[-1])
- x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
- x = np.swapaxes(x, self.axes[-1], -1)
- y = np.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ x = ncp.swapaxes(x, -1, self.axes[-1])
+ x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
+ x = ncp.swapaxes(x, self.axes[-1], -1)
+ y = ncp.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
- y = np.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
if self.nffts[0] > self.dims[self.axes[0]]:
- y = np.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0])
+ y = ncp.take(y, ncp.arange(self.dims[self.axes[0]]), axis=self.axes[0])
if self.nffts[1] > self.dims[self.axes[1]]:
- y = np.take(y, range(self.dims[self.axes[1]]), axis=self.axes[1])
+ y = ncp.take(y, ncp.arange(self.dims[self.axes[1]]), axis=self.axes[1])
if self.doifftpad:
- y = np.pad(y, self.ifftpad)
+ y = ncp.pad(y, self.ifftpad)
if not self.clinear:
- y = np.real(y)
+ y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
- y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y
def __truediv__(self, y):
@@ -310,7 +313,8 @@ def FFT2D(
engine : :obj:`str`, optional
.. versionadded:: 1.17.0
- Engine used for fft computation (``numpy`` or ``scipy``).
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
dtype : :obj:`str`, optional
Type of elements in input array. Note that the ``dtype`` of the operator
is the corresponding complex type even when a real type is provided.
diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py
index d081072b..cf2de78f 100644
--- a/pylops/signalprocessing/fftnd.py
+++ b/pylops/signalprocessing/fftnd.py
@@ -7,8 +7,9 @@
import numpy as np
import numpy.typing as npt
+from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
-from pylops.utils.backend import get_sp_fft
+from pylops.utils.backend import get_array_module, get_sp_fft
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -46,6 +47,7 @@ def __init__(
warnings.warn(
f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}."
)
+
self._kwargs_fft = kwargs_fft
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
if self.norm is _FFTNorms.ORTHO:
@@ -57,58 +59,52 @@ def __init__(
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.ifftshift_before.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
- x = np.real(x)
+ x = ncp.real(x)
if self.real:
- y = np.fft.rfftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = ncp.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
- y = np.swapaxes(y, -1, self.axes[-1])
- y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
- y = np.swapaxes(y, self.axes[-1], -1)
+ y = ncp.swapaxes(y, -1, self.axes[-1])
+ y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
+ y = ncp.swapaxes(y, self.axes[-1], -1)
else:
- y = np.fft.fftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = ncp.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
- y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.fftshift_after.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
- x = np.swapaxes(x, -1, self.axes[-1])
- x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
- x = np.swapaxes(x, self.axes[-1], -1)
- y = np.fft.irfftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ x = ncp.swapaxes(x, -1, self.axes[-1])
+ x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
+ x = ncp.swapaxes(x, self.axes[-1], -1)
+ y = ncp.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
- y = np.fft.ifftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = ncp.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
if nfft > self.dims[ax]:
- y = np.take(y, range(self.dims[ax]), axis=ax)
+ y = ncp.take(y, np.arange(self.dims[ax]), axis=ax)
if self.doifftpad:
- y = np.pad(y, self.ifftpad)
+ y = ncp.pad(y, self.ifftpad)
if not self.clinear:
- y = np.real(y)
+ y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
- y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y
def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:
@@ -161,17 +157,13 @@ def _matvec(self, x: NDArray) -> NDArray:
if not self.clinear:
x = np.real(x)
if self.real:
- y = sp_fft.rfftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = sp_fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axes[-1], -1)
else:
- y = sp_fft.fftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = sp_fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
if self.fftshift_after.any():
@@ -189,13 +181,9 @@ def _rmatvec(self, x: NDArray) -> NDArray:
x = np.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axes[-1], -1)
- y = sp_fft.irfftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = sp_fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
- y = sp_fft.ifftn(
- x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
- )
+ y = sp_fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
@@ -228,7 +216,7 @@ def FFTND(
dtype: DTypeLike = "complex128",
name: str = "F",
**kwargs_fft,
-):
+) -> LinearOperator:
r"""N-dimensional Fast-Fourier Transform.
Apply N-dimensional Fast-Fourier Transform (FFT) to any n ``axes``
@@ -316,7 +304,8 @@ def FFTND(
engine : :obj:`str`, optional
.. versionadded:: 1.17.0
- Engine used for fft computation (``numpy`` or ``scipy``).
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
dtype : :obj:`str`, optional
Type of elements in input array. Note that the ``dtype`` of the operator
is the corresponding complex type even when a real type is provided.
@@ -331,7 +320,9 @@ def FFTND(
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
**kwargs_fft
- Arbitrary keyword arguments to be passed to the selected fft method
+ .. versionadded:: 2.3.0
+
+ Arbitrary keyword arguments to be passed to the selected fft method
Attributes
----------
diff --git a/pylops/signalprocessing/fredholm1.py b/pylops/signalprocessing/fredholm1.py
index 57eec234..feb6c645 100644
--- a/pylops/signalprocessing/fredholm1.py
+++ b/pylops/signalprocessing/fredholm1.py
@@ -3,7 +3,7 @@
import numpy as np
from pylops import LinearOperator
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, NDArray
@@ -118,7 +118,7 @@ def _matvec(self, x: NDArray) -> NDArray:
else:
y = ncp.squeeze(ncp.zeros((self.nsl, self.nx, self.nz), dtype=self.dtype))
for isl in range(self.nsl):
- y[isl] = ncp.dot(self.G[isl], x[isl])
+ y = inplace_set(ncp.dot(self.G[isl], x[isl]), y, isl)
return y
@reshaped
@@ -131,7 +131,6 @@ def _rmatvec(self, x: NDArray) -> NDArray:
if hasattr(self, "GT"):
y = ncp.matmul(self.GT, x)
else:
- # y = ncp.matmul(self.G.transpose((0, 2, 1)).conj(), x)
y = (
ncp.matmul(x.transpose(0, 2, 1).conj(), self.G)
.transpose(0, 2, 1)
@@ -141,9 +140,10 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y = ncp.squeeze(ncp.zeros((self.nsl, self.ny, self.nz), dtype=self.dtype))
if hasattr(self, "GT"):
for isl in range(self.nsl):
- y[isl] = ncp.dot(self.GT[isl], x[isl])
+ y = inplace_set(ncp.dot(self.GT[isl], x[isl]), y, isl)
else:
for isl in range(self.nsl):
- # y[isl] = ncp.dot(self.G[isl].conj().T, x[isl])
- y[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
- return y
+ y = inplace_set(
+ ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj(), y, isl
+ )
+ return y.ravel()
diff --git a/pylops/signalprocessing/nonstatconvolve1d.py b/pylops/signalprocessing/nonstatconvolve1d.py
index 45daeed5..669898ef 100644
--- a/pylops/signalprocessing/nonstatconvolve1d.py
+++ b/pylops/signalprocessing/nonstatconvolve1d.py
@@ -9,7 +9,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -147,7 +147,8 @@ def _interpolate_h(hs, ix, oh, dh, nh):
@reshaped(swapaxis=True)
def _matvec(self, x: NDArray) -> NDArray:
- y = np.zeros_like(x)
+ ncp = get_array_module(x)
+ y = ncp.zeros_like(x)
for ix in range(self.dims[self.axis]):
h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh)
xextremes = (
@@ -158,14 +159,20 @@ def _matvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)),
)
- y[..., xextremes[0] : xextremes[1]] += (
- x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # y[..., xextremes[0] : xextremes[1]] += (
+ # x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # )
+ sl = tuple(
+ [slice(None, None)] * (len(self.dimsd) - 1)
+ + [slice(xextremes[0], xextremes[1])]
)
+ y = inplace_add(x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl)
return y
@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
- y = np.zeros_like(x)
+ ncp = get_array_module(x)
+ y = ncp.zeros_like(x)
for ix in range(self.dims[self.axis]):
h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh)
xextremes = (
@@ -176,17 +183,29 @@ def _rmatvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)),
)
- y[..., ix] = np.sum(
- h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]],
- axis=-1,
+ # y[..., ix] = ncp.sum(
+ # h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]],
+ # axis=-1,
+ # )
+ sl = tuple([slice(None, None)] * (len(self.dimsd) - 1) + [ix])
+ y = inplace_set(
+ ncp.sum(
+ h[hextremes[0] : hextremes[1]]
+ * x[..., xextremes[0] : xextremes[1]],
+ axis=-1,
+ ),
+ y,
+ sl,
)
+
return y
def todense(self):
+ ncp = get_array_module(self.hsinterp[0])
hs = self.hsinterp
- H = np.array(
+ H = ncp.array(
[
- np.roll(np.pad(h, (0, self.dims[self.axis])), ix)
+ ncp.roll(ncp.pad(h, (0, self.dims[self.axis])), ix)
for ix, h in enumerate(hs)
]
)
@@ -317,18 +336,27 @@ def _interpolate_hadj(htmp, hs, hextremes, ix, oh, dh, nh):
"""find closest filters and spread weighted psf"""
ih_closest = int(np.floor((ix - oh) / dh))
if ih_closest < 0:
- hs[0, hextremes[0] : hextremes[1]] += htmp
+ # hs[0, hextremes[0] : hextremes[1]] += htmp
+ sl = tuple([0] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add(htmp, hs, sl)
elif ih_closest >= nh - 1:
- hs[nh - 1, hextremes[0] : hextremes[1]] += htmp
+ # hs[nh - 1, hextremes[0] : hextremes[1]] += htmp
+ sl = tuple([nh - 1] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add(htmp, hs, sl)
else:
dh_closest = (ix - oh) / dh - ih_closest
- hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp
- hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp
+ # hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp
+ sl = tuple([ih_closest] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add((1 - dh_closest) * htmp, hs, sl)
+ # hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp
+ sl = tuple([ih_closest + 1] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add(dh_closest * htmp, hs, sl)
return hs
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
- y = np.zeros(self.dimsd, dtype=self.dtype)
+ ncp = get_array_module(x)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
for ix in range(self.dimsd[0]):
h = self._interpolate_h(x, ix, self.oh, self.dh, self.nh)
xextremes = (
@@ -339,14 +367,23 @@ def _matvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dimsd[0] - ix)),
)
- y[..., xextremes[0] : xextremes[1]] += (
- self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # y[..., xextremes[0] : xextremes[1]] += (
+ # self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # )
+ sl = tuple(
+ [slice(None, None)] * (len(self.dimsd) - 1)
+ + [slice(xextremes[0], xextremes[1])]
+ )
+ y = inplace_add(
+ self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl
)
+
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
- hs = np.zeros(self.dims, dtype=self.dtype)
+ ncp = get_array_module(x)
+ hs = ncp.zeros(self.dims, dtype=self.dtype)
for ix in range(self.dimsd[0]):
xextremes = (
max(0, ix - self.hsize // 2),
diff --git a/pylops/torchoperator.py b/pylops/torchoperator.py
index 5e41f67f..1c4dc2da 100644
--- a/pylops/torchoperator.py
+++ b/pylops/torchoperator.py
@@ -14,7 +14,7 @@
else:
torch_message = (
"Torch package not installed. In order to be able to use"
- 'the twoway module run "pip install torch" or'
+ 'the torchoperator module run "pip install torch" or'
'"conda install -c pytorch torch".'
)
from pylops.utils.typing import TensorTypeLike
diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py
index c56be90d..4b6b506f 100644
--- a/pylops/utils/backend.py
+++ b/pylops/utils/backend.py
@@ -13,10 +13,16 @@
"get_csc_matrix",
"get_sparse_eye",
"get_lstsq",
+ "get_sp_fft",
"get_complex_dtype",
"get_real_dtype",
"to_numpy",
"to_cupy_conditional",
+ "inplace_set",
+ "inplace_add",
+ "inplace_multiply",
+ "inplace_divide",
+ "randn",
]
from types import ModuleType
@@ -45,6 +51,14 @@
from cupyx.scipy.sparse import csc_matrix as cp_csc_matrix
from cupyx.scipy.sparse import eye as cp_eye
+if deps.jax_enabled:
+ import jax
+ import jax.numpy as jnp
+ from jax.scipy.linalg import block_diag as jnp_block_diag
+ from jax.scipy.linalg import toeplitz as jnp_toeplitz
+ from jax.scipy.signal import convolve as j_convolve
+ from jax.scipy.signal import fftconvolve as j_fftconvolve
+
def get_module(backend: str = "numpy") -> ModuleType:
"""Returns correct numerical module based on backend string
@@ -52,21 +66,23 @@ def get_module(backend: str = "numpy") -> ModuleType:
Parameters
----------
backend : :obj:`str`, optional
- Backend used for dot test computations (``numpy`` or ``cupy``). This
+ Backend used for dot test computations (``numpy`` or ``cupy`` or ``jax``). This
parameter will be used to choose how to create the random vectors.
Returns
-------
mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ Module to be used to process array (:mod:`numpy` or :mod:`cupy` or :mod:`jax`)
"""
if backend == "numpy":
ncp = np
elif backend == "cupy":
ncp = cp
+ elif backend == "jax":
+ ncp = jnp
else:
- raise ValueError("backend must be numpy or cupy")
+ raise ValueError("backend must be numpy, cupy, or jax")
return ncp
@@ -76,12 +92,12 @@ def get_module_name(mod: ModuleType) -> str:
Parameters
----------
mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ Module to be used to process array (:mod:`numpy` or :mod:`cupy` or :mod:`jax`)
Returns
-------
backend : :obj:`str`, optional
- Backend used for dot test computations (``numpy`` or ``cupy``). This
+ Backend used for dot test computations (``numpy`` or ``cupy`` or ``jax``). This
parameter will be used to choose how to create the random vectors.
"""
@@ -89,8 +105,10 @@ def get_module_name(mod: ModuleType) -> str:
backend = "numpy"
elif mod == cp:
backend = "cupy"
+ elif mod == jnp:
+ backend = "jax"
else:
- raise ValueError("module must be numpy or cupy")
+ raise ValueError("module must be numpy, cupy, or jax")
return backend
@@ -99,17 +117,23 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ Module to be used to process array
+ (:mod:`numpy`, :mod:`cupy`, or , :mod:`jax`)
"""
- if deps.cupy_enabled:
- return cp.get_array_module(x)
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ return jnp
+ elif deps.cupy_enabled:
+ return cp.get_array_module(x)
+ else:
+ return np
else:
return np
@@ -119,22 +143,24 @@ def get_convolve(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return convolve
-
- if cp.get_array_module(x) == np:
- return convolve
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ return j_convolve
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_convolve
+ else:
+ return convolve
else:
- return cp_convolve
+ return convolve
def get_fftconvolve(x: npt.ArrayLike) -> Callable:
@@ -142,22 +168,24 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return fftconvolve
-
- if cp.get_array_module(x) == np:
- return fftconvolve
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ return j_fftconvolve
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_fftconvolve
+ else:
+ return fftconvolve
else:
- return cp_fftconvolve
+ return fftconvolve
def get_oaconvolve(x: npt.ArrayLike) -> Callable:
@@ -165,22 +193,28 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return oaconvolve
-
- if cp.get_array_module(x) == np:
- return oaconvolve
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ raise NotImplementedError(
+ "oaconvolve not implemented in "
+ "jax. Consider using a different"
+ "option..."
+ )
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_oaconvolve
+ else:
+ return oaconvolve
else:
- return cp_oaconvolve
+ return oaconvolve
def get_correlate(x: npt.ArrayLike) -> Callable:
@@ -188,22 +222,24 @@ def get_correlate(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return correlate
-
- if cp.get_array_module(x) == np:
- return correlate
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ return jax.scipy.signal.correlate
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_correlate
+ else:
+ return correlate
else:
- return cp_correlate
+ return correlate
def get_add_at(x: npt.ArrayLike) -> Callable:
@@ -211,13 +247,13 @@ def get_add_at(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -234,13 +270,13 @@ def get_sliding_window_view(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -257,22 +293,24 @@ def get_block_diag(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return block_diag
-
- if cp.get_array_module(x) == np:
- return block_diag
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ return jnp_block_diag
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_block_diag
+ else:
+ return block_diag
else:
- return cp_block_diag
+ return block_diag
def get_toeplitz(x: npt.ArrayLike) -> Callable:
@@ -285,17 +323,19 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return toeplitz
-
- if cp.get_array_module(x) == np:
- return toeplitz
+ if deps.cupy_enabled or deps.jax_enabled:
+ if isinstance(x, jnp.ndarray):
+ return jnp_toeplitz
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_toeplitz
+ else:
+ return toeplitz
else:
- return cp_toeplitz
+ return toeplitz
def get_csc_matrix(x: npt.ArrayLike) -> Callable:
@@ -308,8 +348,8 @@ def get_csc_matrix(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -331,8 +371,8 @@ def get_sparse_eye(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -354,8 +394,8 @@ def get_lstsq(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -377,8 +417,8 @@ def get_sp_fft(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -433,7 +473,7 @@ def to_numpy(x: NDArray) -> NDArray:
Returns
-------
- x : :obj:`cupy.ndarray`
+ x : :obj:`numpy.ndarray`
Converted array
"""
@@ -464,3 +504,135 @@ def to_cupy_conditional(x: npt.ArrayLike, y: npt.ArrayLike) -> NDArray:
with cp.cuda.Device(x.device):
y = cp.asarray(y)
return y
+
+
+def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace set based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to sum at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].set(x)
+ return y
+ else:
+ y[idx] = x
+ return y
+
+
+def inplace_add(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace add based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to sum at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].add(x)
+ return y
+ else:
+ y[idx] += x
+ return y
+
+
+def inplace_multiply(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace multiplication based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to multiply at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].multiply(x)
+ return y
+ else:
+ y[idx] *= x
+ return y
+
+
+def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace division based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to divide at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].divide(x)
+ return y
+ else:
+ y[idx] /= x
+ return y
+
+
+def randn(*n: int, backend: str = "numpy") -> NDArray:
+ """Returns randomly generated number
+
+ Parameters
+ ----------
+ *n : :obj:`int`
+ Number of samples to generate in each dimension
+ backend : :obj:`str`, optional
+ Backend used for dot test computations (``numpy`` or ``cupy``). This
+ parameter will be used to choose how to create the random vectors.
+
+ Returns
+ -------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Generated array
+
+ """
+ if backend == "numpy":
+ x = np.random.randn(*n)
+ elif backend == "cupy":
+ x = cp.random.randn(*n)
+ elif backend == "jax":
+ x = jnp.array(np.random.randn(*n))
+ else:
+ raise ValueError("backend must be numpy, cupy, or jax")
+ return x
diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py
index 4b2d21e7..ecf69a95 100644
--- a/pylops/utils/deps.py
+++ b/pylops/utils/deps.py
@@ -1,5 +1,6 @@
__all__ = [
"cupy_enabled",
+ "jax_enabled",
"devito_enabled",
"dtcwt_enabled",
"numba_enabled",
@@ -51,6 +52,34 @@ def cupy_import(message: Optional[str] = None) -> str:
return cupy_message
+def jax_import(message: Optional[str] = None) -> str:
+ jax_test = (
+ util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1
+ )
+ if jax_test:
+ try:
+ import_module("jax") # noqa: F401
+
+ jax_message = None
+ except (ImportError, ModuleNotFoundError) as e:
+ jax_message = (
+ f"Failed to import jax, Falling back to numpy (error: {e}). "
+ "Please ensure your environment is set up correctly "
+ "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
+ )
+ print(UserWarning(jax_message))
+ else:
+ jax_message = (
+ "Jax package not installed or os.getenv('JAX_PYLOPS') == 0. "
+ f"In order to be able to use {message} "
+ "ensure 'os.getenv('JAX_PYLOPS') == 1' and run "
+ "'pip install jax'; "
+ "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
+ )
+
+ return jax_message
+
+
def devito_import(message: Optional[str] = None) -> str:
if devito_enabled:
try:
@@ -195,15 +224,18 @@ def sympy_import(message: Optional[str] = None) -> str:
# Set package availability booleans
-# cupy: the package is imported to check everything is working correctly,
-# if not the package is disabled. We do this here as this library is used as drop-in
-# replacement for many numpy and scipy routines when cupy arrays are provided.
+# cupy and jax: the package is imported to check everything is working correctly,
+# if not the package is disabled. We do this here as these libraries are used as drop-in
+# replacement for many numpy and scipy routines when cupy/jax arrays are provided.
# all other libraries: we simply check if the package is available and postpone its import
# to check everything is working correctly when a user tries to create an operator that requires
# such a package
cupy_enabled: bool = (
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
+jax_enabled: bool = (
+ True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False
+)
devito_enabled = util.find_spec("devito") is not None
dtcwt_enabled = util.find_spec("dtcwt") is not None
numba_enabled = util.find_spec("numba") is not None
diff --git a/pylops/utils/dottest.py b/pylops/utils/dottest.py
index ed77b995..c8a198ca 100644
--- a/pylops/utils/dottest.py
+++ b/pylops/utils/dottest.py
@@ -4,7 +4,7 @@
import numpy as np
-from pylops.utils.backend import get_module, to_numpy
+from pylops.utils.backend import get_module, randn, to_numpy
def dottest(
@@ -93,13 +93,13 @@ def dottest(
# make u and v vectors
rdtype = np.ones(1, Op.dtype).real.dtype
- u = ncp.random.randn(nc).astype(rdtype)
+ u = randn(nc, backend=backend).astype(rdtype)
if complexflag not in (0, 2):
- u = u + 1j * ncp.random.randn(nc).astype(rdtype)
+ u = u + 1j * randn(nc, backend=backend).astype(rdtype)
- v = ncp.random.randn(nr).astype(rdtype)
+ v = randn(nr, backend=backend).astype(rdtype)
if complexflag not in (0, 1):
- v = v + 1j * ncp.random.randn(nr).astype(rdtype)
+ v = v + 1j * randn(nr, backend=backend).astype(rdtype)
y = Op.matvec(u) # Op * u
x = Op.rmatvec(v) # Op'* v
diff --git a/pylops/waveeqprocessing/blending.py b/pylops/waveeqprocessing/blending.py
index adca6d93..2bc31c65 100644
--- a/pylops/waveeqprocessing/blending.py
+++ b/pylops/waveeqprocessing/blending.py
@@ -9,7 +9,7 @@
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, HStack, Pad
from pylops.signalprocessing import Shift
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, NDArray
@@ -76,12 +76,7 @@ def __init__(
self.dt = dt
self.times = times
self.shiftall = shiftall
- if np.max(self.times) // dt == np.max(self.times) / dt:
- # do not add extra sample as no shift will be applied
- self.nttot = int(np.max(self.times) / self.dt + self.nt)
- else:
- # add 1 extra sample at the end
- self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
+ self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
if not self.shiftall:
# original implementation, where each source is shifted indipendently
self.PadOp = Pad((self.nr, self.nt), ((0, 0), (0, 1)), dtype=self.dtype)
@@ -143,7 +138,11 @@ def _matvec_smallrecs(self, x: NDArray) -> NDArray:
self.ns, self.nr, self.nt + 1
)
for i, shift_int in enumerate(self.shifts):
- blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data[i]
+ blended_data = inplace_add(
+ shifted_data[i],
+ blended_data,
+ (slice(None, None), slice(shift_int, shift_int + self.nt + 1)),
+ )
return blended_data
@reshaped
@@ -151,7 +150,11 @@ def _rmatvec_smallrecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
shifted_data = ncp.zeros((self.ns, self.nr, self.nt + 1), dtype=self.dtype)
for i, shift_int in enumerate(self.shifts):
- shifted_data[i, :, :] = x[:, shift_int : shift_int + self.nt + 1]
+ shifted_data = inplace_set(
+ x[:, shift_int : shift_int + self.nt + 1],
+ shifted_data,
+ (i, slice(None, None), slice(None, None)),
+ )
deblended_data = self.PadOp._rmatvec(
self.ShiftOp._rmatvec(shifted_data.ravel())
).reshape(self.dims)
@@ -170,7 +173,11 @@ def _matvec_largerecs(self, x: NDArray) -> NDArray:
.matvec(self.PadOp.matvec(x[i, :, :].ravel()))
.reshape(self.ShiftOps[i].dimsd)
)
- blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data
+ blended_data = inplace_add(
+ shifted_data,
+ blended_data,
+ (slice(None, None), slice(shift_int, shift_int + self.nt + 1)),
+ )
return blended_data
@reshaped
@@ -186,7 +193,11 @@ def _rmatvec_largerecs(self, x: NDArray) -> NDArray:
x[:, shift_int : shift_int + self.nt + 1].ravel()
)
).reshape(self.PadOp.dims)
- deblended_data[i, :, :] = shifted_data
+ deblended_data = inplace_set(
+ shifted_data,
+ deblended_data,
+ (i, slice(None, None), slice(None, None)),
+ )
return deblended_data
def _register_multiplications(self) -> None:
diff --git a/pylops/waveeqprocessing/kirchhoff.py b/pylops/waveeqprocessing/kirchhoff.py
index 06c29251..bdd5be87 100644
--- a/pylops/waveeqprocessing/kirchhoff.py
+++ b/pylops/waveeqprocessing/kirchhoff.py
@@ -288,8 +288,11 @@ def __init__(
)
self.rix = np.tile((recs[0] - x[0]) // dx, (ns, 1)).astype(int).ravel()
elif self.ndims == 3:
- # TODO: 3D normalized distances
- raise NotImplementedError("dynamic=True currently not available in 3D")
+ # TODO: compute 3D indices for aperture filter
+ # currently no aperture filter in 3D... just make indices 0
+ # so check if always passed
+ self.six = np.zeros(nr * ns)
+ self.rix = np.zeros(nr * ns)
# compute traveltime and distances
self.travsrcrec = True # use separate tables for src and rec traveltimes
@@ -362,8 +365,26 @@ def __init__(
trav_recs_grad[0], trav_recs_grad[1]
).reshape(np.prod(dims), nr)
else:
- # TODO: 3D
- raise NotImplementedError("dynamic=True currently not available in 3D")
+ trav_srcs_grad = np.concatenate(
+ [trav_srcs_grad[i][np.newaxis] for i in range(3)]
+ )
+ trav_recs_grad = np.concatenate(
+ [trav_recs_grad[i][np.newaxis] for i in range(3)]
+ )
+ self.angle_srcs = (
+ np.sign(trav_srcs_grad[1])
+ * np.arccos(
+ trav_srcs_grad[-1]
+ / np.sqrt(np.sum(trav_srcs_grad**2, axis=0))
+ )
+ ).reshape(np.prod(dims), ns)
+ self.angle_recs = (
+ np.sign(trav_srcs_grad[1])
+ * np.arccos(
+ trav_recs_grad[-1]
+ / np.sqrt(np.sum(trav_recs_grad**2, axis=0))
+ )
+ ).reshape(np.prod(dims), nr)
# pre-compute traveltime indices if total traveltime is used
if not self.travsrcrec:
@@ -386,6 +407,12 @@ def __init__(
# define aperture
# if aperture=None, we want to ensure the check is always matched (no aperture limits...)
+ # if aperture!=None in 3d, force to None as aperture checks are not yet implemented
+ if aperture is not None and self.ndims == 3:
+ aperture = None
+ warnings.warn(
+ "Aperture is forced to None as currently not implemented in 3D"
+ )
if aperture is not None:
warnings.warn(
"Aperture is currently defined as ratio of offset over depth, "
@@ -608,10 +635,10 @@ def _traveltime_table(
# compute traveltime gradients at image points
trav_srcs_grad = np.gradient(
- trav_srcs.reshape(*dims, ns), axis=np.arange(ndims)
+ trav_srcs.reshape(*dims, ns), *dsamp, axis=np.arange(ndims)
)
trav_recs_grad = np.gradient(
- trav_recs.reshape(*dims, nr), axis=np.arange(ndims)
+ trav_recs.reshape(*dims, nr), *dsamp, axis=np.arange(ndims)
)
return (
diff --git a/pylops/waveeqprocessing/wavedecomposition.py b/pylops/waveeqprocessing/wavedecomposition.py
index 7d926d36..715fb2c1 100644
--- a/pylops/waveeqprocessing/wavedecomposition.py
+++ b/pylops/waveeqprocessing/wavedecomposition.py
@@ -156,6 +156,7 @@ def _obliquity3D(
critical: float = 100.0,
ntaper: int = 10,
composition: bool = True,
+ fftengine: str = "scipy",
backend: str = "numpy",
dtype: DTypeLike = "complex128",
) -> Tuple[LinearOperator, LinearOperator]:
@@ -187,6 +188,9 @@ def _obliquity3D(
composition : :obj:`bool`, optional
Create obliquity factor for composition (``True``) or
decomposition (``False``)
+ fftengine : :obj:`str`, optional
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
backend : :obj:`str`, optional
Backend used for creation of obliquity factor operator
(``numpy`` or ``cupy``)
@@ -203,7 +207,11 @@ def _obliquity3D(
"""
# create Fourier operator
FFTop = FFTND(
- dims=[nr[0], nr[1], nt], nffts=nffts, sampling=[dr[0], dr[1], dt], dtype=dtype
+ dims=[nr[0], nr[1], nt],
+ nffts=nffts,
+ sampling=[dr[0], dr[1], dt],
+ engine=fftengine,
+ dtype=dtype,
)
# create obliquity operator
@@ -547,6 +555,7 @@ def UpDownComposition3D(
critical: float = 100.0,
ntaper: int = 10,
scaling: float = 1.0,
+ fftengine: str = "scipy",
backend: str = "numpy",
dtype: DTypeLike = "complex128",
name: str = "U",
@@ -588,6 +597,11 @@ def UpDownComposition3D(
angle
scaling : :obj:`float`, optional
Scaling to apply to the operator (see Notes for more details)
+ fftengine : :obj:`str`, optional
+ .. versionadded:: 2.3.0
+
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
backend : :obj:`str`, optional
Backend used for creation of obliquity factor operator
(``numpy`` or ``cupy``)
@@ -638,6 +652,7 @@ def UpDownComposition3D(
critical=critical,
ntaper=ntaper,
composition=True,
+ fftengine=fftengine,
backend=backend,
dtype=dtype,
)
diff --git a/pytests/test_jaxoperator.py b/pytests/test_jaxoperator.py
new file mode 100755
index 00000000..86de4e8d
--- /dev/null
+++ b/pytests/test_jaxoperator.py
@@ -0,0 +1,53 @@
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+from numpy.testing import assert_array_almost_equal, assert_array_equal
+
+from pylops import JaxOperator, MatrixMult
+
+par1 = {"ny": 11, "nx": 11, "dtype": np.float32} # square
+par2 = {"ny": 21, "nx": 11, "dtype": np.float32} # overdetermined
+
+np.random.seed(0)
+
+
+@pytest.mark.parametrize("par", [(par1)])
+def test_JaxOperator(par):
+ """Apply forward and adjoint and compare with native pylops."""
+ M = np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(par["dtype"])
+ Mop = MatrixMult(jnp.array(M), dtype=par["dtype"])
+ Jop = JaxOperator(Mop)
+
+ x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
+ xjnp = jnp.array(x)
+
+ # pylops operator
+ y = Mop * x
+ xadj = Mop.H * y
+
+ # jax operator
+ yjnp = Jop * xjnp
+ xadjnp = Jop.rmatvecad(xjnp, yjnp)
+
+ assert_array_equal(y, np.array(yjnp))
+ assert_array_equal(xadj, np.array(xadjnp))
+
+
+@pytest.mark.parametrize("par", [(par1)])
+def test_TorchOperator_batch(par):
+ """Apply forward for input with multiple samples
+ (= batch) and flattened arrays"""
+
+ M = np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(par["dtype"])
+ Mop = MatrixMult(jnp.array(M), dtype=par["dtype"])
+ Jop = JaxOperator(Mop)
+ auto_batch_matvec = jax.vmap(Jop._matvec)
+
+ x = np.random.normal(0.0, 1.0, (4, par["nx"])).astype(par["dtype"])
+ xjnp = jnp.array(x)
+
+ y = Mop.matmat(x.T).T
+ yjnp = auto_batch_matvec(xjnp)
+
+ assert_array_almost_equal(y, np.array(yjnp), decimal=5)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 703b377f..6ce1fb00 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -2,6 +2,7 @@ numpy>=1.21.0
scipy>=1.11.0
--extra-index-url https://download.pytorch.org/whl/cpu
torch>=1.2.0
+jax
numba
pyfftw
PyWavelets
@@ -19,6 +20,7 @@ docutils<0.18
Sphinx
pydata-sphinx-theme
sphinx-gallery
+sphinxemoji
numpydoc
nbsphinx
image
diff --git a/tutorials/ilsm.py b/tutorials/ilsm.py
index b4b016f3..1f394bfc 100755
--- a/tutorials/ilsm.py
+++ b/tutorials/ilsm.py
@@ -1,5 +1,5 @@
r"""
-20. Image Domain Least-squares migration
+19. Image Domain Least-squares migration
========================================
Seismic migration is the process by which seismic data are manipulated to create
an image of the subsurface reflectivity.
diff --git a/tutorials/jaxop.py b/tutorials/jaxop.py
new file mode 100755
index 00000000..c7a30d40
--- /dev/null
+++ b/tutorials/jaxop.py
@@ -0,0 +1,103 @@
+r"""
+21. JAX Operator
+================
+This tutorial is aimed at introducing the :class:`pylops.JaxOperator` operator. This
+represents the entry-point to the JAX backend of PyLops.
+
+More specifically, by wrapping any of PyLops' operators into a
+:class:`pylops.JaxOperator` one can:
+
+- apply forward, adjoint and use any of PyLops solver with JAX arrays;
+- enable automatic differentiation;
+- enable automatic vectorization.
+
+Moreover, both the forward and adjoint are internally just-in-time compiled
+to enable any further optimization provided by JAX.
+
+In this example we will consider a :class:`pylops.MatrixMult` operator and
+showcase how to use it in conjunction with :class:`pylops.JaxOperator`
+to enable the different JAX functionalities mentioned above.
+
+"""
+import jax
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+import numpy as np
+
+import pylops
+
+plt.close("all")
+np.random.seed(10)
+
+###############################################################################
+# Let's start by creating a :class:`pylops.MatrixMult` operator. We will then
+# perform the dot-test as well as apply the forward and adjoint operations to
+# JAX arrays.
+
+n = 4
+G = np.random.normal(0, 1, (n, n)).astype("float32")
+Gopjax = pylops.JaxOperator(pylops.MatrixMult(jnp.array(G), dtype="float32"))
+
+# dottest
+pylops.utils.dottest(Gopjax, n, n, backend="jax", verb=True, atol=1e-3)
+
+# forward
+xjnp = jnp.ones(n, dtype="float32")
+yjnp = Gopjax @ xjnp
+
+# adjoint
+xadjjnp = Gopjax.H @ yjnp
+
+###############################################################################
+# We can now use one of PyLops solvers to invert the operator
+
+xcgls = pylops.optimization.basic.cgls(
+ Gopjax, yjnp, x0=jnp.zeros(n), niter=100, tol=1e-10, show=True
+)[0]
+print("Inverse: ", xcgls)
+
+###############################################################################
+# Let's see how we can empower the automatic differentiation capabilities
+# of JAX to obtain the adjoint of our operator without having to implement it.
+# Although in PyLops the adjoint of any of operators is hand-written (and
+# optimized), it may be useful in some cases to quickly implement the forward
+# pass of a new operator and get the adjoint for free. This could be extremely
+# beneficial during the prototyping stage of an operator before embarking in
+# implementing an efficient hand-written adjoint.
+
+xadjjnpad = Gopjax.rmatvecad(xjnp, yjnp)
+
+print("Hand-written Adjoint: ", xadjjnp)
+print("AD Adjoint: ", xadjjnpad)
+
+###############################################################################
+# And more in general how we can combine any of JAX native operations with a
+# PyLops operator.
+
+
+def fun(x):
+ y = Gopjax(x)
+ loss = jnp.sum(y)
+ return loss
+
+
+xgrad = jax.grad(fun)(xjnp)
+print("Grad: ", xgrad)
+
+###############################################################################
+# We turn now our attention to automatic vectorization, which is very useful
+# if we want to apply the same operator to multiple vectors. In PyLops we can
+# easily do so by using the ``matmat`` and ``rmatmat`` methods, however under
+# the hood what these methods do is to simply run a for...loop and call the
+# corresponding ``matvec`` / ``rmatvec`` methods multiple times. On the other
+# hand, JAX is able to automatically add a batch axis at the beginning of
+# operator. Moreover, this can be seamlessly combined with `jax.jit` to
+# further improve performance.
+
+auto_batch_matvec = jax.jit(jax.vmap(Gopjax._matvec))
+xs = jnp.stack([xjnp, xjnp])
+ys = auto_batch_matvec(xs)
+
+print("Original output: ", yjnp)
+print("AV Output 1: ", ys[0])
+print("AV Output 1: ", ys[1])
diff --git a/tutorials/torchop.py b/tutorials/torchop.py
index c555573d..9e73d7b3 100755
--- a/tutorials/torchop.py
+++ b/tutorials/torchop.py
@@ -1,6 +1,6 @@
r"""
-19. Automatic Differentiation
-=============================
+20. Torch Operator
+==================
This tutorial focuses on the use of :class:`pylops.TorchOperator` to allow performing
Automatic Differentiation (AD) on chains of operators which can be: