diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index e86981fe..1d77dc15 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -29,6 +29,7 @@ Templates FunctionOperator MemoizeOperator TorchOperator + JaxOperator Basic operators ~~~~~~~~~~~~~~~ diff --git a/docs/source/conf.py b/docs/source/conf.py index caf745e5..c5e6536d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,6 +21,7 @@ "numpydoc", "nbsphinx", "sphinx_gallery.gen_gallery", + "sphinxemoji.sphinxemoji", # 'sphinx.ext.napoleon', ] @@ -29,6 +30,8 @@ "python": ("https://docs.python.org/3/", None), "numpy": ("https://docs.scipy.org/doc/numpy/", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "cupy": ("https://docs.cupy.dev/en/stable/", None), + "jax": ("https://jax.readthedocs.io/en/latest", None), "sklearn": ("http://scikit-learn.org/stable/", None), "pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None), "matplotlib": ("https://matplotlib.org/", None), diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index e861c567..7f74c373 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -1,55 +1,404 @@ .. _gpu: -GPU Support -=========== +GPU / TPU Support +================= Overview -------- -PyLops supports computations on GPUs powered by `CuPy `_ (``cupy-cudaXX>=v13.0.0``). +From ``v1.12.0``, PyLops supports computations on GPUs powered by +`CuPy `_ (``cupy-cudaXX>=8.1.0``). This library must be installed *before* PyLops is installed. -.. note:: - - Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. - This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system, - otherwise you will get an error when importing PyLops. +From ``v2.3.0``, PyLops supports also computations on GPUs/TPUs powered by +`JAX `_. +This library must be installed *before* PyLops is installed. +.. note:: + Set environment variables ``CUPY_PYLOPS=0`` and/or ``JAX_PYLOPS=0`` to force PyLops to ignore + ``cupy`` and ``jax`` backends. This can be also used if a previous version of ``cupy`` + or ``jax`` is installed in your system, otherwise you will get an error when importing PyLops. Apart from a few exceptions, all operators and solvers in PyLops can -seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy`` arrays -on GPU. Users do simply need to consistently create operators and +seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy/jax`` arrays +on GPU. For CuPy, users simply need to consistently create operators and provide data vectors to the solvers, e.g., when using :class:`pylops.MatrixMult` the input matrix must be a ``cupy`` array if the data provided to a solver is also ``cupy`` array. +For JAX, apart from following the same procedure described for CuPy, the PyLops operator must +be also wrapped into a :class:`pylops.JaxOperator`. -.. warning:: - Some :class:`pylops.LinearOperator` methods are currently on GPU: +In the following, we provide a list of methods in :class:`pylops.LinearOperator` with their current status (available on CPU, +GPU with CuPy, and GPU with JAX): - - :meth:`pylops.LinearOperator.eigs` - - :meth:`pylops.LinearOperator.cond` - - :meth:`pylops.LinearOperator.tosparse` - - :meth:`pylops.LinearOperator.estimate_spectral_norm` +.. list-table:: + :widths: 50 25 25 25 + :header-rows: 1 -.. warning:: + * - Operator/method + - CPU + - GPU with CuPy + - GPU/TPU with JAX + * - :meth:`pylops.LinearOperator.cond` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :meth:`pylops.LinearOperator.conj` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.LinearOperator.div` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.LinearOperator.eigs` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :meth:`pylops.LinearOperator.todense` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.LinearOperator.tosparse` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :meth:`pylops.LinearOperator.trace` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + +Similarly, we provide a list of operators with their current status. + +Basic operators: + +.. list-table:: + :widths: 50 25 25 25 + :header-rows: 1 + + * - Operator/method + - CPU + - GPU with CuPy + - GPU/TPU with JAX + * - :class:`pylops.basicoperators.MatrixMult` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Identity` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Zero` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Diagonal` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :meth:`pylops.basicoperators.Transpose` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Flip` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Roll` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Pad` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Sum` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Symmetrize` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Restriction` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Regression` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.LinearRegression` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.CausalIntegration` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Spread` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.basicoperators.VStack` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.HStack` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Block` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.BlockDiag` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + + +Smoothing and derivatives: + +.. list-table:: + :widths: 50 25 25 25 + :header-rows: 1 + + * - Operator/method + - CPU + - GPU with CuPy + - GPU/TPU with JAX + * - :class:`pylops.basicoperators.FirstDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.SecondDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Laplacian` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.Gradient` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.FirstDirectionalDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.basicoperators.SecondDirectionalDerivative` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + +Signal processing: + +.. list-table:: + :widths: 50 25 25 25 + :header-rows: 1 + + * - Operator/method + - CPU + - GPU with CuPy + - GPU/TPU with JAX + * - :class:`pylops.signalprocessing.Convolve1D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:warning:| + * - :class:`pylops.signalprocessing.Convolve2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.ConvolveND` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.NonStationaryConvolve1D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.NonStationaryFilters1D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.NonStationaryConvolve2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.NonStationaryFilters2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Interp` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.Bilinear` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.FFT` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.FFT2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.FFTND` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.Shift` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.signalprocessing.DWT` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.DWT2D` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.DCT` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Seislet` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Radon2D` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Radon3D` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.ChirpRadon2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.ChirpRadon3D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Sliding1D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Sliding2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Sliding3D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Patch2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Patch3D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:red_circle:| + * - :class:`pylops.signalprocessing.Fredholm1` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| - Some operators are currently not available on GPU: +Wave-Equation processing - - :class:`pylops.Spread` - - :class:`pylops.signalprocessing.Radon2D` - - :class:`pylops.signalprocessing.Radon3D` - - :class:`pylops.signalprocessing.DWT` - - :class:`pylops.signalprocessing.DWT2D` - - :class:`pylops.signalprocessing.Seislet` - - :class:`pylops.waveeqprocessing.Demigration` - - :class:`pylops.waveeqprocessing.LSM` +.. list-table:: + :widths: 50 25 25 25 + :header-rows: 1 + + * - Operator/method + - CPU + - GPU with CuPy + - GPU/TPU with JAX + * - :class:`pylops.avo.avo.PressureToVelocity` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.UpDownComposition2D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.UpDownComposition3D` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.BlendingContinuous` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.BlendingGroup` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.BlendingHalf` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.MDC` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.avo.Kirchhoff` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + * - :class:`pylops.avo.avo.AcousticWave2D` + - |:white_check_mark:| + - |:red_circle:| + - |:red_circle:| + +Geophysical subsurface characterization: + +.. list-table:: + :widths: 50 25 25 25 + :header-rows: 1 + + * - Operator/method + - CPU + - GPU with CuPy + - GPU/TPU with JAX + * - :class:`pylops.avo.avo.AVOLinearModelling` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.poststack.PoststackLinearModelling` + - |:white_check_mark:| + - |:white_check_mark:| + - |:white_check_mark:| + * - :class:`pylops.avo.prestack.PrestackLinearModelling` + - |:white_check_mark:| + - |:white_check_mark:| + - |:warning:| + * - :class:`pylops.avo.prestack.PrestackWaveletModelling` + - |:white_check_mark:| + - |:white_check_mark:| + - |:warning:| .. warning:: - Some solvers are currently not available on GPU: - - :class:`pylops.optimization.sparsity.SPGL1` + 1. The JAX backend of the :class:`pylops.signalprocessing.Convolve1D` operator + currently works only with 1d-arrays due to a different behaviour of + :meth:`scipy.signal.convolve` and :meth:`jax.scipy.signal.convolve` with + nd-arrays. + + 2. The JAX backend of the :class:`pylops.avo.prestack.PrestackLinearModelling` + operator currently works only with ``explicit=True`` due to the same issue as + in point 1 for the :class:`pylops.signalprocessing.Convolve1D` operator employed + when ``explicit=False``. Example @@ -68,8 +417,7 @@ Finally, let's briefly look at an example. First we write a code snippet using y = Gop * x xest = Gop / y - -Now we write a code snippet using ``cupy`` arrays which PyLops will run on +Now we write a code snippet using ``cupy`` arrays which PyLops will run on your GPU: .. code-block:: python @@ -83,9 +431,28 @@ your GPU: xest = Gop / y The code is almost unchanged apart from the fact that we now use ``cupy`` arrays, -PyLops will figure this out! +PyLops will figure this out. + +Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on +your GPU/TPU: + +.. code-block:: python + + ny, nx = 400, 400 + G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32)) + x = jnp.ones(nx, dtype=np.float32) + + Gop = JaxOperator(MatrixMult(G, dtype='float32')) + y = Gop * x + xest = Gop / y + + # Adjoint via AD + xadj = Gop.rmatvecad(x, y) + + +Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays, .. note:: - The CuPy backend is in active development, with many examples not yet in the docs. - You can find many `other examples `_ from the `PyLops Notebooks repository `_. + More examples for the CuPy and JAX backends be found `here `_ + and `here `_. \ No newline at end of file diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index 413a9759..0081a0ad 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -11,6 +11,7 @@ dependencies: - scipy>=1.11.0 - pytorch>=1.2.0 - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -34,6 +35,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/environment-dev.yml b/environment-dev.yml index ef51f696..eb51c4dc 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -11,6 +11,7 @@ dependencies: - scipy>=1.11.0 - pytorch>=1.2.0 - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -35,6 +36,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/pylops/__init__.py b/pylops/__init__.py index 55d4ce3d..7672fda4 100755 --- a/pylops/__init__.py +++ b/pylops/__init__.py @@ -48,6 +48,7 @@ from .config import * from .linearoperator import * from .torchoperator import * +from .jaxoperator import * from .basicoperators import * from . import ( avo, diff --git a/pylops/avo/poststack.py b/pylops/avo/poststack.py index 7e514707..8a9001ac 100644 --- a/pylops/avo/poststack.py +++ b/pylops/avo/poststack.py @@ -27,6 +27,7 @@ get_csc_matrix, get_lstsq, get_module_name, + inplace_set, ) from pylops.utils.signalprocessing import convmtx, nonstationary_convmtx from pylops.utils.typing import NDArray, ShapeLike @@ -93,12 +94,13 @@ def _PoststackLinearModelling( D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag( 0.5 * ncp.ones(nt0 - 1, dtype=dtype), -1 ) - D[0] = D[-1] = 0 + D = inplace_set(ncp.array(0.0), D, 0) + D = inplace_set(ncp.array(0.0), D, -1) else: D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag( ncp.ones(nt0, dtype=dtype), k=0 ) - D[-1] = 0 + D = inplace_set(ncp.array(0.0), D, -1) # Create wavelet operator if len(wav.shape) == 1: diff --git a/pylops/avo/prestack.py b/pylops/avo/prestack.py index 8630bc9a..4cf6c4eb 100644 --- a/pylops/avo/prestack.py +++ b/pylops/avo/prestack.py @@ -31,6 +31,7 @@ get_block_diag, get_lstsq, get_module_name, + inplace_set, ) from pylops.utils.signalprocessing import convmtx from pylops.utils.typing import NDArray, ShapeLike @@ -182,12 +183,13 @@ def PrestackLinearModelling( D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag( 0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=-1 ) - D[0] = D[-1] = 0 + D = inplace_set(ncp.array(0.0), D, 0) + D = inplace_set(ncp.array(0.0), D, -1) else: D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag( ncp.ones(nt0, dtype=dtype), k=0 ) - D[-1] = 0 + D = inplace_set(ncp.array(0.0), D, -1) D = get_block_diag(theta)(*([D] * nG)) # Create wavelet operator @@ -339,7 +341,8 @@ def PrestackWaveletModelling( D = ncp.diag(0.5 * np.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag( 0.5 * np.ones(nt0 - 1, dtype=dtype), k=-1 ) - D[0] = D[-1] = 0 + D = inplace_set(ncp.array(0.0), D, 0) + D = inplace_set(ncp.array(0.0), D, -1) D = get_block_diag(theta)(*([D] * nG)) # Create infinite-reflectivity data diff --git a/pylops/basicoperators/blockdiag.py b/pylops/basicoperators/blockdiag.py index e13ed026..166ae137 100644 --- a/pylops/basicoperators/blockdiag.py +++ b/pylops/basicoperators/blockdiag.py @@ -21,7 +21,7 @@ from pylops import LinearOperator from pylops.basicoperators import MatrixMult -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_set from pylops.utils.typing import DTypeLike, NDArray @@ -175,18 +175,22 @@ def _matvec_serial(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.nops, dtype=self.dtype) for iop, oper in enumerate(self.ops): - y[self.nnops[iop] : self.nnops[iop + 1]] = oper.matvec( - x[self.mmops[iop] : self.mmops[iop + 1]] - ).squeeze() + y = inplace_set( + oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze(), + y, + slice(self.nnops[iop], self.nnops[iop + 1]), + ) return y def _rmatvec_serial(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.mops, dtype=self.dtype) for iop, oper in enumerate(self.ops): - y[self.mmops[iop] : self.mmops[iop + 1]] = oper.rmatvec( - x[self.nnops[iop] : self.nnops[iop + 1]] - ).squeeze() + y = inplace_set( + oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze(), + y, + slice(self.mmops[iop], self.mmops[iop + 1]), + ) return y def _matvec_multiproc(self, x: NDArray) -> NDArray: diff --git a/pylops/basicoperators/firstderivative.py b/pylops/basicoperators/firstderivative.py index 58edf17f..f8bd208e 100644 --- a/pylops/basicoperators/firstderivative.py +++ b/pylops/basicoperators/firstderivative.py @@ -7,7 +7,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -100,6 +100,16 @@ def __init__( self.kind = kind self.edge = edge self.order = order + self.slice = { + i: { + j: tuple([slice(None, None)] * (len(dims) - 1) + [slice(i, j)]) + for j in (None, -1, -2, -3, -4) + } + for i in (None, 1, 2, 3, 4) + } + self.sample = { + i: tuple([slice(None, None)] * (len(dims) - 1) + [i]) for i in range(-3, 4) + } self._register_multiplications(self.kind, self.order) def _register_multiplications( @@ -140,15 +150,20 @@ def _rmatvec(self, x: NDArray) -> NDArray: def _matvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling + # y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling + y = inplace_set( + (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[None][-1] + ) return y @reshaped(swapaxis=True) def _rmatvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-1] -= x[..., :-1] - y[..., 1:] += x[..., :-1] + # y[..., :-1] -= x[..., :-1] + y = inplace_add(-x[..., :-1], y, self.slice[None][-1]) + # y[..., 1:] += x[..., :-1] + y = inplace_add(x[..., :-1], y, self.slice[1][None]) y /= self.sampling return y @@ -156,10 +171,13 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray: def _matvec_centered3(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2]) + # y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2]) + y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice[1][-1]) if self.edge: - y[..., 0] = x[..., 1] - x[..., 0] - y[..., -1] = x[..., -1] - x[..., -2] + # y[..., 0] = x[..., 1] - x[..., 0] + y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0]) + # y[..., -1] = x[..., -1] - x[..., -2] + y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1]) y /= self.sampling return y @@ -167,13 +185,19 @@ def _matvec_centered3(self, x: NDArray) -> NDArray: def _rmatvec_centered3(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] -= 0.5 * x[..., 1:-1] - y[..., 2:] += 0.5 * x[..., 1:-1] + # y[..., :-2] -= 0.5 * x[..., 1:-1] + y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice[None][-2]) + # y[..., 2:] += 0.5 * x[..., 1:-1] + y = inplace_add(0.5 * x[..., 1:-1], y, self.slice[2][None]) if self.edge: - y[..., 0] -= x[..., 0] - y[..., 1] += x[..., 0] - y[..., -2] -= x[..., -1] - y[..., -1] += x[..., -1] + # y[..., 0] -= x[..., 0] + y = inplace_add(-x[..., 0], y, self.sample[0]) + # y[..., 1] += x[..., 0] + y = inplace_add(x[..., 0], y, self.sample[1]) + # y[..., -2] -= x[..., -1] + y = inplace_add(-x[..., -1], y, self.sample[-2]) + # y[..., -1] += x[..., -1] + y = inplace_add(x[..., -1], y, self.sample[-1]) y /= self.sampling return y @@ -181,17 +205,31 @@ def _rmatvec_centered3(self, x: NDArray) -> NDArray: def _matvec_centered5(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 2:-2] = ( - x[..., :-4] / 12.0 - - 2 * x[..., 1:-3] / 3.0 - + 2 * x[..., 3:-1] / 3.0 - - x[..., 4:] / 12.0 + # y[..., 2:-2] = ( + # x[..., :-4] / 12.0 + # - 2 * x[..., 1:-3] / 3.0 + # + 2 * x[..., 3:-1] / 3.0 + # - x[..., 4:] / 12.0 + # ) + y = inplace_set( + ( + x[..., :-4] / 12.0 + - 2 * x[..., 1:-3] / 3.0 + + 2 * x[..., 3:-1] / 3.0 + - x[..., 4:] / 12.0 + ), + y, + self.slice[2][-2], ) if self.edge: - y[..., 0] = x[..., 1] - x[..., 0] - y[..., 1] = 0.5 * (x[..., 2] - x[..., 0]) - y[..., -2] = 0.5 * (x[..., -1] - x[..., -3]) - y[..., -1] = x[..., -1] - x[..., -2] + # y[..., 0] = x[..., 1] - x[..., 0] + y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0]) + # y[..., 1] = 0.5 * (x[..., 2] - x[..., 0]) + y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample[1]) + # y[..., -2] = 0.5 * (x[..., -1] - x[..., -3]) + y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample[-2]) + # y[..., -1] = x[..., -1] - x[..., -2] + y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1]) y /= self.sampling return y @@ -199,17 +237,27 @@ def _matvec_centered5(self, x: NDArray) -> NDArray: def _rmatvec_centered5(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-4] += x[..., 2:-2] / 12.0 - y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0 - y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0 - y[..., 4:] -= x[..., 2:-2] / 12.0 + # y[..., :-4] += x[..., 2:-2] / 12.0 + y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice[None][-4]) + # y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0 + y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice[1][-3]) + # y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0 + y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice[3][-1]) + # y[..., 4:] -= x[..., 2:-2] / 12.0 + y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice[4][None]) if self.edge: - y[..., 0] -= x[..., 0] + 0.5 * x[..., 1] - y[..., 1] += x[..., 0] - y[..., 2] += 0.5 * x[..., 1] - y[..., -3] -= 0.5 * x[..., -2] - y[..., -2] -= x[..., -1] - y[..., -1] += 0.5 * x[..., -2] + x[..., -1] + # y[..., 0] -= x[..., 0] + 0.5 * x[..., 1] + y = inplace_add(-(x[..., 0] + 0.5 * x[..., 1]), y, self.sample[0]) + # y[..., 1] += x[..., 0] + y = inplace_add(x[..., 0], y, self.sample[1]) + # y[..., 2] += 0.5 * x[..., 1] + y = inplace_add(0.5 * x[..., 1], y, self.sample[2]) + # y[..., -3] -= 0.5 * x[..., -2] + y = inplace_add(-0.5 * x[..., -2], y, self.sample[-3]) + # y[..., -2] -= x[..., -1] + y = inplace_add(-x[..., -1], y, self.sample[-2]) + # y[..., -1] += 0.5 * x[..., -2] + x[..., -1] + y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample[-1]) y /= self.sampling return y @@ -217,14 +265,19 @@ def _rmatvec_centered5(self, x: NDArray) -> NDArray: def _matvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling + # y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling + y = inplace_set( + (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[1][None] + ) return y @reshaped(swapaxis=True) def _rmatvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-1] -= x[..., 1:] - y[..., 1:] += x[..., 1:] + # y[..., :-1] -= x[..., 1:] + y = inplace_add(-x[..., 1:], y, self.slice[None][-1]) + # y[..., 1:] += x[..., 1:] + y = inplace_add(x[..., 1:], y, self.slice[1][None]) y /= self.sampling return y diff --git a/pylops/basicoperators/hstack.py b/pylops/basicoperators/hstack.py index 5cfbbec0..b71e8723 100644 --- a/pylops/basicoperators/hstack.py +++ b/pylops/basicoperators/hstack.py @@ -21,7 +21,7 @@ from pylops import LinearOperator from pylops.basicoperators import MatrixMult -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.typing import NDArray @@ -165,14 +165,22 @@ def _matvec_serial(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.nops, dtype=self.dtype) for iop, oper in enumerate(self.ops): - y += oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze() + y = inplace_add( + oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze(), + y, + slice(None, None), + ) return y def _rmatvec_serial(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.mops, dtype=self.dtype) for iop, oper in enumerate(self.ops): - y[self.mmops[iop] : self.mmops[iop + 1]] = oper.rmatvec(x).squeeze() + y = inplace_set( + oper.rmatvec(x).squeeze(), + y, + slice(self.mmops[iop], self.mmops[iop + 1]), + ) return y def _matvec_multiproc(self, x: NDArray) -> NDArray: diff --git a/pylops/basicoperators/identity.py b/pylops/basicoperators/identity.py index c2d05a30..50b76831 100644 --- a/pylops/basicoperators/identity.py +++ b/pylops/basicoperators/identity.py @@ -6,7 +6,7 @@ import numpy as np from pylops import LinearOperator -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -181,7 +181,7 @@ def _matvec(self, x: NDArray) -> NDArray: y = x[self.sliceN] else: y = ncp.zeros(self.dimsd, dtype=self.dtype) - y[self.sliceM] = x + y = inplace_set(x, y, self.sliceM) return y @reshaped @@ -193,7 +193,7 @@ def _rmatvec(self, x: NDArray) -> NDArray: y = x elif self.mode == "model": y = ncp.zeros(self.dims, dtype=self.dtype) - y[self.sliceN] = x + y = inplace_set(x, y, self.sliceN) else: y = x[self.sliceM] return y diff --git a/pylops/basicoperators/pad.py b/pylops/basicoperators/pad.py index 45b63af8..d98a894c 100644 --- a/pylops/basicoperators/pad.py +++ b/pylops/basicoperators/pad.py @@ -6,6 +6,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.backend import get_array_module from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -85,10 +86,12 @@ def __init__( @reshaped def _matvec(self, x: NDArray) -> NDArray: - return np.pad(x, self.pad, mode="constant") + ncp = get_array_module(x) + return ncp.pad(x, self.pad, mode="constant") @reshaped def _rmatvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) for ax, (before, _) in enumerate(self.pad): - x = np.take(x, np.arange(before, before + self.dims[ax]), axis=ax) + x = ncp.take(x, ncp.arange(before, before + self.dims[ax]), axis=ax) return x diff --git a/pylops/basicoperators/restriction.py b/pylops/basicoperators/restriction.py index fc81a252..1a745b30 100644 --- a/pylops/basicoperators/restriction.py +++ b/pylops/basicoperators/restriction.py @@ -16,7 +16,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module, to_cupy_conditional +from pylops.utils.backend import get_array_module, inplace_set, to_cupy_conditional from pylops.utils.typing import DTypeLike, InputDimsLike, IntNDArray, NDArray logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -26,13 +26,13 @@ def _compute_iavamask(dims, axis, iava, ncp): """Compute restriction mask when using cupy arrays""" otherdims = np.array(dims) otherdims = np.delete(otherdims, axis) - iavamask = ncp.zeros(int(dims[axis]), dtype=int) + iavamask = np.zeros(int(dims[axis]), dtype=int) iavamask[iava] = 1 - iavamask = ncp.moveaxis( - ncp.broadcast_to(iavamask, list(otherdims) + [dims[axis]]), -1, axis + iavamask = np.moveaxis( + np.broadcast_to(iavamask, list(otherdims) + [dims[axis]]), -1, axis ) - iavamask = ncp.where(iavamask.ravel() == 1)[0] - return iavamask + iavamask = np.where(iavamask.ravel() == 1)[0] + return ncp.asarray(iavamask) class Restriction(LinearOperator): @@ -179,7 +179,7 @@ def _rmatvec(self, x: NDArray) -> NDArray: self.iava = to_cupy_conditional(x, self.iava) self.iavamask = _compute_iavamask(self.dims, self.axis, self.iava, ncp) y = ncp.zeros(int(self.shape[-1]), dtype=self.dtype) - y[self.iavamask] = x.ravel() + y = inplace_set(x.ravel(), y, self.iavamask) y = y.ravel() return y diff --git a/pylops/basicoperators/roll.py b/pylops/basicoperators/roll.py index 8fc27e4d..29e6f613 100644 --- a/pylops/basicoperators/roll.py +++ b/pylops/basicoperators/roll.py @@ -6,6 +6,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.backend import get_array_module from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -64,8 +65,10 @@ def __init__( @reshaped(swapaxis=True) def _matvec(self, x: NDArray) -> NDArray: - return np.roll(x, shift=self.shift, axis=-1) + ncp = get_array_module(x) + return ncp.roll(x, shift=self.shift, axis=-1) @reshaped(swapaxis=True) def _rmatvec(self, x: NDArray) -> NDArray: - return np.roll(x, shift=-self.shift, axis=-1) + ncp = get_array_module(x) + return ncp.roll(x, shift=-self.shift, axis=-1) diff --git a/pylops/basicoperators/secondderivative.py b/pylops/basicoperators/secondderivative.py index 744d067a..8433987d 100644 --- a/pylops/basicoperators/secondderivative.py +++ b/pylops/basicoperators/secondderivative.py @@ -7,7 +7,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -90,6 +90,16 @@ def __init__( self.sampling = sampling self.kind = kind self.edge = edge + self.slice = { + i: { + j: tuple([slice(None, None)] * (len(dims) - 1) + [slice(i, j)]) + for j in (None, -1, -2, -3, -4) + } + for i in (None, 1, 2, 3, 4) + } + self.sample = { + i: tuple([slice(None, None)] * (len(dims) - 1) + [i]) for i in range(-3, 4) + } self._register_multiplications(self.kind) def _register_multiplications( @@ -123,7 +133,10 @@ def _rmatvec(self, x: NDArray) -> NDArray: def _matvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + # y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + y = inplace_set( + x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[None][-2] + ) y /= self.sampling**2 return y @@ -131,9 +144,12 @@ def _matvec_forward(self, x: NDArray) -> NDArray: def _rmatvec_forward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] += x[..., :-2] - y[..., 1:-1] -= 2 * x[..., :-2] - y[..., 2:] += x[..., :-2] + # y[..., :-2] += x[..., :-2] + y = inplace_add(x[..., :-2], y, self.slice[None][-2]) + # y[..., 1:-1] -= 2 * x[..., :-2] + y = inplace_add(-2 * x[..., :-2], y, self.slice[1][-1]) + # y[..., 2:] += x[..., :-2] + y = inplace_add(x[..., :-2], y, self.slice[2][None]) y /= self.sampling**2 return y @@ -141,10 +157,17 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray: def _matvec_centered(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + # y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + y = inplace_set( + x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[1][-1] + ) if self.edge: - y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2] - y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1] + # y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2] + y = inplace_set(x[..., 0] - 2 * x[..., 1] + x[..., 2], y, self.sample[0]) + # y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1] + y = inplace_set( + x[..., -3] - 2 * x[..., -2] + x[..., -1], y, self.sample[-1] + ) y /= self.sampling**2 return y @@ -152,16 +175,25 @@ def _matvec_centered(self, x: NDArray) -> NDArray: def _rmatvec_centered(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] += x[..., 1:-1] - y[..., 1:-1] -= 2 * x[..., 1:-1] - y[..., 2:] += x[..., 1:-1] + # y[..., :-2] += x[..., 1:-1] + y = inplace_add(x[..., 1:-1], y, self.slice[None][-2]) + # y[..., 1:-1] -= 2 * x[..., 1:-1] + y = inplace_add(-2 * x[..., 1:-1], y, self.slice[1][-1]) + # y[..., 2:] += x[..., 1:-1] + y = inplace_add(x[..., 1:-1], y, self.slice[2][None]) if self.edge: - y[..., 0] += x[..., 0] - y[..., 1] -= 2 * x[..., 0] - y[..., 2] += x[..., 0] - y[..., -3] += x[..., -1] - y[..., -2] -= 2 * x[..., -1] - y[..., -1] += x[..., -1] + # y[..., 0] += x[..., 0] + y = inplace_add(x[..., 0], y, self.sample[0]) + # y[..., 1] -= 2 * x[..., 0] + y = inplace_add(-2 * x[..., 0], y, self.sample[1]) + # y[..., 2] += x[..., 0] + y = inplace_add(x[..., 0], y, self.sample[2]) + # y[..., -3] += x[..., -1] + y = inplace_add(x[..., -1], y, self.sample[-3]) + # y[..., -2] -= 2 * x[..., -1] + y = inplace_add(-2 * x[..., -1], y, self.sample[-2]) + # y[..., -1] += x[..., -1] + y = inplace_add(x[..., -1], y, self.sample[-1]) y /= self.sampling**2 return y @@ -169,7 +201,10 @@ def _rmatvec_centered(self, x: NDArray) -> NDArray: def _matvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + # y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2] + y = inplace_set( + x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[2][None] + ) y /= self.sampling**2 return y @@ -177,8 +212,11 @@ def _matvec_backward(self, x: NDArray) -> NDArray: def _rmatvec_backward(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(x.shape, self.dtype) - y[..., :-2] += x[..., 2:] - y[..., 1:-1] -= 2 * x[..., 2:] - y[..., 2:] += x[..., 2:] + # y[..., :-2] += x[..., 2:] + y = inplace_add(x[..., 2:], y, self.slice[None][-2]) + # y[..., 1:-1] -= 2 * x[..., 2:] + y = inplace_add(-2 * x[..., 2:], y, self.slice[1][-1]) + # y[..., 2:] += x[..., 2:] + y = inplace_add(x[..., 2:], y, self.slice[2][None]) y /= self.sampling**2 return y diff --git a/pylops/basicoperators/symmetrize.py b/pylops/basicoperators/symmetrize.py index 47814154..41ca122b 100644 --- a/pylops/basicoperators/symmetrize.py +++ b/pylops/basicoperators/symmetrize.py @@ -6,7 +6,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -80,6 +80,13 @@ def __init__( self.nsym = dims[self.axis] dimsd = list(dims) dimsd[self.axis] = 2 * dims[self.axis] - 1 + self.slice1 = tuple([slice(None, None)] * (len(dims) - 1) + [slice(1, None)]) + self.slicensym_1 = tuple( + [slice(None, None)] * (len(dims) - 1) + [slice(self.nsym - 1, None)] + ) + self.slice_nsym_1 = tuple( + [slice(None, None)] * (len(dims) - 1) + [slice(None, self.nsym - 1)] + ) super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name) @@ -88,12 +95,12 @@ def _matvec(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.dimsd, dtype=self.dtype) y = y.swapaxes(self.axis, -1) - y[..., self.nsym - 1 :] = x - y[..., : self.nsym - 1] = x[..., -1:0:-1] + y = inplace_set(x, y, self.slicensym_1) + y = inplace_set(x[..., -1:0:-1], y, self.slice_nsym_1) return y @reshaped(swapaxis=True) def _rmatvec(self, x: NDArray) -> NDArray: y = x[..., self.nsym - 1 :].copy() - y[..., 1:] += x[..., self.nsym - 2 :: -1] + y = inplace_add(x[..., self.nsym - 2 :: -1], y, self.slice1) return y diff --git a/pylops/basicoperators/vstack.py b/pylops/basicoperators/vstack.py index 812b1a7e..0d66642e 100644 --- a/pylops/basicoperators/vstack.py +++ b/pylops/basicoperators/vstack.py @@ -12,16 +12,16 @@ from scipy.sparse.linalg.interface import LinearOperator as spLinearOperator from scipy.sparse.linalg.interface import _get_dtype else: - from scipy.sparse.linalg._interface import _get_dtype from scipy.sparse.linalg._interface import ( LinearOperator as spLinearOperator, ) + from scipy.sparse.linalg._interface import _get_dtype from typing import Callable, Optional, Sequence from pylops import LinearOperator from pylops.basicoperators import MatrixMult -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.typing import DTypeLike, NDArray @@ -165,14 +165,20 @@ def _matvec_serial(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.nops, dtype=self.dtype) for iop, oper in enumerate(self.ops): - y[self.nnops[iop] : self.nnops[iop + 1]] = oper.matvec(x).squeeze() + y = inplace_set( + oper.matvec(x).squeeze(), y, slice(self.nnops[iop], self.nnops[iop + 1]) + ) return y def _rmatvec_serial(self, x: NDArray) -> NDArray: ncp = get_array_module(x) y = ncp.zeros(self.mops, dtype=self.dtype) for iop, oper in enumerate(self.ops): - y += oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze() + y = inplace_add( + oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze(), + y, + slice(None, None), + ) return y def _matvec_multiproc(self, x: NDArray) -> NDArray: diff --git a/pylops/jaxoperator.py b/pylops/jaxoperator.py new file mode 100644 index 00000000..5d5c40ed --- /dev/null +++ b/pylops/jaxoperator.py @@ -0,0 +1,104 @@ +__all__ = [ + "JaxOperator", +] + +from typing import Any, NewType + +from pylops import LinearOperator +from pylops.utils import deps + +if deps.jax_enabled: + import jax + + jaxarrayin_type = jax.typing.ArrayLike + jaxarrayout_type = jax.Array +else: + jax_message = ( + "JAX package not installed. In order to be able to use" + 'the jaxoperator module run "pip install jax" or' + '"conda install -c conda-forge jax".' + ) + jaxarrayin_type = Any + jaxarrayout_type = Any + +JaxTypeIn = NewType("JaxTypeIn", jaxarrayin_type) +JaxTypeOut = NewType("JaxTypeOut", jaxarrayout_type) + + +class JaxOperator(LinearOperator): + """Enable JAX backend for PyLops operator. + + This class can be used to wrap a pylops operator to enable the JAX + backend. Doing so, users can run all of the methods of a pylops + operator with JAX arrays. Moreover, the forward and adjoint + are internally just-in-time compiled, and other JAX functionalities + such as automatic differentiation and automatic vectorization + are enabled. + + Parameters + ---------- + Op : :obj:`pylops.LinearOperator` + PyLops operator + + """ + + def __init__(self, Op: LinearOperator) -> None: + if not deps.jax_enabled: + raise NotImplementedError(jax_message) + super().__init__( + dtype=Op.dtype, + dims=Op.dims, + dimsd=Op.dimsd, + clinear=Op.clinear, + explicit=False, + forceflat=Op.forceflat, + name=Op.name, + ) + self._matvec = jax.jit(Op._matvec) + self._rmatvec = jax.jit(Op._rmatvec) + + def __call__(self, x, *args, **kwargs): + return self._matvec(x) + + def _rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut: + _, f_vjp = jax.vjp(self._matvec, x) + xadj = jax.jit(f_vjp)(y)[0] + return xadj + + def rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut: + """Vector-Jacobian product + + JIT-compiled Vector-Jacobian product + + Parameters + ---------- + x : :obj:`jax.Array` + Input array for forward + y : :obj:`jax.Array` + Input array for adjoint + + Returns + ------- + xadj : :obj:`jax.typing.ArrayLike` + Output array + + """ + M, N = self.shape + + if x.shape != (M,) and x.shape != (M, 1): + raise ValueError( + f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)." + ) + + y = self._rmatvecad(x, y) + + if x.ndim == 1: + y = y.reshape(N) + elif x.ndim == 2: + y = y.reshape(N, 1) + else: + raise ValueError( + f"Invalid shape returned by user-defined rmatvecad(). " + f"Expected 2-d ndarray or matrix, not {x.ndim}-d ndarray" + ) + return y diff --git a/pylops/linearoperator.py b/pylops/linearoperator.py index 44e561cd..661178f5 100644 --- a/pylops/linearoperator.py +++ b/pylops/linearoperator.py @@ -442,10 +442,11 @@ def _matmat(self, X: NDArray) -> NDArray: Modified version of scipy _matmat to avoid having trailing dimension in col when provided to matvec """ + ncp = get_array_module(X) if sp.sparse.issparse(X): - y = np.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T + y = ncp.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T else: - y = np.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T + y = ncp.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T return y def _rmatmat(self, X: NDArray) -> NDArray: @@ -454,10 +455,11 @@ def _rmatmat(self, X: NDArray) -> NDArray: Modified version of scipy _rmatmat to avoid having trailing dimension in col when provided to rmatvec """ + ncp = get_array_module(X) if sp.sparse.issparse(X): - y = np.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T + y = ncp.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T else: - y = np.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T + y = ncp.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T return y def _adjoint(self) -> LinearOperator: @@ -508,7 +510,9 @@ def matvec(self, x: NDArray) -> NDArray: M, N = self.shape if x.shape != (N,) and x.shape != (N, 1): - raise ValueError("dimension mismatch") + raise ValueError( + f"Dimension mismatch. Got {x.shape}, but expected ({N},) or ({N}, 1)." + ) y = self._matvec(x) @@ -517,7 +521,7 @@ def matvec(self, x: NDArray) -> NDArray: elif x.ndim == 2: y = y.reshape(M, 1) else: - raise ValueError("invalid shape returned by user-defined matvec()") + raise ValueError("Invalid shape returned by user-defined matvec()") return y @count(forward=False) @@ -542,7 +546,9 @@ def rmatvec(self, x: NDArray) -> NDArray: M, N = self.shape if x.shape != (M,) and x.shape != (M, 1): - raise ValueError("dimension mismatch") + raise ValueError( + f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)." + ) y = self._rmatvec(x) @@ -551,7 +557,7 @@ def rmatvec(self, x: NDArray) -> NDArray: elif x.ndim == 2: y = y.reshape(N, 1) else: - raise ValueError("invalid shape returned by user-defined rmatvec()") + raise ValueError("Invalid shape returned by user-defined rmatvec()") return y @count(forward=True, matmat=True) @@ -574,9 +580,9 @@ def matmat(self, X: NDArray) -> NDArray: """ if X.ndim != 2: - raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim) + raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray") if X.shape[0] != self.shape[1]: - raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape)) + raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}") Y = self._matmat(X) return Y @@ -600,9 +606,9 @@ def rmatmat(self, X: NDArray) -> NDArray: """ if X.ndim != 2: - raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim) + raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray") if X.shape[0] != self.shape[0]: - raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape)) + raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}") Y = self._rmatmat(X) return Y @@ -791,7 +797,7 @@ def todense( Parameters ---------- backend : :obj:`str`, optional - Backend used to densify matrix (``numpy`` or ``cupy``). Note that + Backend used to densify matrix (``numpy`` or ``cupy`` or ``jax``). Note that this must be consistent with how the operator has been created. Returns @@ -816,7 +822,7 @@ def todense( if Op.shape[1] == shapemin: matrix = Op.matmat(identity) else: - matrix = np.conj(Op.rmatmat(identity)).T + matrix = ncp.conj(Op.rmatmat(identity)).T return matrix def tosparse(self) -> NDArray: diff --git a/pylops/signalprocessing/convolve1d.py b/pylops/signalprocessing/convolve1d.py index fd82eb91..bc154a94 100644 --- a/pylops/signalprocessing/convolve1d.py +++ b/pylops/signalprocessing/convolve1d.py @@ -48,10 +48,10 @@ def _choose_convfunc( def _pad_along_axis(array: np.ndarray, pad_size: tuple, axis: int = 0) -> np.ndarray: - + ncp = get_array_module(array) npad = [(0, 0)] * array.ndim npad[axis] = pad_size - return np.pad(array, pad_width=npad) + return ncp.pad(array, pad_width=npad) class _Convolve1Dshort(LinearOperator): @@ -67,6 +67,7 @@ def __init__( dtype: DTypeLike = "float64", name: str = "C", ) -> None: + ncp = get_array_module(h) dims = _value_or_sized_to_tuple(dims) super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dims, name=name) self.axis = axis @@ -83,7 +84,7 @@ def __init__( (max(self.offset, 0), -min(self.offset, 0)), axis=-1 if h.ndim == 1 else axis, ) - self.hstar = np.flip(self.h, axis=-1) + self.hstar = ncp.flip(self.h, axis=-1) # add dimensions to filter to match dimensions of model and data if self.h.ndim == 1: @@ -127,6 +128,7 @@ def __init__( dtype: DTypeLike = "float64", name: str = "C", ) -> None: + ncp = get_array_module(h) dims = _value_or_sized_to_tuple(dims) dimsd = h.shape super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name) @@ -140,13 +142,13 @@ def __init__( self.offset = 2 * (self.dims[self.axis] // 2 - int(offset)) if self.dims[self.axis] % 2 == 0: self.offset -= 1 - self.hstar = np.flip(self.h, axis=-1) + self.hstar = ncp.flip(self.h, axis=-1) - self.pad = np.zeros((len(dims), 2), dtype=int) + self.pad = ncp.zeros((len(dims), 2), dtype=int) self.pad[self.axis, 0] = max(self.offset, 0) self.pad[self.axis, 1] = -min(self.offset, 0) - self.padd = np.zeros((len(dims), 2), dtype=int) + self.padd = ncp.zeros((len(dims), 2), dtype=int) self.padd[self.axis, 1] = max(self.offset, 0) self.padd[self.axis, 0] = -min(self.offset, 0) @@ -162,12 +164,13 @@ def __init__( @reshaped def _matvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) if type(self.h) is not type(x): self.h = to_cupy_conditional(x, self.h) self.convfunc, self.method = _choose_convfunc( self.h, self.method, self.dims, self.axis ) - x = np.pad(x, self.pad) + x = ncp.pad(x, self.pad) y = self.convfunc(self.h, x, mode="same") return y @@ -179,7 +182,7 @@ def _rmatvec(self, x: NDArray) -> NDArray: self.convfunc, self.method = _choose_convfunc( self.hstar, self.method, self.dims, self.axis ) - x = np.pad(x, self.padd) + x = ncp.pad(x, self.padd) y = self.convfunc(self.hstar, x) if self.dims[self.axis] % 2 == 0: y = ncp.take( diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py index 64444bcd..6af81a30 100644 --- a/pylops/signalprocessing/fft.py +++ b/pylops/signalprocessing/fft.py @@ -11,6 +11,7 @@ from pylops import LinearOperator from pylops.signalprocessing._baseffts import _BaseFFT, _FFTNorms from pylops.utils import deps +from pylops.utils.backend import get_array_module, inplace_divide, inplace_multiply from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -60,53 +61,61 @@ def __init__( self._scale = self.nfft elif self.norm is _FFTNorms.ONE_OVER_N: self._scale = 1.0 / self.nfft + self.slice = tuple( + [slice(None, None)] * (len(self.dims) - 1) + + [slice(1, 1 + (self.nfft - 1) // 2)] + ) @reshaped def _matvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) if self.ifftshift_before: - x = np.fft.ifftshift(x, axes=self.axis) + x = ncp.fft.ifftshift(x, axes=self.axis) if not self.clinear: - x = np.real(x) + x = ncp.real(x) if self.real: - y = np.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + y = ncp.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) # Apply scaling to obtain a correct adjoint for this operator - y = np.swapaxes(y, -1, self.axis) - y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2) - y = np.swapaxes(y, self.axis, -1) + y = ncp.swapaxes(y, -1, self.axis) + # y[..., 1 : 1 + (self.nfft - 1) // 2] *= ncp.sqrt(2) + y = inplace_multiply(ncp.sqrt(2), y, self.slice) + y = ncp.swapaxes(y, self.axis, -1) else: - y = np.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + y = ncp.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale if self.fftshift_after: - y = np.fft.fftshift(y, axes=self.axis) + y = ncp.fft.fftshift(y, axes=self.axis) y = y.astype(self.cdtype) return y @reshaped def _rmatvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) if self.fftshift_after: - x = np.fft.ifftshift(x, axes=self.axis) + x = ncp.fft.ifftshift(x, axes=self.axis) if self.real: # Apply scaling to obtain a correct adjoint for this operator x = x.copy() - x = np.swapaxes(x, -1, self.axis) - x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2) - x = np.swapaxes(x, self.axis, -1) - y = np.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + x = ncp.swapaxes(x, -1, self.axis) + # x[..., 1 : 1 + (self.nfft - 1) // 2] /= ncp.sqrt(2) + x = inplace_divide(ncp.sqrt(2), x, self.slice) + x = ncp.swapaxes(x, self.axis, -1) + y = ncp.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) else: - y = np.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + y = ncp.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) if self.norm is _FFTNorms.NONE: y *= self._scale if self.nfft > self.dims[self.axis]: - y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis) + y = ncp.take(y, range(0, self.dims[self.axis]), axis=self.axis) elif self.nfft < self.dims[self.axis]: - y = np.pad(y, self.ifftpad) + y = ncp.pad(y, self.ifftpad) if not self.clinear: - y = np.real(y) + y = ncp.real(y) if self.ifftshift_before: - y = np.fft.fftshift(y, axes=self.axis) + y = ncp.fft.fftshift(y, axes=self.axis) y = y.astype(self.rdtype) return y @@ -453,7 +462,7 @@ def FFT( Nyquist to the frequency bin before zero. engine : :obj:`str`, optional Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose - ``numpy`` when working with cupy arrays. + ``numpy`` when working with cupy and jax arrays. .. note:: Since version 1.17.0, accepts "scipy". diff --git a/pylops/signalprocessing/fft2d.py b/pylops/signalprocessing/fft2d.py index f54e2972..2f4b5f15 100644 --- a/pylops/signalprocessing/fft2d.py +++ b/pylops/signalprocessing/fft2d.py @@ -9,6 +9,7 @@ from pylops import LinearOperator from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms +from pylops.utils.backend import get_array_module from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike @@ -67,51 +68,53 @@ def __init__( @reshaped def _matvec(self, x): + ncp = get_array_module(x) if self.ifftshift_before.any(): - x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before]) + x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: - x = np.real(x) + x = ncp.real(x) if self.real: - y = np.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = ncp.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) # Apply scaling to obtain a correct adjoint for this operator - y = np.swapaxes(y, -1, self.axes[-1]) - y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) - y = np.swapaxes(y, self.axes[-1], -1) + y = ncp.swapaxes(y, -1, self.axes[-1]) + y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2) + y = ncp.swapaxes(y, self.axes[-1], -1) else: - y = np.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = ncp.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale y = y.astype(self.cdtype) if self.fftshift_after.any(): - y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after]) + y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after]) return y @reshaped def _rmatvec(self, x): + ncp = get_array_module(x) if self.fftshift_after.any(): - x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after]) + x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after]) if self.real: # Apply scaling to obtain a correct adjoint for this operator x = x.copy() - x = np.swapaxes(x, -1, self.axes[-1]) - x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) - x = np.swapaxes(x, self.axes[-1], -1) - y = np.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + x = ncp.swapaxes(x, -1, self.axes[-1]) + x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2) + x = ncp.swapaxes(x, self.axes[-1], -1) + y = ncp.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) else: - y = np.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = ncp.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) if self.norm is _FFTNorms.NONE: y *= self._scale if self.nffts[0] > self.dims[self.axes[0]]: - y = np.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0]) + y = ncp.take(y, ncp.arange(self.dims[self.axes[0]]), axis=self.axes[0]) if self.nffts[1] > self.dims[self.axes[1]]: - y = np.take(y, range(self.dims[self.axes[1]]), axis=self.axes[1]) + y = ncp.take(y, ncp.arange(self.dims[self.axes[1]]), axis=self.axes[1]) if self.doifftpad: - y = np.pad(y, self.ifftpad) + y = ncp.pad(y, self.ifftpad) if not self.clinear: - y = np.real(y) + y = ncp.real(y) y = y.astype(self.rdtype) if self.ifftshift_before.any(): - y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before]) + y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before]) return y def __truediv__(self, y): @@ -310,7 +313,8 @@ def FFT2D( engine : :obj:`str`, optional .. versionadded:: 1.17.0 - Engine used for fft computation (``numpy`` or ``scipy``). + Engine used for fft computation (``numpy`` or ``scipy``). Choose + ``numpy`` when working with cupy and jax arrays. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py index d081072b..cf2de78f 100644 --- a/pylops/signalprocessing/fftnd.py +++ b/pylops/signalprocessing/fftnd.py @@ -7,8 +7,9 @@ import numpy as np import numpy.typing as npt +from pylops import LinearOperator from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms -from pylops.utils.backend import get_sp_fft +from pylops.utils.backend import get_array_module, get_sp_fft from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -46,6 +47,7 @@ def __init__( warnings.warn( f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}." ) + self._kwargs_fft = kwargs_fft self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy if self.norm is _FFTNorms.ORTHO: @@ -57,58 +59,52 @@ def __init__( @reshaped def _matvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) if self.ifftshift_before.any(): - x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before]) + x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: - x = np.real(x) + x = ncp.real(x) if self.real: - y = np.fft.rfftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = ncp.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) # Apply scaling to obtain a correct adjoint for this operator - y = np.swapaxes(y, -1, self.axes[-1]) - y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) - y = np.swapaxes(y, self.axes[-1], -1) + y = ncp.swapaxes(y, -1, self.axes[-1]) + y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2) + y = ncp.swapaxes(y, self.axes[-1], -1) else: - y = np.fft.fftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = ncp.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale y = y.astype(self.cdtype) if self.fftshift_after.any(): - y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after]) + y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after]) return y @reshaped def _rmatvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) if self.fftshift_after.any(): - x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after]) + x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after]) if self.real: # Apply scaling to obtain a correct adjoint for this operator x = x.copy() - x = np.swapaxes(x, -1, self.axes[-1]) - x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) - x = np.swapaxes(x, self.axes[-1], -1) - y = np.fft.irfftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + x = ncp.swapaxes(x, -1, self.axes[-1]) + x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2) + x = ncp.swapaxes(x, self.axes[-1], -1) + y = ncp.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) else: - y = np.fft.ifftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = ncp.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) if self.norm is _FFTNorms.NONE: y *= self._scale for ax, nfft in zip(self.axes, self.nffts): if nfft > self.dims[ax]: - y = np.take(y, range(self.dims[ax]), axis=ax) + y = ncp.take(y, np.arange(self.dims[ax]), axis=ax) if self.doifftpad: - y = np.pad(y, self.ifftpad) + y = ncp.pad(y, self.ifftpad) if not self.clinear: - y = np.real(y) + y = ncp.real(y) y = y.astype(self.rdtype) if self.ifftshift_before.any(): - y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before]) + y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before]) return y def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: @@ -161,17 +157,13 @@ def _matvec(self, x: NDArray) -> NDArray: if not self.clinear: x = np.real(x) if self.real: - y = sp_fft.rfftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = sp_fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) # Apply scaling to obtain a correct adjoint for this operator y = np.swapaxes(y, -1, self.axes[-1]) y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) y = np.swapaxes(y, self.axes[-1], -1) else: - y = sp_fft.fftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = sp_fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale if self.fftshift_after.any(): @@ -189,13 +181,9 @@ def _rmatvec(self, x: NDArray) -> NDArray: x = np.swapaxes(x, -1, self.axes[-1]) x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) x = np.swapaxes(x, self.axes[-1], -1) - y = sp_fft.irfftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = sp_fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) else: - y = sp_fft.ifftn( - x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft - ) + y = sp_fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) if self.norm is _FFTNorms.NONE: y *= self._scale for ax, nfft in zip(self.axes, self.nffts): @@ -228,7 +216,7 @@ def FFTND( dtype: DTypeLike = "complex128", name: str = "F", **kwargs_fft, -): +) -> LinearOperator: r"""N-dimensional Fast-Fourier Transform. Apply N-dimensional Fast-Fourier Transform (FFT) to any n ``axes`` @@ -316,7 +304,8 @@ def FFTND( engine : :obj:`str`, optional .. versionadded:: 1.17.0 - Engine used for fft computation (``numpy`` or ``scipy``). + Engine used for fft computation (``numpy`` or ``scipy``). Choose + ``numpy`` when working with cupy and jax arrays. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. @@ -331,7 +320,9 @@ def FFTND( Name of operator (to be used by :func:`pylops.utils.describe.describe`) **kwargs_fft - Arbitrary keyword arguments to be passed to the selected fft method + .. versionadded:: 2.3.0 + + Arbitrary keyword arguments to be passed to the selected fft method Attributes ---------- diff --git a/pylops/signalprocessing/fredholm1.py b/pylops/signalprocessing/fredholm1.py index 57eec234..feb6c645 100644 --- a/pylops/signalprocessing/fredholm1.py +++ b/pylops/signalprocessing/fredholm1.py @@ -3,7 +3,7 @@ import numpy as np from pylops import LinearOperator -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, NDArray @@ -118,7 +118,7 @@ def _matvec(self, x: NDArray) -> NDArray: else: y = ncp.squeeze(ncp.zeros((self.nsl, self.nx, self.nz), dtype=self.dtype)) for isl in range(self.nsl): - y[isl] = ncp.dot(self.G[isl], x[isl]) + y = inplace_set(ncp.dot(self.G[isl], x[isl]), y, isl) return y @reshaped @@ -131,7 +131,6 @@ def _rmatvec(self, x: NDArray) -> NDArray: if hasattr(self, "GT"): y = ncp.matmul(self.GT, x) else: - # y = ncp.matmul(self.G.transpose((0, 2, 1)).conj(), x) y = ( ncp.matmul(x.transpose(0, 2, 1).conj(), self.G) .transpose(0, 2, 1) @@ -141,9 +140,10 @@ def _rmatvec(self, x: NDArray) -> NDArray: y = ncp.squeeze(ncp.zeros((self.nsl, self.ny, self.nz), dtype=self.dtype)) if hasattr(self, "GT"): for isl in range(self.nsl): - y[isl] = ncp.dot(self.GT[isl], x[isl]) + y = inplace_set(ncp.dot(self.GT[isl], x[isl]), y, isl) else: for isl in range(self.nsl): - # y[isl] = ncp.dot(self.G[isl].conj().T, x[isl]) - y[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj() - return y + y = inplace_set( + ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj(), y, isl + ) + return y.ravel() diff --git a/pylops/signalprocessing/nonstatconvolve1d.py b/pylops/signalprocessing/nonstatconvolve1d.py index 45daeed5..669898ef 100644 --- a/pylops/signalprocessing/nonstatconvolve1d.py +++ b/pylops/signalprocessing/nonstatconvolve1d.py @@ -9,7 +9,7 @@ from pylops import LinearOperator from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray @@ -147,7 +147,8 @@ def _interpolate_h(hs, ix, oh, dh, nh): @reshaped(swapaxis=True) def _matvec(self, x: NDArray) -> NDArray: - y = np.zeros_like(x) + ncp = get_array_module(x) + y = ncp.zeros_like(x) for ix in range(self.dims[self.axis]): h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh) xextremes = ( @@ -158,14 +159,20 @@ def _matvec(self, x: NDArray) -> NDArray: max(0, -ix + self.hsize // 2), min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)), ) - y[..., xextremes[0] : xextremes[1]] += ( - x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]] + # y[..., xextremes[0] : xextremes[1]] += ( + # x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]] + # ) + sl = tuple( + [slice(None, None)] * (len(self.dimsd) - 1) + + [slice(xextremes[0], xextremes[1])] ) + y = inplace_add(x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl) return y @reshaped(swapaxis=True) def _rmatvec(self, x: NDArray) -> NDArray: - y = np.zeros_like(x) + ncp = get_array_module(x) + y = ncp.zeros_like(x) for ix in range(self.dims[self.axis]): h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh) xextremes = ( @@ -176,17 +183,29 @@ def _rmatvec(self, x: NDArray) -> NDArray: max(0, -ix + self.hsize // 2), min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)), ) - y[..., ix] = np.sum( - h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]], - axis=-1, + # y[..., ix] = ncp.sum( + # h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]], + # axis=-1, + # ) + sl = tuple([slice(None, None)] * (len(self.dimsd) - 1) + [ix]) + y = inplace_set( + ncp.sum( + h[hextremes[0] : hextremes[1]] + * x[..., xextremes[0] : xextremes[1]], + axis=-1, + ), + y, + sl, ) + return y def todense(self): + ncp = get_array_module(self.hsinterp[0]) hs = self.hsinterp - H = np.array( + H = ncp.array( [ - np.roll(np.pad(h, (0, self.dims[self.axis])), ix) + ncp.roll(ncp.pad(h, (0, self.dims[self.axis])), ix) for ix, h in enumerate(hs) ] ) @@ -317,18 +336,27 @@ def _interpolate_hadj(htmp, hs, hextremes, ix, oh, dh, nh): """find closest filters and spread weighted psf""" ih_closest = int(np.floor((ix - oh) / dh)) if ih_closest < 0: - hs[0, hextremes[0] : hextremes[1]] += htmp + # hs[0, hextremes[0] : hextremes[1]] += htmp + sl = tuple([0] + [slice(hextremes[0], hextremes[1])]) + hs = inplace_add(htmp, hs, sl) elif ih_closest >= nh - 1: - hs[nh - 1, hextremes[0] : hextremes[1]] += htmp + # hs[nh - 1, hextremes[0] : hextremes[1]] += htmp + sl = tuple([nh - 1] + [slice(hextremes[0], hextremes[1])]) + hs = inplace_add(htmp, hs, sl) else: dh_closest = (ix - oh) / dh - ih_closest - hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp - hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp + # hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp + sl = tuple([ih_closest] + [slice(hextremes[0], hextremes[1])]) + hs = inplace_add((1 - dh_closest) * htmp, hs, sl) + # hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp + sl = tuple([ih_closest + 1] + [slice(hextremes[0], hextremes[1])]) + hs = inplace_add(dh_closest * htmp, hs, sl) return hs @reshaped def _matvec(self, x: NDArray) -> NDArray: - y = np.zeros(self.dimsd, dtype=self.dtype) + ncp = get_array_module(x) + y = ncp.zeros(self.dimsd, dtype=self.dtype) for ix in range(self.dimsd[0]): h = self._interpolate_h(x, ix, self.oh, self.dh, self.nh) xextremes = ( @@ -339,14 +367,23 @@ def _matvec(self, x: NDArray) -> NDArray: max(0, -ix + self.hsize // 2), min(self.hsize, self.hsize // 2 + (self.dimsd[0] - ix)), ) - y[..., xextremes[0] : xextremes[1]] += ( - self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]] + # y[..., xextremes[0] : xextremes[1]] += ( + # self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]] + # ) + sl = tuple( + [slice(None, None)] * (len(self.dimsd) - 1) + + [slice(xextremes[0], xextremes[1])] + ) + y = inplace_add( + self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl ) + return y @reshaped def _rmatvec(self, x: NDArray) -> NDArray: - hs = np.zeros(self.dims, dtype=self.dtype) + ncp = get_array_module(x) + hs = ncp.zeros(self.dims, dtype=self.dtype) for ix in range(self.dimsd[0]): xextremes = ( max(0, ix - self.hsize // 2), diff --git a/pylops/torchoperator.py b/pylops/torchoperator.py index 5e41f67f..1c4dc2da 100644 --- a/pylops/torchoperator.py +++ b/pylops/torchoperator.py @@ -14,7 +14,7 @@ else: torch_message = ( "Torch package not installed. In order to be able to use" - 'the twoway module run "pip install torch" or' + 'the torchoperator module run "pip install torch" or' '"conda install -c pytorch torch".' ) from pylops.utils.typing import TensorTypeLike diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index c56be90d..4b6b506f 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -13,10 +13,16 @@ "get_csc_matrix", "get_sparse_eye", "get_lstsq", + "get_sp_fft", "get_complex_dtype", "get_real_dtype", "to_numpy", "to_cupy_conditional", + "inplace_set", + "inplace_add", + "inplace_multiply", + "inplace_divide", + "randn", ] from types import ModuleType @@ -45,6 +51,14 @@ from cupyx.scipy.sparse import csc_matrix as cp_csc_matrix from cupyx.scipy.sparse import eye as cp_eye +if deps.jax_enabled: + import jax + import jax.numpy as jnp + from jax.scipy.linalg import block_diag as jnp_block_diag + from jax.scipy.linalg import toeplitz as jnp_toeplitz + from jax.scipy.signal import convolve as j_convolve + from jax.scipy.signal import fftconvolve as j_fftconvolve + def get_module(backend: str = "numpy") -> ModuleType: """Returns correct numerical module based on backend string @@ -52,21 +66,23 @@ def get_module(backend: str = "numpy") -> ModuleType: Parameters ---------- backend : :obj:`str`, optional - Backend used for dot test computations (``numpy`` or ``cupy``). This + Backend used for dot test computations (``numpy`` or ``cupy`` or ``jax``). This parameter will be used to choose how to create the random vectors. Returns ------- mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + Module to be used to process array (:mod:`numpy` or :mod:`cupy` or :mod:`jax`) """ if backend == "numpy": ncp = np elif backend == "cupy": ncp = cp + elif backend == "jax": + ncp = jnp else: - raise ValueError("backend must be numpy or cupy") + raise ValueError("backend must be numpy, cupy, or jax") return ncp @@ -76,12 +92,12 @@ def get_module_name(mod: ModuleType) -> str: Parameters ---------- mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + Module to be used to process array (:mod:`numpy` or :mod:`cupy` or :mod:`jax`) Returns ------- backend : :obj:`str`, optional - Backend used for dot test computations (``numpy`` or ``cupy``). This + Backend used for dot test computations (``numpy`` or ``cupy`` or ``jax``). This parameter will be used to choose how to create the random vectors. """ @@ -89,8 +105,10 @@ def get_module_name(mod: ModuleType) -> str: backend = "numpy" elif mod == cp: backend = "cupy" + elif mod == jnp: + backend = "jax" else: - raise ValueError("module must be numpy or cupy") + raise ValueError("module must be numpy, cupy, or jax") return backend @@ -99,17 +117,23 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + Module to be used to process array + (:mod:`numpy`, :mod:`cupy`, or , :mod:`jax`) """ - if deps.cupy_enabled: - return cp.get_array_module(x) + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + return jnp + elif deps.cupy_enabled: + return cp.get_array_module(x) + else: + return np else: return np @@ -119,22 +143,24 @@ def get_convolve(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ - if not deps.cupy_enabled: - return convolve - - if cp.get_array_module(x) == np: - return convolve + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + return j_convolve + elif deps.cupy_enabled and cp.get_array_module(x) == cp: + return cp_convolve + else: + return convolve else: - return cp_convolve + return convolve def get_fftconvolve(x: npt.ArrayLike) -> Callable: @@ -142,22 +168,24 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ - if not deps.cupy_enabled: - return fftconvolve - - if cp.get_array_module(x) == np: - return fftconvolve + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + return j_fftconvolve + elif deps.cupy_enabled and cp.get_array_module(x) == cp: + return cp_fftconvolve + else: + return fftconvolve else: - return cp_fftconvolve + return fftconvolve def get_oaconvolve(x: npt.ArrayLike) -> Callable: @@ -165,22 +193,28 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ - if not deps.cupy_enabled: - return oaconvolve - - if cp.get_array_module(x) == np: - return oaconvolve + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + raise NotImplementedError( + "oaconvolve not implemented in " + "jax. Consider using a different" + "option..." + ) + elif deps.cupy_enabled and cp.get_array_module(x) == cp: + return cp_oaconvolve + else: + return oaconvolve else: - return cp_oaconvolve + return oaconvolve def get_correlate(x: npt.ArrayLike) -> Callable: @@ -188,22 +222,24 @@ def get_correlate(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ - if not deps.cupy_enabled: - return correlate - - if cp.get_array_module(x) == np: - return correlate + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + return jax.scipy.signal.correlate + elif deps.cupy_enabled and cp.get_array_module(x) == cp: + return cp_correlate + else: + return correlate else: - return cp_correlate + return correlate def get_add_at(x: npt.ArrayLike) -> Callable: @@ -211,13 +247,13 @@ def get_add_at(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ if not deps.cupy_enabled: @@ -234,13 +270,13 @@ def get_sliding_window_view(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ if not deps.cupy_enabled: @@ -257,22 +293,24 @@ def get_block_diag(x: npt.ArrayLike) -> Callable: Parameters ---------- - x : :obj:`numpy.ndarray` + x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array` Array Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ - if not deps.cupy_enabled: - return block_diag - - if cp.get_array_module(x) == np: - return block_diag + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + return jnp_block_diag + elif deps.cupy_enabled and cp.get_array_module(x) == cp: + return cp_block_diag + else: + return block_diag else: - return cp_block_diag + return block_diag def get_toeplitz(x: npt.ArrayLike) -> Callable: @@ -285,17 +323,19 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable: Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ - if not deps.cupy_enabled: - return toeplitz - - if cp.get_array_module(x) == np: - return toeplitz + if deps.cupy_enabled or deps.jax_enabled: + if isinstance(x, jnp.ndarray): + return jnp_toeplitz + elif deps.cupy_enabled and cp.get_array_module(x) == cp: + return cp_toeplitz + else: + return toeplitz else: - return cp_toeplitz + return toeplitz def get_csc_matrix(x: npt.ArrayLike) -> Callable: @@ -308,8 +348,8 @@ def get_csc_matrix(x: npt.ArrayLike) -> Callable: Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ if not deps.cupy_enabled: @@ -331,8 +371,8 @@ def get_sparse_eye(x: npt.ArrayLike) -> Callable: Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ if not deps.cupy_enabled: @@ -354,8 +394,8 @@ def get_lstsq(x: npt.ArrayLike) -> Callable: Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ if not deps.cupy_enabled: @@ -377,8 +417,8 @@ def get_sp_fft(x: npt.ArrayLike) -> Callable: Returns ------- - mod : :obj:`func` - Module to be used to process array (:mod:`numpy` or :mod:`cupy`) + f : :obj:`func` + Function to be used to process array """ if not deps.cupy_enabled: @@ -433,7 +473,7 @@ def to_numpy(x: NDArray) -> NDArray: Returns ------- - x : :obj:`cupy.ndarray` + x : :obj:`numpy.ndarray` Converted array """ @@ -464,3 +504,135 @@ def to_cupy_conditional(x: npt.ArrayLike, y: npt.ArrayLike) -> NDArray: with cp.cuda.Device(x.device): y = cp.asarray(y) return y + + +def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: + """Perform inplace set based on input + + Parameters + ---------- + x : :obj:`numpy.ndarray` or :obj:`jax.Array` + Array to sum + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + idx : :obj:`list` + Indices to sum at + + Returns + ------- + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + + """ + if deps.jax_enabled and isinstance(x, jnp.ndarray): + y = y.at[idx].set(x) + return y + else: + y[idx] = x + return y + + +def inplace_add(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: + """Perform inplace add based on input + + Parameters + ---------- + x : :obj:`numpy.ndarray` or :obj:`jax.Array` + Array to sum + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + idx : :obj:`list` + Indices to sum at + + Returns + ------- + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + + """ + if deps.jax_enabled and isinstance(x, jnp.ndarray): + y = y.at[idx].add(x) + return y + else: + y[idx] += x + return y + + +def inplace_multiply(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: + """Perform inplace multiplication based on input + + Parameters + ---------- + x : :obj:`numpy.ndarray` or :obj:`jax.Array` + Array to sum + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + idx : :obj:`list` + Indices to multiply at + + Returns + ------- + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + + """ + if deps.jax_enabled and isinstance(x, jnp.ndarray): + y = y.at[idx].multiply(x) + return y + else: + y[idx] *= x + return y + + +def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray: + """Perform inplace division based on input + + Parameters + ---------- + x : :obj:`numpy.ndarray` or :obj:`jax.Array` + Array to sum + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + idx : :obj:`list` + Indices to divide at + + Returns + ------- + y : :obj:`numpy.ndarray` or :obj:`jax.Array` + Output array + + """ + if deps.jax_enabled and isinstance(x, jnp.ndarray): + y = y.at[idx].divide(x) + return y + else: + y[idx] /= x + return y + + +def randn(*n: int, backend: str = "numpy") -> NDArray: + """Returns randomly generated number + + Parameters + ---------- + *n : :obj:`int` + Number of samples to generate in each dimension + backend : :obj:`str`, optional + Backend used for dot test computations (``numpy`` or ``cupy``). This + parameter will be used to choose how to create the random vectors. + + Returns + ------- + x : :obj:`numpy.ndarray` or :obj:`jax.Array` + Generated array + + """ + if backend == "numpy": + x = np.random.randn(*n) + elif backend == "cupy": + x = cp.random.randn(*n) + elif backend == "jax": + x = jnp.array(np.random.randn(*n)) + else: + raise ValueError("backend must be numpy, cupy, or jax") + return x diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index 4b2d21e7..ecf69a95 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,5 +1,6 @@ __all__ = [ "cupy_enabled", + "jax_enabled", "devito_enabled", "dtcwt_enabled", "numba_enabled", @@ -51,6 +52,34 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message +def jax_import(message: Optional[str] = None) -> str: + jax_test = ( + util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 + ) + if jax_test: + try: + import_module("jax") # noqa: F401 + + jax_message = None + except (ImportError, ModuleNotFoundError) as e: + jax_message = ( + f"Failed to import jax, Falling back to numpy (error: {e}). " + "Please ensure your environment is set up correctly " + "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" + ) + print(UserWarning(jax_message)) + else: + jax_message = ( + "Jax package not installed or os.getenv('JAX_PYLOPS') == 0. " + f"In order to be able to use {message} " + "ensure 'os.getenv('JAX_PYLOPS') == 1' and run " + "'pip install jax'; " + "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" + ) + + return jax_message + + def devito_import(message: Optional[str] = None) -> str: if devito_enabled: try: @@ -195,15 +224,18 @@ def sympy_import(message: Optional[str] = None) -> str: # Set package availability booleans -# cupy: the package is imported to check everything is working correctly, -# if not the package is disabled. We do this here as this library is used as drop-in -# replacement for many numpy and scipy routines when cupy arrays are provided. +# cupy and jax: the package is imported to check everything is working correctly, +# if not the package is disabled. We do this here as these libraries are used as drop-in +# replacement for many numpy and scipy routines when cupy/jax arrays are provided. # all other libraries: we simply check if the package is available and postpone its import # to check everything is working correctly when a user tries to create an operator that requires # such a package cupy_enabled: bool = ( True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False ) +jax_enabled: bool = ( + True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False +) devito_enabled = util.find_spec("devito") is not None dtcwt_enabled = util.find_spec("dtcwt") is not None numba_enabled = util.find_spec("numba") is not None diff --git a/pylops/utils/dottest.py b/pylops/utils/dottest.py index ed77b995..c8a198ca 100644 --- a/pylops/utils/dottest.py +++ b/pylops/utils/dottest.py @@ -4,7 +4,7 @@ import numpy as np -from pylops.utils.backend import get_module, to_numpy +from pylops.utils.backend import get_module, randn, to_numpy def dottest( @@ -93,13 +93,13 @@ def dottest( # make u and v vectors rdtype = np.ones(1, Op.dtype).real.dtype - u = ncp.random.randn(nc).astype(rdtype) + u = randn(nc, backend=backend).astype(rdtype) if complexflag not in (0, 2): - u = u + 1j * ncp.random.randn(nc).astype(rdtype) + u = u + 1j * randn(nc, backend=backend).astype(rdtype) - v = ncp.random.randn(nr).astype(rdtype) + v = randn(nr, backend=backend).astype(rdtype) if complexflag not in (0, 1): - v = v + 1j * ncp.random.randn(nr).astype(rdtype) + v = v + 1j * randn(nr, backend=backend).astype(rdtype) y = Op.matvec(u) # Op * u x = Op.rmatvec(v) # Op'* v diff --git a/pylops/waveeqprocessing/blending.py b/pylops/waveeqprocessing/blending.py index adca6d93..2bc31c65 100644 --- a/pylops/waveeqprocessing/blending.py +++ b/pylops/waveeqprocessing/blending.py @@ -9,7 +9,7 @@ from pylops import LinearOperator from pylops.basicoperators import BlockDiag, HStack, Pad from pylops.signalprocessing import Shift -from pylops.utils.backend import get_array_module +from pylops.utils.backend import get_array_module, inplace_add, inplace_set from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, NDArray @@ -76,12 +76,7 @@ def __init__( self.dt = dt self.times = times self.shiftall = shiftall - if np.max(self.times) // dt == np.max(self.times) / dt: - # do not add extra sample as no shift will be applied - self.nttot = int(np.max(self.times) / self.dt + self.nt) - else: - # add 1 extra sample at the end - self.nttot = int(np.max(self.times) / self.dt + self.nt + 1) + self.nttot = int(np.max(self.times) / self.dt + self.nt + 1) if not self.shiftall: # original implementation, where each source is shifted indipendently self.PadOp = Pad((self.nr, self.nt), ((0, 0), (0, 1)), dtype=self.dtype) @@ -143,7 +138,11 @@ def _matvec_smallrecs(self, x: NDArray) -> NDArray: self.ns, self.nr, self.nt + 1 ) for i, shift_int in enumerate(self.shifts): - blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data[i] + blended_data = inplace_add( + shifted_data[i], + blended_data, + (slice(None, None), slice(shift_int, shift_int + self.nt + 1)), + ) return blended_data @reshaped @@ -151,7 +150,11 @@ def _rmatvec_smallrecs(self, x: NDArray) -> NDArray: ncp = get_array_module(x) shifted_data = ncp.zeros((self.ns, self.nr, self.nt + 1), dtype=self.dtype) for i, shift_int in enumerate(self.shifts): - shifted_data[i, :, :] = x[:, shift_int : shift_int + self.nt + 1] + shifted_data = inplace_set( + x[:, shift_int : shift_int + self.nt + 1], + shifted_data, + (i, slice(None, None), slice(None, None)), + ) deblended_data = self.PadOp._rmatvec( self.ShiftOp._rmatvec(shifted_data.ravel()) ).reshape(self.dims) @@ -170,7 +173,11 @@ def _matvec_largerecs(self, x: NDArray) -> NDArray: .matvec(self.PadOp.matvec(x[i, :, :].ravel())) .reshape(self.ShiftOps[i].dimsd) ) - blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data + blended_data = inplace_add( + shifted_data, + blended_data, + (slice(None, None), slice(shift_int, shift_int + self.nt + 1)), + ) return blended_data @reshaped @@ -186,7 +193,11 @@ def _rmatvec_largerecs(self, x: NDArray) -> NDArray: x[:, shift_int : shift_int + self.nt + 1].ravel() ) ).reshape(self.PadOp.dims) - deblended_data[i, :, :] = shifted_data + deblended_data = inplace_set( + shifted_data, + deblended_data, + (i, slice(None, None), slice(None, None)), + ) return deblended_data def _register_multiplications(self) -> None: diff --git a/pylops/waveeqprocessing/kirchhoff.py b/pylops/waveeqprocessing/kirchhoff.py index 06c29251..bdd5be87 100644 --- a/pylops/waveeqprocessing/kirchhoff.py +++ b/pylops/waveeqprocessing/kirchhoff.py @@ -288,8 +288,11 @@ def __init__( ) self.rix = np.tile((recs[0] - x[0]) // dx, (ns, 1)).astype(int).ravel() elif self.ndims == 3: - # TODO: 3D normalized distances - raise NotImplementedError("dynamic=True currently not available in 3D") + # TODO: compute 3D indices for aperture filter + # currently no aperture filter in 3D... just make indices 0 + # so check if always passed + self.six = np.zeros(nr * ns) + self.rix = np.zeros(nr * ns) # compute traveltime and distances self.travsrcrec = True # use separate tables for src and rec traveltimes @@ -362,8 +365,26 @@ def __init__( trav_recs_grad[0], trav_recs_grad[1] ).reshape(np.prod(dims), nr) else: - # TODO: 3D - raise NotImplementedError("dynamic=True currently not available in 3D") + trav_srcs_grad = np.concatenate( + [trav_srcs_grad[i][np.newaxis] for i in range(3)] + ) + trav_recs_grad = np.concatenate( + [trav_recs_grad[i][np.newaxis] for i in range(3)] + ) + self.angle_srcs = ( + np.sign(trav_srcs_grad[1]) + * np.arccos( + trav_srcs_grad[-1] + / np.sqrt(np.sum(trav_srcs_grad**2, axis=0)) + ) + ).reshape(np.prod(dims), ns) + self.angle_recs = ( + np.sign(trav_srcs_grad[1]) + * np.arccos( + trav_recs_grad[-1] + / np.sqrt(np.sum(trav_recs_grad**2, axis=0)) + ) + ).reshape(np.prod(dims), nr) # pre-compute traveltime indices if total traveltime is used if not self.travsrcrec: @@ -386,6 +407,12 @@ def __init__( # define aperture # if aperture=None, we want to ensure the check is always matched (no aperture limits...) + # if aperture!=None in 3d, force to None as aperture checks are not yet implemented + if aperture is not None and self.ndims == 3: + aperture = None + warnings.warn( + "Aperture is forced to None as currently not implemented in 3D" + ) if aperture is not None: warnings.warn( "Aperture is currently defined as ratio of offset over depth, " @@ -608,10 +635,10 @@ def _traveltime_table( # compute traveltime gradients at image points trav_srcs_grad = np.gradient( - trav_srcs.reshape(*dims, ns), axis=np.arange(ndims) + trav_srcs.reshape(*dims, ns), *dsamp, axis=np.arange(ndims) ) trav_recs_grad = np.gradient( - trav_recs.reshape(*dims, nr), axis=np.arange(ndims) + trav_recs.reshape(*dims, nr), *dsamp, axis=np.arange(ndims) ) return ( diff --git a/pylops/waveeqprocessing/wavedecomposition.py b/pylops/waveeqprocessing/wavedecomposition.py index 7d926d36..715fb2c1 100644 --- a/pylops/waveeqprocessing/wavedecomposition.py +++ b/pylops/waveeqprocessing/wavedecomposition.py @@ -156,6 +156,7 @@ def _obliquity3D( critical: float = 100.0, ntaper: int = 10, composition: bool = True, + fftengine: str = "scipy", backend: str = "numpy", dtype: DTypeLike = "complex128", ) -> Tuple[LinearOperator, LinearOperator]: @@ -187,6 +188,9 @@ def _obliquity3D( composition : :obj:`bool`, optional Create obliquity factor for composition (``True``) or decomposition (``False``) + fftengine : :obj:`str`, optional + Engine used for fft computation (``numpy`` or ``scipy``). Choose + ``numpy`` when working with cupy and jax arrays. backend : :obj:`str`, optional Backend used for creation of obliquity factor operator (``numpy`` or ``cupy``) @@ -203,7 +207,11 @@ def _obliquity3D( """ # create Fourier operator FFTop = FFTND( - dims=[nr[0], nr[1], nt], nffts=nffts, sampling=[dr[0], dr[1], dt], dtype=dtype + dims=[nr[0], nr[1], nt], + nffts=nffts, + sampling=[dr[0], dr[1], dt], + engine=fftengine, + dtype=dtype, ) # create obliquity operator @@ -547,6 +555,7 @@ def UpDownComposition3D( critical: float = 100.0, ntaper: int = 10, scaling: float = 1.0, + fftengine: str = "scipy", backend: str = "numpy", dtype: DTypeLike = "complex128", name: str = "U", @@ -588,6 +597,11 @@ def UpDownComposition3D( angle scaling : :obj:`float`, optional Scaling to apply to the operator (see Notes for more details) + fftengine : :obj:`str`, optional + .. versionadded:: 2.3.0 + + Engine used for fft computation (``numpy`` or ``scipy``). Choose + ``numpy`` when working with cupy and jax arrays. backend : :obj:`str`, optional Backend used for creation of obliquity factor operator (``numpy`` or ``cupy``) @@ -638,6 +652,7 @@ def UpDownComposition3D( critical=critical, ntaper=ntaper, composition=True, + fftengine=fftengine, backend=backend, dtype=dtype, ) diff --git a/pytests/test_jaxoperator.py b/pytests/test_jaxoperator.py new file mode 100755 index 00000000..86de4e8d --- /dev/null +++ b/pytests/test_jaxoperator.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from numpy.testing import assert_array_almost_equal, assert_array_equal + +from pylops import JaxOperator, MatrixMult + +par1 = {"ny": 11, "nx": 11, "dtype": np.float32} # square +par2 = {"ny": 21, "nx": 11, "dtype": np.float32} # overdetermined + +np.random.seed(0) + + +@pytest.mark.parametrize("par", [(par1)]) +def test_JaxOperator(par): + """Apply forward and adjoint and compare with native pylops.""" + M = np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(par["dtype"]) + Mop = MatrixMult(jnp.array(M), dtype=par["dtype"]) + Jop = JaxOperator(Mop) + + x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"]) + xjnp = jnp.array(x) + + # pylops operator + y = Mop * x + xadj = Mop.H * y + + # jax operator + yjnp = Jop * xjnp + xadjnp = Jop.rmatvecad(xjnp, yjnp) + + assert_array_equal(y, np.array(yjnp)) + assert_array_equal(xadj, np.array(xadjnp)) + + +@pytest.mark.parametrize("par", [(par1)]) +def test_TorchOperator_batch(par): + """Apply forward for input with multiple samples + (= batch) and flattened arrays""" + + M = np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(par["dtype"]) + Mop = MatrixMult(jnp.array(M), dtype=par["dtype"]) + Jop = JaxOperator(Mop) + auto_batch_matvec = jax.vmap(Jop._matvec) + + x = np.random.normal(0.0, 1.0, (4, par["nx"])).astype(par["dtype"]) + xjnp = jnp.array(x) + + y = Mop.matmat(x.T).T + yjnp = auto_batch_matvec(xjnp) + + assert_array_almost_equal(y, np.array(yjnp), decimal=5) diff --git a/requirements-dev.txt b/requirements-dev.txt index 703b377f..6ce1fb00 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ numpy>=1.21.0 scipy>=1.11.0 --extra-index-url https://download.pytorch.org/whl/cpu torch>=1.2.0 +jax numba pyfftw PyWavelets @@ -19,6 +20,7 @@ docutils<0.18 Sphinx pydata-sphinx-theme sphinx-gallery +sphinxemoji numpydoc nbsphinx image diff --git a/tutorials/ilsm.py b/tutorials/ilsm.py index b4b016f3..1f394bfc 100755 --- a/tutorials/ilsm.py +++ b/tutorials/ilsm.py @@ -1,5 +1,5 @@ r""" -20. Image Domain Least-squares migration +19. Image Domain Least-squares migration ======================================== Seismic migration is the process by which seismic data are manipulated to create an image of the subsurface reflectivity. diff --git a/tutorials/jaxop.py b/tutorials/jaxop.py new file mode 100755 index 00000000..c7a30d40 --- /dev/null +++ b/tutorials/jaxop.py @@ -0,0 +1,103 @@ +r""" +21. JAX Operator +================ +This tutorial is aimed at introducing the :class:`pylops.JaxOperator` operator. This +represents the entry-point to the JAX backend of PyLops. + +More specifically, by wrapping any of PyLops' operators into a +:class:`pylops.JaxOperator` one can: + +- apply forward, adjoint and use any of PyLops solver with JAX arrays; +- enable automatic differentiation; +- enable automatic vectorization. + +Moreover, both the forward and adjoint are internally just-in-time compiled +to enable any further optimization provided by JAX. + +In this example we will consider a :class:`pylops.MatrixMult` operator and +showcase how to use it in conjunction with :class:`pylops.JaxOperator` +to enable the different JAX functionalities mentioned above. + +""" +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +import pylops + +plt.close("all") +np.random.seed(10) + +############################################################################### +# Let's start by creating a :class:`pylops.MatrixMult` operator. We will then +# perform the dot-test as well as apply the forward and adjoint operations to +# JAX arrays. + +n = 4 +G = np.random.normal(0, 1, (n, n)).astype("float32") +Gopjax = pylops.JaxOperator(pylops.MatrixMult(jnp.array(G), dtype="float32")) + +# dottest +pylops.utils.dottest(Gopjax, n, n, backend="jax", verb=True, atol=1e-3) + +# forward +xjnp = jnp.ones(n, dtype="float32") +yjnp = Gopjax @ xjnp + +# adjoint +xadjjnp = Gopjax.H @ yjnp + +############################################################################### +# We can now use one of PyLops solvers to invert the operator + +xcgls = pylops.optimization.basic.cgls( + Gopjax, yjnp, x0=jnp.zeros(n), niter=100, tol=1e-10, show=True +)[0] +print("Inverse: ", xcgls) + +############################################################################### +# Let's see how we can empower the automatic differentiation capabilities +# of JAX to obtain the adjoint of our operator without having to implement it. +# Although in PyLops the adjoint of any of operators is hand-written (and +# optimized), it may be useful in some cases to quickly implement the forward +# pass of a new operator and get the adjoint for free. This could be extremely +# beneficial during the prototyping stage of an operator before embarking in +# implementing an efficient hand-written adjoint. + +xadjjnpad = Gopjax.rmatvecad(xjnp, yjnp) + +print("Hand-written Adjoint: ", xadjjnp) +print("AD Adjoint: ", xadjjnpad) + +############################################################################### +# And more in general how we can combine any of JAX native operations with a +# PyLops operator. + + +def fun(x): + y = Gopjax(x) + loss = jnp.sum(y) + return loss + + +xgrad = jax.grad(fun)(xjnp) +print("Grad: ", xgrad) + +############################################################################### +# We turn now our attention to automatic vectorization, which is very useful +# if we want to apply the same operator to multiple vectors. In PyLops we can +# easily do so by using the ``matmat`` and ``rmatmat`` methods, however under +# the hood what these methods do is to simply run a for...loop and call the +# corresponding ``matvec`` / ``rmatvec`` methods multiple times. On the other +# hand, JAX is able to automatically add a batch axis at the beginning of +# operator. Moreover, this can be seamlessly combined with `jax.jit` to +# further improve performance. + +auto_batch_matvec = jax.jit(jax.vmap(Gopjax._matvec)) +xs = jnp.stack([xjnp, xjnp]) +ys = auto_batch_matvec(xs) + +print("Original output: ", yjnp) +print("AV Output 1: ", ys[0]) +print("AV Output 1: ", ys[1]) diff --git a/tutorials/torchop.py b/tutorials/torchop.py index c555573d..9e73d7b3 100755 --- a/tutorials/torchop.py +++ b/tutorials/torchop.py @@ -1,6 +1,6 @@ r""" -19. Automatic Differentiation -============================= +20. Torch Operator +================== This tutorial focuses on the use of :class:`pylops.TorchOperator` to allow performing Automatic Differentiation (AD) on chains of operators which can be: