Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added Auto Opt, GPU and jax.Array #26

Merged

Conversation

philip-paul-mueller
Copy link
Contributor

This commit addresses the following issues:

  • It adds an implementation for auto optimizer (Add auto_optimize to the options #24).
    The current implementation is not likely to stay, since it essentially uses DaCe's version, which is known to have problems with JaCe's SDFGs.
  • It allows to run stuff on GPU (Add GPU Capabilities. #25).
    While it is possible it is still needed that the user explicitly specify it, JAX does an auto detection.
  • Instead of returning NumPy arrays JaCe now returns jax.Array objects (Return jax.Array #22).
    This goes in tandem with a reworking of the type annotation, which was wrong before (and can not be correctly made).

The auto optimizer is disabled and must be enabled explicitly.
This will be the behaviour will be maintained at least until the map fusion PR is merged.

Essentially this commit addresses issues:
- [issue#24](GridTools#24)
- [issue#25](GridTools#25)
As explained elsewhere the annotation was wrong.
Simple example:
```python
@jace.jit
def foo(a: np.ndarray) -> np.float64:
	return np.float64(a.sum())
```
The actual return value would be `np.ndarray`, although one dimension.
This limitation is also present in JAX.
@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 68.91892% with 23 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@0a9f361). Learn more about missing BASE report.

Files with missing lines Patch % Lines
src/jace/util/jax_helper.py 33.33% 9 Missing and 3 partials ⚠️
src/jace/api.py 63.63% 2 Missing and 2 partials ⚠️
src/jace/optimization.py 76.47% 2 Missing and 2 partials ⚠️
src/jace/__init__.py 33.33% 1 Missing and 1 partial ⚠️
src/jace/stages.py 94.11% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #26   +/-   ##
=======================================
  Coverage        ?   68.36%           
=======================================
  Files           ?       31           
  Lines           ?     1277           
  Branches        ?      261           
=======================================
  Hits            ?      873           
  Misses          ?      329           
  Partials        ?       75           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I have some minor comments. (BTW, Iit is not by any means "a super small PR")

src/jace/api.py Outdated Show resolved Hide resolved
src/jace/api.py Outdated
"""

backend: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend: str
backend: Literal['cpu', 'gpu']

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think that enforcing this through the annotations and MyPy is the right way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? What's the use case for the Literal type annotation then? Using Literal here also helps in documenting the supported values in a more prominent way.

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First a practical reason, if you read the backend string say from a file then you will get a str, so you will have to cast it.
It also does not really helps documenting since in general it does not explain what a particular value mean or does, you still have to read the doc anyway.

src/jace/api.py Outdated
@@ -72,18 +75,19 @@ def jit(
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.
kwargs: Jit arguments, see `JITOptions` for more.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs: Jit arguments, see `JITOptions` for more.
kwargs: jit arguments, see `JITOptions` for more.

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be consistent and write JIT capital or Jit as it is the abbreviation of "Just in time" and we are at the beginning of a sentence.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use JIT then but it's up to you...

src/jace/api.py Outdated Show resolved Hide resolved
src/jace/api.py Outdated Show resolved Hide resolved
if isinstance(arr, jax.Array):
return arr
# In newer version it is no longer needed to pass a capsule.
return jax_dlpack.from_dlpack(arr, copy=copy) # type: ignore[attr-defined] # `from_dlpack` is not found.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this type ignore needed if jax is included in our lenient mypy settings (and also the function should actually exist)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea.

Comment on lines 33 to 38
try:
import cupy as cp # type: ignore[import-not-found]
except ImportError:
cp = None


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case you want to use JaCe one a CPU only machine as the CI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, why is cupy needed here? I can't see where it is used in this module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the update of the newest JAX we had to test for the array type, since for cupy we have to call .toDLPack() but for numpy it was necessary to call .__dlpack__().
I simply forgot that.

Comment on lines 88 to 91
# NOTE: `jax.make_jaxpr()` to current tracing backend we use, never supported
# all arguments `jax.jit()` supported, which is strange. But in JaCe this
# _will_ not be the case. To make things work, until we have a better working
# backend we have to filter, i.e. clearing the `trace_options`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check and rewrite, I don't understand it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right it was written very strangely.

Comment on lines +105 to +106
# For some reasons MyPy seems to think that `jax.make_jaxpr()` is the same
# as `jace.make_jaxpr()` so we have to ignore the error.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds suspicious...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but I have no idea why.

src/jace/translator/post_translation.py Outdated Show resolved Hide resolved
src/jace/api.py Outdated Show resolved Hide resolved
src/jace/api.py Outdated Show resolved Hide resolved
src/jace/api.py Outdated
"""

backend: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? What's the use case for the Literal type annotation then? Using Literal here also helps in documenting the supported values in a more prominent way.

Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pressed the send button by mistake with only half of my comments. Here is the second half.

src/jace/tracing.py Outdated Show resolved Hide resolved
src/jace/tracing.py Outdated Show resolved Hide resolved
Comment on lines 33 to 38
try:
import cupy as cp # type: ignore[import-not-found]
except ImportError:
cp = None


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, why is cupy needed here? I can't see where it is used in this module.

src/jace/util/jax_helper.py Show resolved Hide resolved
Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I have just a comment about an error message but I don't need to review it again so I'm already approving the PR now.

case "tpu" | "TPU":
raise NotImplementedError("TPU are not supported.")
case _:
raise ValueError(f"Could not parse the backend '{backend}'.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite the error message to reflect the new name of the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right.

@philip-paul-mueller philip-paul-mueller merged commit 3be9f36 into GridTools:main Oct 4, 2024
4 checks passed
@philip-paul-mueller philip-paul-mueller deleted the auto_opt_and_gpu branch October 7, 2024 13:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants