Skip to content

Commit

Permalink
Now JaCe uses the jax.Array as return type.
Browse files Browse the repository at this point in the history
This addresses [issue#22](GridTools#22).
  • Loading branch information
philip-paul-mueller committed Oct 2, 2024
1 parent e2aa76f commit 255267d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ module = [
"dace.*",
"jax.*",
"jaxlib.*",
"cupy.",
]

# -- pytest --
Expand Down
9 changes: 6 additions & 3 deletions src/jace/translated_jaxpr_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


if TYPE_CHECKING:
import numpy as np
import jax
from dace.codegen import compiled_sdfg as dace_csdfg


Expand Down Expand Up @@ -139,7 +139,7 @@ def sdfg(self) -> dace.SDFG: # noqa: D102 [undocumented-public-method]
def __call__(
self,
flat_call_args: Sequence[Any],
) -> list[np.ndarray]:
) -> list[jax.Array]:
"""
Run the compiled SDFG using the flattened input.
Expand Down Expand Up @@ -178,7 +178,10 @@ def __call__(
dace.Config.set("compiler", "allow_view_arguments", value=True)
self.compiled_sdfg(**sdfg_call_args)

return [sdfg_call_args[output_name] for output_name in self.output_names]
return [
util.move_into_jax_array(sdfg_call_args[output_name])
for output_name in self.output_names
]


def compile_jaxpr_sdfg(tsdfg: TranslatedJaxprSDFG) -> dace_csdfg.CompiledJaxprSDFG:
Expand Down
2 changes: 2 additions & 0 deletions src/jace/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_jax_var_name,
get_jax_var_shape,
is_tracing_ongoing,
move_into_jax_array,
parse_backend_jit_option,
propose_jax_name,
translate_dtype,
Expand Down Expand Up @@ -51,6 +52,7 @@
"is_on_device",
"is_scalar",
"is_tracing_ongoing",
"move_into_jax_array",
"parse_backend_jit_option",
"propose_jax_name",
"translate_dtype",
Expand Down
28 changes: 27 additions & 1 deletion src/jace/util/jax_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import dace
import jax
import jax.core as jax_core
from jax import core as jax_core, dlpack as jax_dlpack

from jace import util

Expand All @@ -30,6 +30,12 @@
import numpy as np


try:
import cupy as cp # type: ignore[import-not-found]
except ImportError:
cp = None


@dataclasses.dataclass(repr=True, frozen=True, eq=False)
class JaCeVar:
"""
Expand Down Expand Up @@ -277,3 +283,23 @@ def parse_backend_jit_option(
raise NotImplementedError("TPU are not supported.")
case _:
raise ValueError(f"Could not parse the backend '{backend}'.")


def move_into_jax_array(
arr: Any,
copy: bool | None = False,
) -> jax.Array:
"""
Moves `arr` into a JAX array using DLPack format.
By default `copy` is set to `False`, it is the responsibility of the caller
to ensure that the underlying buffer is not modified later.
Args:
arr: The array to move into a JAX array.
copy: Should a copy be made; defaults to `False`.
"""
if isinstance(arr, jax.Array):
return arr
# In newer version it is no longer needed to pass a capsule.
return jax_dlpack.from_dlpack(arr, copy=copy) # type: ignore[attr-defined] # `from_dlpack` is not found.

0 comments on commit 255267d

Please sign in to comment.