Skip to content

Commit

Permalink
minor: one more fix of typing of jax
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Aug 9, 2024
1 parent 29a4dc2 commit 310b2a9
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions pylops/jaxoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@
if deps.jax_enabled:
import jax

jaxarray_type = jax.Array
jaxarrayin_type = jax.typing.ArrayLike
jaxarrayout_type = jax.Array
else:
jax_message = (
"JAX package not installed. In order to be able to use"
'the jaxoperator module run "pip install jax" or'
'"conda install -c conda-forge jax".'
)
jaxarray_type = Any
jaxarrayin_type = None
jaxarrayout_type = None

JaxType = NewType("JaxType", jaxarray_type)
JaxTypeIn = NewType("JaxTypeIn", jaxarrayin_type)
JaxTypeOut = NewType("JaxTypeOut", jaxarrayout_type)


class JaxOperator(LinearOperator):
Expand Down Expand Up @@ -57,12 +60,12 @@ def __init__(self, Op: LinearOperator) -> None:
def __call__(self, x, *args, **kwargs):
return self._matvec(x)

def _rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
def _rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
_, f_vjp = jax.vjp(self._matvec, x)
xadj = jax.jit(f_vjp)(y)[0]
return xadj

def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
def rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
"""Vector-Jacobian product
JIT-compiled Vector-Jacobian product
Expand All @@ -76,7 +79,7 @@ def rmatvecad(self, x: JaxType, y: JaxType) -> JaxType:
Returns
-------
xadj : :obj:`jax.Array`
xadj : :obj:`jax.typing.ArrayLike`
Output array
"""
Expand Down

0 comments on commit 310b2a9

Please sign in to comment.