Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Modular Calling #28

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 68 additions & 13 deletions src/jace/translated_jaxpr_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,18 @@ class CompiledJaxprSDFG:
def sdfg(self) -> dace.SDFG: # noqa: D102 [undocumented-public-method]
return self.compiled_sdfg.sdfg

def __call__(
def _construct_csdfg_args(
self,
flat_call_args: Sequence[Any],
) -> list[jax.Array]:
) -> dict[str, Any]:
"""
Run the compiled SDFG using the flattened input.
Create the calling arguments from `flat_call_args`.

The function will not perform flattening of its input nor unflattening of
the output.
The function will collect the already flattened arguments into a `dict`.
Furthermore, it will allocate the buffers that are used for the return values
and add them to the `dict` as well.
The `dict` can then be passed to `self._call_csdfg()` to invoke the compiled
SDFG.

Args:
flat_call_args: Flattened input arguments.
Expand All @@ -154,42 +157,94 @@ def __call__(
f"Expected {len(self.input_names)} flattened arguments, but got {len(flat_call_args)}."
)

sdfg_call_args: dict[str, Any] = {}
csdfg_call_args: dict[str, Any] = {}
for in_name, in_val in zip(self.input_names, flat_call_args):
# TODO(phimuell): Implement a stride matching process.
if util.is_jax_array(in_val):
if not util.is_fully_addressable(in_val):
raise ValueError(f"Passed a not fully addressable JAX array as '{in_name}'")
in_val = in_val.__array__() # noqa: PLW2901 [redefined-loop-name] # JAX arrays do not expose the __array_interface__.
sdfg_call_args[in_name] = in_val
csdfg_call_args[in_name] = in_val

# Allocate the output arrays.
# In DaCe the output arrays are created by the `CompiledSDFG` calls and all
# calls share the same arrays. In JaCe the output arrays are distinct.
arrays = self.sdfg.arrays
for output_name in self.output_names:
sdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name])
csdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name])

assert len(sdfg_call_args) == len(self.compiled_sdfg.argnames), (
assert len(csdfg_call_args) == len(self.compiled_sdfg.argnames), (
"Failed to construct the call arguments,"
f" expected {len(self.compiled_sdfg.argnames)} but got {len(flat_call_args)}."
f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(sdfg_call_args.keys())}"
f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(csdfg_call_args.keys())}"
)
return csdfg_call_args

def _call_csdfg(
self,
csdfg_call_args: dict[str, Any],
) -> None:
"""
Calls the underlying SDFG with the data in `csdfg_call_args`.

This will forward the arguments directly to the compiled SDFG object.
See `self._construct_csdfg_args()` for how to construct the `dict`
and `self._extract_return_values()` for how to get the output values back.

Args:
csdfg_call_args: The required arguments to call the compiled sdfg.
"""
assert len(csdfg_call_args) == len(self.compiled_sdfg.argnames)

# Calling the SDFG
with dace.config.temporary_config():
dace.Config.set("compiler", "allow_view_arguments", value=True)
self.compiled_sdfg(**sdfg_call_args)
self.compiled_sdfg(**csdfg_call_args)

def _extract_return_values(
self,
csdfg_call_args: dict[str, Any],
) -> list[jax.Array]:
"""
Extract the return values and return the flattened version.

JaCe allocates the buffer for the return value outside the SDFG and passes
them as arguments, see `self._construct_csdfg_args()` and `self._call_csdfg()`.
This function will extract these values and return them in the flattened order.
Furthermore, the buffer will be transferred to a `jax.Array` object.

Args:
csdfg_call_args: Collection of the arguments passed to the compiled SDFG.

Note:
After this function returns accessing any element in `csdfg_call_args`
is undefined behaviour.
"""
# DaCe writes the results either into CuPy or NumPy arrays. For compatibility
# with JAX we will now turn them into `jax.Array`s. Note that this is safe
# because we created these arrays in this function explicitly. Thus when
# this function ends, there is no writable reference to these arrays left.
return [
util.move_into_jax_array(sdfg_call_args[output_name])
util.move_into_jax_array(csdfg_call_args[output_name])
for output_name in self.output_names
]

def __call__(
self,
flat_call_args: Sequence[Any],
) -> list[jax.Array]:
"""
Run the compiled SDFG using the flattened input.

The function will not perform flattening of its input nor unflattening of
the output.

Args:
flat_call_args: Flattened input arguments.
"""
csdfg_call_args = self._construct_csdfg_args(flat_call_args)
self._call_csdfg(csdfg_call_args)
return self._extract_return_values(csdfg_call_args)


def compile_jaxpr_sdfg(tsdfg: TranslatedJaxprSDFG) -> dace_csdfg.CompiledJaxprSDFG:
"""Compile `tsdfg` and return a `CompiledJaxprSDFG` object with the result."""
Expand Down