-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add test functions * Tweak conditional node behaviour * Fix id switch and include const-vars * Don't collapse branches with literal inputs * Cluster scan arguments and outputs * Run test cases * Install jaxlib for tests --------- Co-authored-by: zombie-einstein <[email protected]>
- Loading branch information
1 parent
1c74060
commit 632a05a
Showing
6 changed files
with
191 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
import jpviz | ||
|
||
|
||
@jax.jit | ||
def func1(first, second): | ||
temp = first + jnp.sin(second) * 3.0 | ||
return jnp.sum(temp) | ||
|
||
|
||
def func2(inner, first, second): | ||
temp = first + inner(second) * 3.0 | ||
return jnp.sum(temp) | ||
|
||
|
||
def inner_func(second): | ||
if second.shape[0] > 4: | ||
return jnp.sin(second) | ||
else: | ||
assert False | ||
|
||
|
||
@jax.jit | ||
def func3(first, second): | ||
return func2(inner_func, first, second) | ||
|
||
|
||
@jax.jit | ||
def func4(arg): | ||
temp = arg[0] + jnp.sin(arg[1]) * 3.0 | ||
return jnp.sum(temp) | ||
|
||
|
||
@jax.jit | ||
def one_of_three(index, arg): | ||
return jax.lax.switch( | ||
index, [lambda x: x + 1.0, lambda x: x - 2.0, lambda x: x + 3.0], arg | ||
) | ||
|
||
|
||
@jax.jit | ||
def func7(arg): | ||
return jax.lax.cond( | ||
arg >= 0.0, lambda x_true: x_true + 3.0, lambda x_false: x_false - 3.0, arg | ||
) | ||
|
||
|
||
@jax.jit | ||
def func8(arg1, arg2): | ||
return jax.lax.cond( | ||
arg1 >= 0.0, | ||
lambda x_true: x_true[0], | ||
lambda x_false: jnp.array([1]) + x_false[1], | ||
arg2, | ||
) | ||
|
||
|
||
@jax.jit | ||
def func10(arg, n): | ||
ones = jnp.ones(arg.shape) | ||
return jax.lax.fori_loop( | ||
0, n, lambda i, carry: carry + ones * 3.0 + arg, arg + ones | ||
) | ||
|
||
|
||
@jax.jit | ||
def func11(arr, extra): | ||
ones = jnp.ones(arr.shape) | ||
|
||
def body(carry, a_elems): | ||
ae1, ae2 = a_elems | ||
return carry + ae1 * ae2 + extra, carry | ||
|
||
return jax.lax.scan(body, 0.0, (arr, ones)) | ||
|
||
|
||
test_cases = [ | ||
(func1, [jnp.zeros(8), jnp.ones(8)]), | ||
(func3, [jnp.zeros(8), jnp.ones(8)]), | ||
(func4, [(jnp.zeros(8), jnp.ones(8))]), | ||
(one_of_three, [1, 5.0]), | ||
(func7, [5.0]), | ||
(func8, [5.0, (jnp.zeros(1), 2.0)]), | ||
(func10, [jnp.ones(16), 5]), | ||
(func11, [jnp.ones(16), 5.0]), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("f, args", test_cases) | ||
def test_works(f, args): | ||
_ = jpviz.draw(f, collapse_primitives=True)(*args) | ||
_ = jpviz.draw(f, collapse_primitives=False)(*args) |