Skip to content

Commit

Permalink
Addressed Enriques comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Oct 3, 2024
1 parent 255267d commit 29d8452
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 35 deletions.
28 changes: 16 additions & 12 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,31 @@

import functools
from collections.abc import Callable, Mapping
from typing import Any, Literal, ParamSpec, TypedDict, overload
from typing import Any, Final, Literal, ParamSpec, TypedDict, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack

from jace import stages, translator
from jace import stages, translator, util


__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"]
__all__ = ["DEFAUL_BACKEND", "JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")

DEFAUL_BACKEND: Final[str] = "cpu"


class JITOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.
Not all arguments that are supported by `jax-jit()` are also supported by
`jace.jit`. Furthermore, some additional ones might be supported.
The following arguments are supported:
- `backend`: For which platform DaCe should generate code for. It is a string,
where the following values are supported: `'cpu'` or `'gpu'`.
DaCe's `DeviceType` enum or FPGA are not supported.
Args:
backend: Target platform for which DaCe should generate code. Supported values
are `'cpu'` or `'gpu'`.
"""

backend: str
Expand Down Expand Up @@ -75,17 +77,18 @@ def jit(
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments, see `JITOptions` for more.
kwargs: JIT arguments, see `JITOptions` for more.
Note:
This function is the only valid way to obtain a JaCe computation.
"""
not_supported_jit_keys = kwargs.keys() - {"backend"}
not_supported_jit_keys = kwargs.keys() - JITOptions.__annotations__.keys()
if not_supported_jit_keys:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(not_supported_jit_keys)}."
raise ValueError(
f"The following arguments to 'jace.jit' are not supported: {', '.join(not_supported_jit_keys)}."
)
if kwargs.get("backend", DEFAUL_BACKEND).lower() not in {"cpu", "gpu"}:
raise ValueError(f"The backend '{kwargs['backend']}' is not supported.")

def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
jace_wrapper = stages.JaCeWrapped(
Expand All @@ -96,6 +99,7 @@ def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
else primitive_translators
),
jit_options=kwargs,
device=util.parse_backend_jit_option(kwargs.get("backend", DEFAUL_BACKEND)),
)
functools.update_wrapper(jace_wrapper, f)
return jace_wrapper
Expand Down
20 changes: 8 additions & 12 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": True,
"persistent_transients": True,
"validate": True,
"validate_all": False,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent_transients": False,
"validate": True,
"validate_all": False,
}
Expand All @@ -55,7 +53,6 @@ class CompilerOptions(TypedDict, total=False):

auto_optimize: bool
simplify: bool
persistent_transients: bool
validate: bool
validate_all: bool

