Skip to content

Commit

Permalink
Test cases (#14)
Browse files Browse the repository at this point in the history
* Add test functions

* Tweak conditional node behaviour

* Fix id switch and include const-vars

* Don't collapse branches with literal inputs

* Cluster scan arguments and outputs

* Run test cases

* Install jaxlib for tests

---------

Co-authored-by: zombie-einstein <[email protected]>
  • Loading branch information
zombie-einstein and zombie-einstein authored Sep 10, 2023
1 parent 1c74060 commit 632a05a
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 43 deletions.
1 change: 1 addition & 0 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ jobs:
virtualenvs-create: false
installer-parallel: true
- run: poetry install
- run: pip install --upgrade "jax[cpu]"
- run: pytest -vv
79 changes: 40 additions & 39 deletions jpviz/dot/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,75 +96,71 @@ def get_conditional(
for i, branch in enumerate(conditional.params["branches"]):
if len(branch.eqns) == 0:
branch_graph_id = f"{cond_node_id}_branch_{i}"
label = f"Branch {i}"

if collapse_primitives:
cond_graph.add_node(
pydot.Node(
name=branch_graph_id,
label=f"Branch {i}: Id",
label=label,
**styling.FUNCTION_NODE_STYLING,
)
)
for var in branch.jaxpr.invars:

for (var, p_var) in zip(branch.jaxpr.invars, conditional.invars[1:]):
# TODO: What does the underscore mean?
if str(var)[-1] == "_":
continue
cond_graph.add_edge(
pydot.Edge(f"{cond_graph_id}_{var}", branch_graph_id)
pydot.Edge(f"{cond_graph_id}_{p_var}", branch_graph_id)
)
for var in conditional.outvars:
cond_graph.add_edge(
pydot.Edge(branch_graph_id, f"{cond_graph_id}_{var}")
)
else:
branch_graph = graph_utils.get_subgraph(
f"cluster_{branch_graph_id}", f"Branch {i}"
f"cluster_{branch_graph_id}", label
)
for var, c_var in zip(branch.jaxpr.outvars, conditional.outvars):
for (var, p_var) in zip(branch.jaxpr.invars, conditional.invars[1:]):
# TODO: What does the underscore mean?
if str(var)[-1] == "_":
continue
arg_id = f"{branch_graph_id}_{var}"
branch_graph.add_node(
graph_utils.get_var_node(arg_id, var, show_avals)
)
cond_graph.add_edge(pydot.Edge(f"{cond_graph_id}_{var}", arg_id))
cond_graph.add_edge(pydot.Edge(f"{cond_graph_id}_{p_var}", arg_id))
for var, c_var in zip(branch.jaxpr.outvars, conditional.outvars):
arg_id = f"{branch_graph_id}_{var}"
cond_graph.add_edge(pydot.Edge(arg_id, f"{cond_graph_id}_{c_var}"))
cond_graph.add_subgraph(branch_graph)
elif len(branch.eqns) == 1:
(
branch_graph,
branch_in_edges,
branch_out_nodes,
branch_out_edges,
n,
) = get_sub_graph(
branch.eqns[0],
cond_graph_id,
n,
collapse_primitives,
show_avals,
)
branch_graph.set_label(f"Branch {i}: {branch_graph.get_label()}")
if isinstance(branch_graph, pydot.Subgraph):
cond_graph.add_subgraph(branch_graph)
else:
cond_graph.add_node(branch_graph)
for edge in branch_in_edges:
cond_graph.add_edge(edge)
for edge, var in zip(branch_out_edges, conditional.outvars):
cond_graph.add_edge(
pydot.Edge(edge.get_source(), f"{cond_graph_id}_{var}")
)
else:
branch_graph_id = f"{cond_node_id}_branch_{i}"
branch_label = f"Branch {i}: λ"
if utils.contains_non_primitives(branch.eqns) or not collapse_primitives:
if len(branch.eqns) == 1:
eqn = branch.eqns[0]
branch_label = (
eqn.params["name"] if "name" in eqn.params else eqn.primitive.name
)
branch_label = f"Branch {i}: {branch_label}"
no_literal_inputs = any(
[isinstance(a, jax_core.Literal) for a in branch.jaxpr.invars]
)
collapse_branch = no_literal_inputs or collapse_primitives
else:
branch_label = f"Branch {i}"
collapse_branch = collapse_primitives

if utils.contains_non_primitives(branch.eqns) or not collapse_branch:
branch_graph = graph_utils.get_subgraph(
f"cluster_{branch_graph_id}", branch_label
)

branch_args, arg_edges = graph_utils.get_arguments(
branch_graph_id,
cond_graph_id,
branch.jaxpr.constvars,
branch.jaxpr.invars,
branch.jaxpr.invars,
conditional.invars[1:],
show_avals,
)
for edge in arg_edges:
Expand Down Expand Up @@ -226,13 +222,17 @@ def get_conditional(
**styling.FUNCTION_NODE_STYLING,
)
)
for var in branch.jaxpr.invars:
for (var, p_var) in zip(branch.jaxpr.invars, conditional.invars[1:]):
# TODO: What does the underscore mean?

if str(var)[-1] == "_":
continue
cond_graph.add_edge(
pydot.Edge(f"{cond_graph_id}_{var}", branch_graph_id)
)

if not is_literal:
cond_graph.add_edge(
pydot.Edge(f"{cond_graph_id}_{p_var}", branch_graph_id)
)

for var in conditional.outvars:
cond_graph.add_edge(
pydot.Edge(branch_graph_id, f"{cond_graph_id}_{var}")
Expand Down Expand Up @@ -364,6 +364,7 @@ def expand_non_primitive(
argument_nodes, argument_edges = graph_utils.get_arguments(
graph_id,
parent_id,
eqn.params["jaxpr"].jaxpr.constvars,
eqn.params["jaxpr"].jaxpr.invars,
eqn.invars,
show_avals,
Expand Down
51 changes: 49 additions & 2 deletions jpviz/dot/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_arg_node(
arg_id: str
Unique ID of the node
var: jax._src.core.Var
JAX variable of literal instance
JAX variable or literal instance
show_avals: bool
If `True` show the type in the node
is_literal: True
Expand All @@ -39,6 +39,34 @@ def get_arg_node(
)


def get_const_node(
arg_id: str,
var: typing.Union[jax_core.Var, jax_core.Literal],
show_avals: bool,
) -> pydot.Node:
"""
Return a pydot node representing a function const arg
Parameters
----------
arg_id: str
Unique ID of the node
var: jax._src.core.Var
JAX variable
show_avals: bool
If `True` show the type in the node
Returns
-------
pydot.Node
"""
return pydot.Node(
name=arg_id,
label=utils.get_node_label(var, show_avals),
**styling.CONST_ARG_STYLING,
)


def get_var_node(var_id: str, var: jax_core.Var, show_avals: bool) -> pydot.Node:
"""
Get a pydot node representing a variable internal to a function
Expand Down Expand Up @@ -113,6 +141,7 @@ def get_subgraph(graph_id: str, label: str) -> pydot.Subgraph:
def get_arguments(
graph_id: str,
parent_id: str,
graph_consts: typing.List[jax_core.Var],
graph_invars: typing.List[jax_core.Var],
parent_invars: typing.List[jax_core.Var],
show_avals: bool,
Expand All @@ -127,6 +156,8 @@ def get_arguments(
ID of the subgraph that owns the arguments
parent_id: str
ID of the parent of the subgraph
graph_consts: List[jax._src.core.Var]
List of graph const-vars
graph_invars: List[jax._src.core.Var]
List of input variables to the subgraph
parent_invars: List[jax._src.core.Var]
Expand All @@ -144,6 +175,10 @@ def get_arguments(
argument_nodes = pydot.Subgraph(f"{graph_id}_args", rank="same")
argument_edges = list()

for var in graph_consts:
arg_id = f"{graph_id}_{var}"
argument_nodes.add_node(get_const_node(arg_id, var, show_avals))

for var, p_var in zip(graph_invars, parent_invars):
# TODO: What does the underscore mean?
if str(var)[-1] == "_":
Expand Down Expand Up @@ -200,6 +235,9 @@ def get_scan_arguments(
carry_nodes = pydot.Subgraph(
f"cluster_{graph_id}_init", rank="same", label="init", style="dotted"
)
iterate_nodes = pydot.Subgraph(
f"cluster_{graph_id}_iter", rank="same", label="iterate", style="dotted"
)
argument_edges = list()

for i, (var, p_var) in enumerate(zip(graph_invars, parent_invars)):
Expand All @@ -219,6 +257,10 @@ def get_scan_arguments(

if n_const <= i < n_carry + n_const:
carry_nodes.add_node(get_arg_node(arg_id, var, show_avals, var_is_literal))
elif i >= n_carry + n_const:
iterate_nodes.add_node(
get_arg_node(arg_id, var, show_avals, var_is_literal)
)
else:
argument_nodes.add_node(
get_arg_node(arg_id, var, show_avals, var_is_literal)
Expand All @@ -227,6 +269,7 @@ def get_scan_arguments(
argument_edges.append(pydot.Edge(f"{parent_id}_{p_var}", arg_id))

argument_nodes.add_subgraph(carry_nodes)
argument_nodes.add_subgraph(iterate_nodes)
return argument_nodes, argument_edges


Expand Down Expand Up @@ -357,6 +400,9 @@ def get_scan_outputs(
carry_nodes = pydot.Subgraph(
f"cluster_{graph_id}_carry", rank="same", label="carry", style="dotted"
)
accumulate_nodes = pydot.Subgraph(
f"cluster_{graph_id}_acc", rank="same", label="Accumulate", style="dotted"
)
out_edges = list()
out_nodes = list()
id_edges = list()
Expand All @@ -371,9 +417,10 @@ def get_scan_outputs(
if i < n_carry:
carry_nodes.add_node(get_out_node(arg_id, var, show_avals))
else:
out_graph.add_node(get_out_node(arg_id, var, show_avals))
accumulate_nodes.add_node(get_out_node(arg_id, var, show_avals))
out_edges.append(pydot.Edge(arg_id, f"{parent_id}_{p_var}"))
out_nodes.append(get_var_node(f"{parent_id}_{p_var}", p_var, show_avals))

out_graph.add_subgraph(carry_nodes)
out_graph.add_subgraph(accumulate_nodes)
return out_graph, out_edges, out_nodes, id_edges
6 changes: 6 additions & 0 deletions jpviz/dot/styling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
fontname="Courier",
fontsize="10",
)
CONST_ARG_STYLING = dict(
shape="box",
color="darkgreen",
fontname="Courier",
fontsize="10",
)
OUT_ARG_STYLING = dict(
shape="box",
color="red",
Expand Down
2 changes: 0 additions & 2 deletions tests/dummy_test.py

This file was deleted.

95 changes: 95 additions & 0 deletions tests/test_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import jax
import jax.numpy as jnp
import pytest

import jpviz


@jax.jit
def func1(first, second):
temp = first + jnp.sin(second) * 3.0
return jnp.sum(temp)


def func2(inner, first, second):
temp = first + inner(second) * 3.0
return jnp.sum(temp)


def inner_func(second):
if second.shape[0] > 4:
return jnp.sin(second)
else:
assert False


@jax.jit
def func3(first, second):
return func2(inner_func, first, second)


@jax.jit
def func4(arg):
temp = arg[0] + jnp.sin(arg[1]) * 3.0
return jnp.sum(temp)


@jax.jit
def one_of_three(index, arg):
return jax.lax.switch(
index, [lambda x: x + 1.0, lambda x: x - 2.0, lambda x: x + 3.0], arg
)


@jax.jit
def func7(arg):
return jax.lax.cond(
arg >= 0.0, lambda x_true: x_true + 3.0, lambda x_false: x_false - 3.0, arg
)


@jax.jit
def func8(arg1, arg2):
return jax.lax.cond(
arg1 >= 0.0,
lambda x_true: x_true[0],
lambda x_false: jnp.array([1]) + x_false[1],
arg2,
)


@jax.jit
def func10(arg, n):
ones = jnp.ones(arg.shape)
return jax.lax.fori_loop(
0, n, lambda i, carry: carry + ones * 3.0 + arg, arg + ones
)


@jax.jit
def func11(arr, extra):
ones = jnp.ones(arr.shape)

def body(carry, a_elems):
ae1, ae2 = a_elems
return carry + ae1 * ae2 + extra, carry

return jax.lax.scan(body, 0.0, (arr, ones))


test_cases = [
(func1, [jnp.zeros(8), jnp.ones(8)]),
(func3, [jnp.zeros(8), jnp.ones(8)]),
(func4, [(jnp.zeros(8), jnp.ones(8))]),
(one_of_three, [1, 5.0]),
(func7, [5.0]),
(func8, [5.0, (jnp.zeros(1), 2.0)]),
(func10, [jnp.ones(16), 5]),
(func11, [jnp.ones(16), 5.0]),
]


@pytest.mark.parametrize("f, args", test_cases)
def test_works(f, args):
_ = jpviz.draw(f, collapse_primitives=True)(*args)
_ = jpviz.draw(f, collapse_primitives=False)(*args)

0 comments on commit 632a05a

Please sign in to comment.