-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Added Auto Opt, GPU and jax.Array
#26
feat: Added Auto Opt, GPU and jax.Array
#26
Conversation
The auto optimizer is disabled and must be enabled explicitly. This will be the behaviour will be maintained at least until the map fusion PR is merged. Essentially this commit addresses issues: - [issue#24](GridTools#24) - [issue#25](GridTools#25)
As explained elsewhere the annotation was wrong. Simple example: ```python @jace.jit def foo(a: np.ndarray) -> np.float64: return np.float64(a.sum()) ``` The actual return value would be `np.ndarray`, although one dimension. This limitation is also present in JAX.
This addresses [issue#22](GridTools#22).
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #26 +/- ##
=======================================
Coverage ? 68.36%
=======================================
Files ? 31
Lines ? 1277
Branches ? 261
=======================================
Hits ? 873
Misses ? 329
Partials ? 75 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I have some minor comments. (BTW, Iit is not by any means "a super small PR")
src/jace/api.py
Outdated
""" | ||
|
||
backend: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backend: str | |
backend: Literal['cpu', 'gpu'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think that enforcing this through the annotations and MyPy is the right way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not? What's the use case for the Literal
type annotation then? Using Literal
here also helps in documenting the supported values in a more prominent way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First a practical reason, if you read the backend string say from a file then you will get a str
, so you will have to cast it.
It also does not really helps documenting since in general it does not explain what a particular value mean or does, you still have to read the doc anyway.
src/jace/api.py
Outdated
@@ -72,18 +75,19 @@ def jit( | |||
fun: Function to wrap. | |||
primitive_translators: Use these primitive translators for the lowering to SDFG. | |||
If not specified the translators in the global registry are used. | |||
kwargs: Jit arguments. | |||
kwargs: Jit arguments, see `JITOptions` for more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs: Jit arguments, see `JITOptions` for more. | |
kwargs: jit arguments, see `JITOptions` for more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be consistent and write JIT capital or Jit
as it is the abbreviation of "Just in time" and we are at the beginning of a sentence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather use JIT
then but it's up to you...
if isinstance(arr, jax.Array): | ||
return arr | ||
# In newer version it is no longer needed to pass a capsule. | ||
return jax_dlpack.from_dlpack(arr, copy=copy) # type: ignore[attr-defined] # `from_dlpack` is not found. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this type ignore needed if jax
is included in our lenient mypy settings (and also the function should actually exist)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea.
src/jace/util/jax_helper.py
Outdated
try: | ||
import cupy as cp # type: ignore[import-not-found] | ||
except ImportError: | ||
cp = None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case you want to use JaCe one a CPU only machine as the CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant, why is cupy needed here? I can't see where it is used in this module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before the update of the newest JAX we had to test for the array type, since for cupy we have to call .toDLPack()
but for numpy it was necessary to call .__dlpack__()
.
I simply forgot that.
src/jace/tracing.py
Outdated
# NOTE: `jax.make_jaxpr()` to current tracing backend we use, never supported | ||
# all arguments `jax.jit()` supported, which is strange. But in JaCe this | ||
# _will_ not be the case. To make things work, until we have a better working | ||
# backend we have to filter, i.e. clearing the `trace_options`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check and rewrite, I don't understand it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right it was written very strangely.
# For some reasons MyPy seems to think that `jax.make_jaxpr()` is the same | ||
# as `jace.make_jaxpr()` so we have to ignore the error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds suspicious...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, but I have no idea why.
src/jace/api.py
Outdated
""" | ||
|
||
backend: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not? What's the use case for the Literal
type annotation then? Using Literal
here also helps in documenting the supported values in a more prominent way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pressed the send
button by mistake with only half of my comments. Here is the second half.
src/jace/util/jax_helper.py
Outdated
try: | ||
import cupy as cp # type: ignore[import-not-found] | ||
except ImportError: | ||
cp = None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant, why is cupy needed here? I can't see where it is used in this module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I have just a comment about an error message but I don't need to review it again so I'm already approving the PR now.
src/jace/util/jax_helper.py
Outdated
case "tpu" | "TPU": | ||
raise NotImplementedError("TPU are not supported.") | ||
case _: | ||
raise ValueError(f"Could not parse the backend '{backend}'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrite the error message to reflect the new name of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are absolutely right.
This commit addresses the following issues:
auto_optimize
to the options #24).The current implementation is not likely to stay, since it essentially uses DaCe's version, which is known to have problems with JaCe's SDFGs.
While it is possible it is still needed that the user explicitly specify it, JAX does an auto detection.
jax.Array
objects (Returnjax.Array
#22).This goes in tandem with a reworking of the type annotation, which was wrong before (and can not be correctly made).