Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Added an Extensive Set of Tests #21

Open
wants to merge 512 commits into
base: main
Choose a base branch
from

Conversation

philip-paul-mueller
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller commented Sep 24, 2024

This PR adds a large set of tests for various translators, especially the ones that where added by PR#20 but also other parts of the code are tested.
This is a very basic PR that does not add any functionality, except tests for the one that is already there.

This PR can only be merged after PR#20 has been merged.
Furthermore, for certain reasons this PR contains the whole development history of PR#20.

…o tests for them yet.

`device_put` is actually a very powerfull operation, such as Memlets, where source and destination are on different devices.
However, we do not support something like that yet, so we will keep it down yet.
It is not yet dynamic sclice, but soon.
…n value.

This restirction was removed commit `05e4a885441c0cd`.
However, I realized that it made some higher level translators impossible to write, such as `select_n`.
Thus I removed this restriction again.

Another solution would be to add another layer.
I also observe random failures for the `test_iota_broadcast()` test if I run all tests.
However, if I only run it, then nothing happens, I have no idea why.
This should bring the development branch to the newest stage.
Jax adjust the start indexes if the window overruns, however, this is not done, instead an out of bound error happens.
During that work I also detected some [issue](spcl/dace#1579) in DaCe's simplification pipeline.
The cleaning was not correct.
The function essentially created some deatached caches.
The cleaning was not correct.
The function essentially created some deatached caches.
I just copied them and did not do a merge, which is not so nice.
Furthermore, the tests are not yet there, in my view it makes sense to first have something that can be checked.
Before it was implementewd as a switch, for the case JAX would use the bool overload of XLA.
The check for this was now essentially moved inside the function.
However, a similar issue (#1644) in DaCe is still open.
It is now better confiugured.
Now let's test if it works.
This is basically for testing them.
Before the `order` argument was a `Literal` but this caused more truble now it is a string.
I enabled the simplify pass in commit `411bd7bd` and it worked locally.
However, this was because I was not running it inside nox and using my own version of DaCe.
The bug in simplify was fixed in [PR#1603](spcl/dace#1603) which was merged _after_ 16.1 was released, thus the fix is not avaliable.
It seems JAX has updated the `make_jaxpr()` function and now that thing caches itself.
This is now accounted for.
But the new function really tests if the loweing works, if teh strides are honored and infered.
philip-paul-mueller added a commit that referenced this pull request Sep 26, 2024
This PR introduces a series of primitive translators, most of them are
based on the prototype, with some improvements.
I just copied them over from the development branch, which is not so
nice, but was the simplest thing to to without also introducing the
other stuff.

It is important that the tests from the development branch were not
added, to keep the PR small.
Furthermore, we need something to test, so this PR must go first.

For organizational reasons, the development history of this PR happened
to be contained in [PR#21](#21).

---------

Co-authored-by: Enrique González Paredes <[email protected]>
@codecov-commenter
Copy link

codecov-commenter commented Oct 1, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@0a9f361). Learn more about missing BASE report.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #21   +/-   ##
=======================================
  Coverage        ?   88.66%           
=======================================
  Files           ?       31           
  Lines           ?     1235           
  Branches        ?      251           
=======================================
  Hits            ?     1095           
  Misses          ?       82           
  Partials        ?       58           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of review only looking at the test infrastructure.


- JaCe always traces with enabled `x64` mode.
This is a restriction that might be lifted in the future.
- JAX returns scalars as zero-dimensional arrays, JaCe returns them as array with shape `(1, )`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main issue is that DaCe does not have a real concept of zero dimensional arrays, as far as I know.
Consider the following two functions

@dace.program
def bar(a: dace.float64):
    return a + 1

@dace.program
def baz(a: dace.float64[1]):
    return a + 1

if you pass a zero dimensional array to bar() then it will be casted to an scalar, if you pass it to baz() an error will happen.
Furthermore, the binary interface of the SDFG can not return scalars, return values have to be arrays there is no way around that without patching the code generator and making a lot of changes to handle special cases.
So I decided to follow PEP20 and decided that this case is not special enough to change the rule.
If you want this feature then please open an issue.

- JaCe always traces with enabled `x64` mode.
This is a restriction that might be lifted in the future.
- JAX returns scalars as zero-dimensional arrays, JaCe returns them as array with shape `(1, )`.
- In JAX parts of the computation runs on CPU parts on GPU, in JaCe everything runs (currently) either on CPU or GPU.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here? Which parts run on CPU/GPU in JAX?

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX compiler, i.e. XLA can decide to do this.
The question is, if it really does it.

Comment on lines +15 to +16
- JaCe does not return `jax.Array` instances, but NumPy/CuPy arrays.
- The execution is not asynchronous.
Copy link
Collaborator

@egparedes egparedes Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two points could be also fixed in the future, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move fixtures from conftest.py here (in which case should be renamed to commom_fixtures.py) or remove this empty file.

from jace import optimization, stages
from jace.util import translation_cache as tcache


Copy link
Collaborator

@egparedes egparedes Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all fixtures in this module:

Additionally, I would define another fixture requesting all other fixtures expected in the standard case and use this group fixtures in the pytestmark at module level. Example:

# This file
@pytest.fixture
def standard_jace_test_settings(enable_x64_mode_in_jax, disable_jit, ...) -> ...:
     ....

# Other test files
pytestmark = pytest.mark.usefixtures("standard_jace_test_settings")

Finally, I'd create a simpler type alias for the return type of generator fixtures as suggested here :

T = TypeVar("T")
YieldFixture = Generator[T, None, None]

@pytest.fixture
def foo() -> YieldFixture[str]:
    yield "foo"

def make_array(
shape: Sequence[int] | int,
dtype: type = np.float64,
order: str = "C",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
order: str = "C",
order: Literal["C", "F"] = "C",

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the original, however, it has serious problems.
Whenever the order came from an argument (fixture to test if it works in both C and F order) then MyPy complained, and I had to add brainless # type: ignore[call-overload] marks (see commit 090c3a2).
I do not think that cluttering the code with these annotations just to get some tiny bit of type security is not worth it.

Comment on lines +27 to +28
low: Any = None,
high: Any = None,
Copy link
Collaborator

@egparedes egparedes Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not?

Suggested change
low: Any = None,
high: Any = None,
low: int | float | np.number = 0,
high: int | float | np.number = 1,

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yours is a bit better.
But it must be still accept None, otherwise you will only generate 0 or 1 in the integer case.
Furthermore, technically it must also accept complex, but they are an edge case that we ignore.



def make_array(
shape: Sequence[int] | int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shape: Sequence[int] | int,
shape: Sequence[int],

This is optional, but I would stick to one single way to define the shape to increase the readability of the function and help the type checker to catch errors here. In my opinion, there is not too much value in providing automatic tuple conversion for ints.

Copy link
Contributor Author

@philip-paul-mueller philip-paul-mueller Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think that this change would increase readability.
For my these ..., shape=(some_scalar_variable,), ... expression are much less readable than ..., shape=some_scalar_variable, ....
This typedef allows you to select the style that is most appropriate to the function.
Thus you do not have to replicate the "logic" that distinguish scalars from tuples everywhere you use this thing, but only at one central place.
Furthermore, it is the NumPy behaviour.

__all__ = ["make_array"]


def make_array(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question: this utility function is dealing with the generation of random array values with a certain shape, dtype and range, but what about packing them in the the correct ndarray type (numpy, cupy, Jax)? Isn't it needed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the end of the day, DaCe generates C code that operates on pointers to memory regions of a certain size.
Thus in the majority of cases you care about the data type and the size of the memory and not about the type of container that is used to store manage said memory it in the Python world.

However, you need to have tests to ensure that you can extract that memory from JAX, NumPy and CuPy arrays.
For JAX array we have:

  • tests/unit_tests/test_caching.py:test_caching_jax_numpy_array()
  • tests/unit_tests/test_jax_api.py:test_jax_array_as_input()

So in the end is composition at work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants