diff --git a/pylops/jaxoperator.py b/pylops/jaxoperator.py index bf00a604..5d5c40ed 100644 --- a/pylops/jaxoperator.py +++ b/pylops/jaxoperator.py @@ -18,8 +18,8 @@ 'the jaxoperator module run "pip install jax" or' '"conda install -c conda-forge jax".' ) - jaxarrayin_type = None - jaxarrayout_type = None + jaxarrayin_type = Any + jaxarrayout_type = Any JaxTypeIn = NewType("JaxTypeIn", jaxarrayin_type) JaxTypeOut = NewType("JaxTypeOut", jaxarrayout_type)