diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index cb2e6d0..0bb498e 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -15,4 +15,5 @@ jobs: virtualenvs-create: false installer-parallel: true - run: poetry install + - run: pip install --upgrade "jax[cpu]" - run: pytest -vv diff --git a/jpviz/dot/graph.py b/jpviz/dot/graph.py index 0797280..bba7ca9 100644 --- a/jpviz/dot/graph.py +++ b/jpviz/dot/graph.py @@ -96,20 +96,23 @@ def get_conditional( for i, branch in enumerate(conditional.params["branches"]): if len(branch.eqns) == 0: branch_graph_id = f"{cond_node_id}_branch_{i}" + label = f"Branch {i}" + if collapse_primitives: cond_graph.add_node( pydot.Node( name=branch_graph_id, - label=f"Branch {i}: Id", + label=label, **styling.FUNCTION_NODE_STYLING, ) ) - for var in branch.jaxpr.invars: + + for (var, p_var) in zip(branch.jaxpr.invars, conditional.invars[1:]): # TODO: What does the underscore mean? if str(var)[-1] == "_": continue cond_graph.add_edge( - pydot.Edge(f"{cond_graph_id}_{var}", branch_graph_id) + pydot.Edge(f"{cond_graph_id}_{p_var}", branch_graph_id) ) for var in conditional.outvars: cond_graph.add_edge( @@ -117,54 +120,47 @@ def get_conditional( ) else: branch_graph = graph_utils.get_subgraph( - f"cluster_{branch_graph_id}", f"Branch {i}" + f"cluster_{branch_graph_id}", label ) - for var, c_var in zip(branch.jaxpr.outvars, conditional.outvars): + for (var, p_var) in zip(branch.jaxpr.invars, conditional.invars[1:]): + # TODO: What does the underscore mean? + if str(var)[-1] == "_": + continue arg_id = f"{branch_graph_id}_{var}" branch_graph.add_node( graph_utils.get_var_node(arg_id, var, show_avals) ) - cond_graph.add_edge(pydot.Edge(f"{cond_graph_id}_{var}", arg_id)) + cond_graph.add_edge(pydot.Edge(f"{cond_graph_id}_{p_var}", arg_id)) + for var, c_var in zip(branch.jaxpr.outvars, conditional.outvars): + arg_id = f"{branch_graph_id}_{var}" cond_graph.add_edge(pydot.Edge(arg_id, f"{cond_graph_id}_{c_var}")) cond_graph.add_subgraph(branch_graph) - elif len(branch.eqns) == 1: - ( - branch_graph, - branch_in_edges, - branch_out_nodes, - branch_out_edges, - n, - ) = get_sub_graph( - branch.eqns[0], - cond_graph_id, - n, - collapse_primitives, - show_avals, - ) - branch_graph.set_label(f"Branch {i}: {branch_graph.get_label()}") - if isinstance(branch_graph, pydot.Subgraph): - cond_graph.add_subgraph(branch_graph) - else: - cond_graph.add_node(branch_graph) - for edge in branch_in_edges: - cond_graph.add_edge(edge) - for edge, var in zip(branch_out_edges, conditional.outvars): - cond_graph.add_edge( - pydot.Edge(edge.get_source(), f"{cond_graph_id}_{var}") - ) else: branch_graph_id = f"{cond_node_id}_branch_{i}" - branch_label = f"Branch {i}: λ" - if utils.contains_non_primitives(branch.eqns) or not collapse_primitives: + if len(branch.eqns) == 1: + eqn = branch.eqns[0] + branch_label = ( + eqn.params["name"] if "name" in eqn.params else eqn.primitive.name + ) + branch_label = f"Branch {i}: {branch_label}" + no_literal_inputs = any( + [isinstance(a, jax_core.Literal) for a in branch.jaxpr.invars] + ) + collapse_branch = no_literal_inputs or collapse_primitives + else: + branch_label = f"Branch {i}" + collapse_branch = collapse_primitives + + if utils.contains_non_primitives(branch.eqns) or not collapse_branch: branch_graph = graph_utils.get_subgraph( f"cluster_{branch_graph_id}", branch_label ) - branch_args, arg_edges = graph_utils.get_arguments( branch_graph_id, cond_graph_id, + branch.jaxpr.constvars, branch.jaxpr.invars, - branch.jaxpr.invars, + conditional.invars[1:], show_avals, ) for edge in arg_edges: @@ -226,13 +222,17 @@ def get_conditional( **styling.FUNCTION_NODE_STYLING, ) ) - for var in branch.jaxpr.invars: + for (var, p_var) in zip(branch.jaxpr.invars, conditional.invars[1:]): # TODO: What does the underscore mean? + if str(var)[-1] == "_": continue - cond_graph.add_edge( - pydot.Edge(f"{cond_graph_id}_{var}", branch_graph_id) - ) + + if not is_literal: + cond_graph.add_edge( + pydot.Edge(f"{cond_graph_id}_{p_var}", branch_graph_id) + ) + for var in conditional.outvars: cond_graph.add_edge( pydot.Edge(branch_graph_id, f"{cond_graph_id}_{var}") @@ -364,6 +364,7 @@ def expand_non_primitive( argument_nodes, argument_edges = graph_utils.get_arguments( graph_id, parent_id, + eqn.params["jaxpr"].jaxpr.constvars, eqn.params["jaxpr"].jaxpr.invars, eqn.invars, show_avals, diff --git a/jpviz/dot/graph_utils.py b/jpviz/dot/graph_utils.py index 05f4cef..c6b0535 100644 --- a/jpviz/dot/graph_utils.py +++ b/jpviz/dot/graph_utils.py @@ -20,7 +20,7 @@ def get_arg_node( arg_id: str Unique ID of the node var: jax._src.core.Var - JAX variable of literal instance + JAX variable or literal instance show_avals: bool If `True` show the type in the node is_literal: True @@ -39,6 +39,34 @@ def get_arg_node( ) +def get_const_node( + arg_id: str, + var: typing.Union[jax_core.Var, jax_core.Literal], + show_avals: bool, +) -> pydot.Node: + """ + Return a pydot node representing a function const arg + + Parameters + ---------- + arg_id: str + Unique ID of the node + var: jax._src.core.Var + JAX variable + show_avals: bool + If `True` show the type in the node + + Returns + ------- + pydot.Node + """ + return pydot.Node( + name=arg_id, + label=utils.get_node_label(var, show_avals), + **styling.CONST_ARG_STYLING, + ) + + def get_var_node(var_id: str, var: jax_core.Var, show_avals: bool) -> pydot.Node: """ Get a pydot node representing a variable internal to a function @@ -113,6 +141,7 @@ def get_subgraph(graph_id: str, label: str) -> pydot.Subgraph: def get_arguments( graph_id: str, parent_id: str, + graph_consts: typing.List[jax_core.Var], graph_invars: typing.List[jax_core.Var], parent_invars: typing.List[jax_core.Var], show_avals: bool, @@ -127,6 +156,8 @@ def get_arguments( ID of the subgraph that owns the arguments parent_id: str ID of the parent of the subgraph + graph_consts: List[jax._src.core.Var] + List of graph const-vars graph_invars: List[jax._src.core.Var] List of input variables to the subgraph parent_invars: List[jax._src.core.Var] @@ -144,6 +175,10 @@ def get_arguments( argument_nodes = pydot.Subgraph(f"{graph_id}_args", rank="same") argument_edges = list() + for var in graph_consts: + arg_id = f"{graph_id}_{var}" + argument_nodes.add_node(get_const_node(arg_id, var, show_avals)) + for var, p_var in zip(graph_invars, parent_invars): # TODO: What does the underscore mean? if str(var)[-1] == "_": @@ -200,6 +235,9 @@ def get_scan_arguments( carry_nodes = pydot.Subgraph( f"cluster_{graph_id}_init", rank="same", label="init", style="dotted" ) + iterate_nodes = pydot.Subgraph( + f"cluster_{graph_id}_iter", rank="same", label="iterate", style="dotted" + ) argument_edges = list() for i, (var, p_var) in enumerate(zip(graph_invars, parent_invars)): @@ -219,6 +257,10 @@ def get_scan_arguments( if n_const <= i < n_carry + n_const: carry_nodes.add_node(get_arg_node(arg_id, var, show_avals, var_is_literal)) + elif i >= n_carry + n_const: + iterate_nodes.add_node( + get_arg_node(arg_id, var, show_avals, var_is_literal) + ) else: argument_nodes.add_node( get_arg_node(arg_id, var, show_avals, var_is_literal) @@ -227,6 +269,7 @@ def get_scan_arguments( argument_edges.append(pydot.Edge(f"{parent_id}_{p_var}", arg_id)) argument_nodes.add_subgraph(carry_nodes) + argument_nodes.add_subgraph(iterate_nodes) return argument_nodes, argument_edges @@ -357,6 +400,9 @@ def get_scan_outputs( carry_nodes = pydot.Subgraph( f"cluster_{graph_id}_carry", rank="same", label="carry", style="dotted" ) + accumulate_nodes = pydot.Subgraph( + f"cluster_{graph_id}_acc", rank="same", label="Accumulate", style="dotted" + ) out_edges = list() out_nodes = list() id_edges = list() @@ -371,9 +417,10 @@ def get_scan_outputs( if i < n_carry: carry_nodes.add_node(get_out_node(arg_id, var, show_avals)) else: - out_graph.add_node(get_out_node(arg_id, var, show_avals)) + accumulate_nodes.add_node(get_out_node(arg_id, var, show_avals)) out_edges.append(pydot.Edge(arg_id, f"{parent_id}_{p_var}")) out_nodes.append(get_var_node(f"{parent_id}_{p_var}", p_var, show_avals)) out_graph.add_subgraph(carry_nodes) + out_graph.add_subgraph(accumulate_nodes) return out_graph, out_edges, out_nodes, id_edges diff --git a/jpviz/dot/styling.py b/jpviz/dot/styling.py index b0a7b42..6c5ba5a 100644 --- a/jpviz/dot/styling.py +++ b/jpviz/dot/styling.py @@ -16,6 +16,12 @@ fontname="Courier", fontsize="10", ) +CONST_ARG_STYLING = dict( + shape="box", + color="darkgreen", + fontname="Courier", + fontsize="10", +) OUT_ARG_STYLING = dict( shape="box", color="red", diff --git a/tests/dummy_test.py b/tests/dummy_test.py deleted file mode 100644 index f174823..0000000 --- a/tests/dummy_test.py +++ /dev/null @@ -1,2 +0,0 @@ -def test(): - pass diff --git a/tests/test_cases.py b/tests/test_cases.py new file mode 100644 index 0000000..a0eb741 --- /dev/null +++ b/tests/test_cases.py @@ -0,0 +1,95 @@ +import jax +import jax.numpy as jnp +import pytest + +import jpviz + + +@jax.jit +def func1(first, second): + temp = first + jnp.sin(second) * 3.0 + return jnp.sum(temp) + + +def func2(inner, first, second): + temp = first + inner(second) * 3.0 + return jnp.sum(temp) + + +def inner_func(second): + if second.shape[0] > 4: + return jnp.sin(second) + else: + assert False + + +@jax.jit +def func3(first, second): + return func2(inner_func, first, second) + + +@jax.jit +def func4(arg): + temp = arg[0] + jnp.sin(arg[1]) * 3.0 + return jnp.sum(temp) + + +@jax.jit +def one_of_three(index, arg): + return jax.lax.switch( + index, [lambda x: x + 1.0, lambda x: x - 2.0, lambda x: x + 3.0], arg + ) + + +@jax.jit +def func7(arg): + return jax.lax.cond( + arg >= 0.0, lambda x_true: x_true + 3.0, lambda x_false: x_false - 3.0, arg + ) + + +@jax.jit +def func8(arg1, arg2): + return jax.lax.cond( + arg1 >= 0.0, + lambda x_true: x_true[0], + lambda x_false: jnp.array([1]) + x_false[1], + arg2, + ) + + +@jax.jit +def func10(arg, n): + ones = jnp.ones(arg.shape) + return jax.lax.fori_loop( + 0, n, lambda i, carry: carry + ones * 3.0 + arg, arg + ones + ) + + +@jax.jit +def func11(arr, extra): + ones = jnp.ones(arr.shape) + + def body(carry, a_elems): + ae1, ae2 = a_elems + return carry + ae1 * ae2 + extra, carry + + return jax.lax.scan(body, 0.0, (arr, ones)) + + +test_cases = [ + (func1, [jnp.zeros(8), jnp.ones(8)]), + (func3, [jnp.zeros(8), jnp.ones(8)]), + (func4, [(jnp.zeros(8), jnp.ones(8))]), + (one_of_three, [1, 5.0]), + (func7, [5.0]), + (func8, [5.0, (jnp.zeros(1), 2.0)]), + (func10, [jnp.ones(16), 5]), + (func11, [jnp.ones(16), 5.0]), +] + + +@pytest.mark.parametrize("f, args", test_cases) +def test_works(f, args): + _ = jpviz.draw(f, collapse_primitives=True)(*args) + _ = jpviz.draw(f, collapse_primitives=False)(*args)