diff --git a/pyproject.toml b/pyproject.toml index 3149170..eccef87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ module = [ "dace.*", "jax.*", "jaxlib.*", + "cupy.", ] # -- pytest -- diff --git a/src/jace/translated_jaxpr_sdfg.py b/src/jace/translated_jaxpr_sdfg.py index 9cb9908..1f22c6c 100644 --- a/src/jace/translated_jaxpr_sdfg.py +++ b/src/jace/translated_jaxpr_sdfg.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: - import numpy as np + import jax from dace.codegen import compiled_sdfg as dace_csdfg @@ -139,7 +139,7 @@ def sdfg(self) -> dace.SDFG: # noqa: D102 [undocumented-public-method] def __call__( self, flat_call_args: Sequence[Any], - ) -> list[np.ndarray]: + ) -> list[jax.Array]: """ Run the compiled SDFG using the flattened input. @@ -178,7 +178,10 @@ def __call__( dace.Config.set("compiler", "allow_view_arguments", value=True) self.compiled_sdfg(**sdfg_call_args) - return [sdfg_call_args[output_name] for output_name in self.output_names] + return [ + util.move_into_jax_array(sdfg_call_args[output_name]) + for output_name in self.output_names + ] def compile_jaxpr_sdfg(tsdfg: TranslatedJaxprSDFG) -> dace_csdfg.CompiledJaxprSDFG: diff --git a/src/jace/util/__init__.py b/src/jace/util/__init__.py index 6d4079d..81d2f99 100644 --- a/src/jace/util/__init__.py +++ b/src/jace/util/__init__.py @@ -17,6 +17,7 @@ get_jax_var_name, get_jax_var_shape, is_tracing_ongoing, + move_into_jax_array, parse_backend_jit_option, propose_jax_name, translate_dtype, @@ -51,6 +52,7 @@ "is_on_device", "is_scalar", "is_tracing_ongoing", + "move_into_jax_array", "parse_backend_jit_option", "propose_jax_name", "translate_dtype", diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index 2f95709..f02978e 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -21,7 +21,7 @@ import dace import jax -import jax.core as jax_core +from jax import core as jax_core, dlpack as jax_dlpack from jace import util @@ -30,6 +30,12 @@ import numpy as np +try: + import cupy as cp # type: ignore[import-not-found] +except ImportError: + cp = None + + @dataclasses.dataclass(repr=True, frozen=True, eq=False) class JaCeVar: """ @@ -277,3 +283,23 @@ def parse_backend_jit_option( raise NotImplementedError("TPU are not supported.") case _: raise ValueError(f"Could not parse the backend '{backend}'.") + + +def move_into_jax_array( + arr: Any, + copy: bool | None = False, +) -> jax.Array: + """ + Moves `arr` into a JAX array using DLPack format. + + By default `copy` is set to `False`, it is the responsibility of the caller + to ensure that the underlying buffer is not modified later. + + Args: + arr: The array to move into a JAX array. + copy: Should a copy be made; defaults to `False`. + """ + if isinstance(arr, jax.Array): + return arr + # In newer version it is no longer needed to pass a capsule. + return jax_dlpack.from_dlpack(arr, copy=copy) # type: ignore[attr-defined] # `from_dlpack` is not found.