From b8d2c85ab5389a8a5cc8988b0ede776ff3bf5f3a Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sat, 17 Aug 2024 15:23:34 +0300 Subject: [PATCH] bug: protect use of jnp if jax is not installed --- pylops/utils/backend.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py index 4b6b506f..50da7faa 100644 --- a/pylops/utils/backend.py +++ b/pylops/utils/backend.py @@ -128,7 +128,7 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): return jnp elif deps.cupy_enabled: return cp.get_array_module(x) @@ -153,7 +153,7 @@ def get_convolve(x: npt.ArrayLike) -> Callable: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): return j_convolve elif deps.cupy_enabled and cp.get_array_module(x) == cp: return cp_convolve @@ -178,7 +178,7 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): return j_fftconvolve elif deps.cupy_enabled and cp.get_array_module(x) == cp: return cp_fftconvolve @@ -203,7 +203,7 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): raise NotImplementedError( "oaconvolve not implemented in " "jax. Consider using a different" @@ -232,7 +232,7 @@ def get_correlate(x: npt.ArrayLike) -> Callable: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): return jax.scipy.signal.correlate elif deps.cupy_enabled and cp.get_array_module(x) == cp: return cp_correlate @@ -303,7 +303,7 @@ def get_block_diag(x: npt.ArrayLike) -> Callable: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): return jnp_block_diag elif deps.cupy_enabled and cp.get_array_module(x) == cp: return cp_block_diag @@ -328,7 +328,7 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable: """ if deps.cupy_enabled or deps.jax_enabled: - if isinstance(x, jnp.ndarray): + if deps.jax_enabled and isinstance(x, jnp.ndarray): return jnp_toeplitz elif deps.cupy_enabled and cp.get_array_module(x) == cp: return cp_toeplitz