Expand All @@ -77,22 +74,21 @@ def jace_optimize( # noqa: D417 [undocumented-param] # `kwargs` is not documen
device: The device on which the SDFG will run on.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline.
persistent_transients: Set the allocation lifetime of (non register) transients
in the SDFG to `AllocationLifetime.Persistent`, i.e. keep them allocated
between different invocations.
validate: Perform validation at the end.
validate_all: Perform extensive validation.
validate: Perform validation of the SDFG at the end.
validate_all: Perform validation after each substep.
Note:
Currently DaCe's auto optimization pipeline is used when auto optimize is
enabled. However, it might change in the future. Because DaCe's auto
optimizer is considered unstable it must be explicitly enabled.
"""
assert device in {dace.DeviceType.CPU, dace.DeviceType.GPU}
simplify = kwargs.get("simplify", False)
auto_optimize = kwargs.get("auto_optimize", False)
validate = kwargs.get("validate", DEFAULT_OPTIMIZATIONS["validate"])
validate_all = kwargs.get("validate_all", DEFAULT_OPTIMIZATIONS["validate_all"])
# If an argument is not specified then we consider it disabled.
kwargs = {**NO_OPTIMIZATIONS, **kwargs}
simplify = kwargs["simplify"]
auto_optimize = kwargs["auto_optimize"]
validate = kwargs["validate"]
validate_all = kwargs["validate_all"]

if simplify:
tsdfg.sdfg.simplify(
Expand Down
13 changes: 8 additions & 5 deletions src/jace/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class JaCeWrapped(tcache.CachingStage["JaCeLowered"], Generic[_P]):
fun: The function that is wrapped.
primitive_translators: Primitive translators that that should be used.
jit_options: Options to influence the jit process.
device: The device on which the SDFG will run on.
Todo:
- Support default values of the wrapped function.
Expand All @@ -102,17 +103,20 @@ class JaCeWrapped(tcache.CachingStage["JaCeLowered"], Generic[_P]):
_fun: Callable[_P, Any]
_primitive_translators: dict[str, translator.PrimitiveTranslator]
_jit_options: api.JITOptions
_device: dace.DeviceType

def __init__(
self,
fun: Callable[_P, Any],
primitive_translators: Mapping[str, translator.PrimitiveTranslator],
jit_options: api.JITOptions,
device: dace.DeviceType,
) -> None:
super().__init__()
self._primitive_translators = {**primitive_translators}
self._jit_options = {**jit_options}
self._fun = fun
self._device = device

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
"""
Expand Down Expand Up @@ -165,10 +169,9 @@ def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> JaCeLowered:
trans_ctx: translator.TranslationContext = builder.translate_jaxpr(jaxpr)

flat_call_args = jax_tree.tree_leaves((args, kwargs))
device = util.parse_backend_jit_option(self._jit_options.get("backend", "cpu"))
tsdfg: tjsdfg.TranslatedJaxprSDFG = ptranslation.postprocess_jaxpr_sdfg(
trans_ctx=trans_ctx,
device=device,
device=self._device,
fun=self.wrapped_fun,
flat_call_args=flat_call_args,
)
Expand All @@ -178,7 +181,7 @@ def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> JaCeLowered:
tsdfg=tsdfg,
out_tree=out_tree,
jaxpr=trans_ctx.jaxpr,
device=device,
device=self._device,
)

@property
Expand Down Expand Up @@ -241,13 +244,13 @@ def __init__(
tsdfg: tjsdfg.TranslatedJaxprSDFG,
out_tree: jax_tree.PyTreeDef,
jaxpr: jax_core.ClosedJaxpr,
device: str | dace.DeviceType,
device: dace.DeviceType,
) -> None:
super().__init__()
self._translated_sdfg = tsdfg
self._out_tree = out_tree
self._jaxpr = jaxpr
self._device = util.parse_backend_jit_option(device)
self._device = device

@tcache.cached_transition
def compile(self, compiler_options: CompilerOptions | None = None) -> JaCeCompiled:
Expand Down
10 changes: 6 additions & 4 deletions src/jace/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def make_jaxpr(
# TODO(phimuell): Test if this restriction is needed.
assert all(param.default is param.empty for param in inspect.signature(fun).parameters.values())

# NOTE: `jax.make_jaxpr()` to current tracing backend we use, never supported
# all arguments `jax.jit()` supported, which is strange. But in JaCe this
# _will_ not be the case. To make things work, until we have a better working
# backend we have to filter, i.e. clearing the `trace_options`.
# NOTE: In the current implementation we are using `jax.make_jaxpr()`. But this
# is a different implementation than `jax.jit()` uses. The main difference
# between the two, seems to be the set of arguments that are supported. In JaCe,
# however, we want to support all arguments that `jace.jit()` does.
# For establishing compatibility we have to clear the arguments to make them
# compatible, with what `jax.make_jaxpr()` and `jace.jit()` supports.
trace_options = {}

def tracer_impl(
Expand Down
2 changes: 1 addition & 1 deletion src/jace/translator/post_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _create_input_state(
The function will create a new set of variables that are exposed as inputs. This
variables are based on the example input arguments passed through `flat_call_args`.
This process will hard code the memory location, i.e. if the input is on the GPU,
then the new input will be don the GPU as well and strides into the SDFG.
then the new input will be on the GPU as well and strides into the SDFG.
The assignment is performed inside a new state, which is put at the beginning.
Args:
Expand Down
1 change: 0 additions & 1 deletion src/jace/util/jax_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,5 +301,4 @@ def move_into_jax_array(
"""
if isinstance(arr, jax.Array):
return arr
# In newer version it is no longer needed to pass a capsule.
return jax_dlpack.from_dlpack(arr, copy=copy) # type: ignore[attr-defined] # `from_dlpack` is not found.

0 comments on commit 29d8452

Please sign in to comment.