Skip to content

Commit

Permalink
Merge pull request #607 from mrava87/patch-jaximport
Browse files Browse the repository at this point in the history
bug: protect use of jnp if jax is not installed
  • Loading branch information
mrava87 authored Aug 17, 2024
2 parents e111225 + b8d2c85 commit 8039762
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pylops/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
return jnp
elif deps.cupy_enabled:
return cp.get_array_module(x)
Expand All @@ -153,7 +153,7 @@ def get_convolve(x: npt.ArrayLike) -> Callable:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
return j_convolve
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
return cp_convolve
Expand All @@ -178,7 +178,7 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
return j_fftconvolve
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
return cp_fftconvolve
Expand All @@ -203,7 +203,7 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
raise NotImplementedError(
"oaconvolve not implemented in "
"jax. Consider using a different"
Expand Down Expand Up @@ -232,7 +232,7 @@ def get_correlate(x: npt.ArrayLike) -> Callable:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
return jax.scipy.signal.correlate
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
return cp_correlate
Expand Down Expand Up @@ -303,7 +303,7 @@ def get_block_diag(x: npt.ArrayLike) -> Callable:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
return jnp_block_diag
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
return cp_block_diag
Expand All @@ -328,7 +328,7 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable:
"""
if deps.cupy_enabled or deps.jax_enabled:
if isinstance(x, jnp.ndarray):
if deps.jax_enabled and isinstance(x, jnp.ndarray):
return jnp_toeplitz
elif deps.cupy_enabled and cp.get_array_module(x) == cp:
return cp_toeplitz
Expand Down

0 comments on commit 8039762

Please sign in to comment.