Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan primitive implementation missing #29

Open
colehaus opened this issue Sep 4, 2024 · 4 comments
Open

scan primitive implementation missing #29

colehaus opened this issue Sep 4, 2024 · 4 comments

Comments

@colehaus
Copy link

colehaus commented Sep 4, 2024

Looking at

@register(jax.lax.while_p)
, I assume scan is supposed to be implemented generically for all Quax objects. I took a quick try at an implementation like this:

@quax.register(lax.scan_p)
def _(
    *args: Union[quax.ArrayValue, ArrayLike],
    reverse: bool,
    length: int,
    jaxpr,
    num_consts: int,
    num_carry: int,
    linear,
    unroll: int = 1,
    _split_transpose: Optional[bool] = None,
):
    consts = args[:num_consts]
    init = args[num_consts : num_consts + num_carry]
    xs = args[num_consts + num_carry :]

    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))
    quax_jaxpr = jax.make_jaxpr(quax_f)(*consts, *init, *xs)

    const_leaves, _ = jtu.tree_flatten(consts)
    init_leaves, init_treedef = jtu.tree_flatten(init)
    xs_leaves, _ = jtu.tree_flatten(xs)

    out_flat = lax.scan_p.bind(
        *const_leaves,
        *init_leaves,
        *xs_leaves,
        reverse=reverse,
        length=length,
        jaxpr=quax_jaxpr,
        num_consts=num_consts,
        num_carry=num_carry,
        linear=linear,
        unroll=unroll,
        _split_transpose=_split_transpose,
    )

    # _initial_style_jaxpr(quax_f, , , "scan")
    carry_nvals = len(init_leaves)
    carry, ys = out_flat[:carry_nvals], out_flat[carry_nvals:]

    carry_out = jtu.tree_unflatten(init_treedef, carry)

    return carry_out, None

But there are at least two problems:

  • When I actually use it, I get:
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:194, in <listcomp>(.0)
    [192](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:192)         out = method(*values, **params)
    [193](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:193) if primitive.multiple_results:
--> [194](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:194)     return [_QuaxTracer(self, _wrap_if_array(x)) for x in out]  # pyright: ignore
    [195](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:195) else:
    [196](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:196)     return _QuaxTracer(self, _wrap_if_array(out))

File /usr/local/lib/python3.11/dist-packages/quax/_core.py:84, in _QuaxTracer.__init__(self, trace, value)
     [83](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:83) def __init__(self, trace: "_QuaxTrace", value: "Value"):
---> [84](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:84)     assert _is_value(value)
     [85](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:85)     self._trace = trace
     [86](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:86)     self.value = value

Having scan available is pretty handy for use with the scan over layers technique.

@patrick-kidger
Copy link
Owner

Very happy to take a PR on this. But I'm not immediately sure what the issue is here though I'm afraid!

@colehaus
Copy link
Author

colehaus commented Sep 4, 2024

Something like this seems to roughly work:

@quax.register(lax.scan_p)
def _(
    *args: Union[quax.ArrayValue, ArrayLike],
    reverse: bool,
    length: int,
    jaxpr,
    num_consts: int,
    num_carry: int,
    linear,
    unroll: int = 1,
    _split_transpose: Optional[bool] = None,
):
    const = args[:num_consts]
    init = args[num_consts : num_consts + num_carry]
    xs = args[num_consts + num_carry :]

    const_flat, _ = jtu.tree_flatten(const)
    const_avals = tuple(safe_map(_abstractify, const_flat))

    xs_flat, _ = jtu.tree_flatten(xs)
    xs_avals = tuple(safe_map(_abstractify, xs_flat))
    x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]

    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))

    init_flat, init_treedef = jtu.tree_flatten(init)
    carry_avals = tuple(safe_map(_abstractify, init_flat))

    in_flat, in_treedef = jtu.tree_flatten(const + init + xs)
    jaxpr, consts, out_treedef = _initial_style_jaxpr(quax_f, in_treedef, (*const_avals, *carry_avals, *x_avals), "scan")

    out_flat = lax.scan_p.bind(
        *consts,
        *in_flat,
        reverse=reverse,
        length=length,
        jaxpr=jaxpr,
        num_consts=num_consts,
        num_carry=num_carry,
        linear=(False,) * (len(consts) + len(in_flat)),
        unroll=unroll,
        _split_transpose=_split_transpose,
    )
    carry_out = jtu.tree_unflatten(init_treedef, out_flat[: init_treedef.num_leaves])
    num_extra_outs = out_treedef.num_leaves - init_treedef.num_leaves
    flat_structure = jtu.tree_structure((0,) * num_extra_outs)
    extra_out = jtu.tree_unflatten(flat_structure, out_flat[init_treedef.num_leaves :])
    return carry_out[0] if len(init) == 1 else carry_out, extra_out[0] if num_extra_outs == 1 else extra_out

Two problems:

  • If the first output value is a nested pytree and not a simple array, we get an error:
File /usr/local/lib/python3.11/dist-packages/quax/_core.py:194, in <listcomp>(.0)
   [192](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:192)         out = method(*values, **params)
   [193](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:193) if primitive.multiple_results:
--> [194](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:194)     return [_QuaxTracer(self, _wrap_if_array(x)) for x in out]  # pyright: ignore
   [195](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:195) else:
   [196](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:196)     return _QuaxTracer(self, _wrap_if_array(out))

File /usr/local/lib/python3.11/dist-packages/quax/_core.py:84, in _QuaxTracer.__init__(self, trace, value)
    [83](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:83) def __init__(self, trace: "_QuaxTrace", value: "Value"):
---> [84](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:84)     assert _is_value(value)
    [85](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:85)     self._trace = trace
    [86](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.11/dist-packages/quax/_core.py:86)     self.value = value

I think this might be a mistake in Quax and the wrapping and asserting need to be tree-mapped over x?

  • I don't think we have any easy way to find out what the proper PytreeDef is for the full output. We can infer the PytreeDef for the carry part of the output from the input. But the extra output (b) structure comes from the function itself while the version of the function reconstructed from jaxpr_as_fun doesn't have this structure and only produces flat leaves. This rough implementation doesn't respect the structure passed in from the user and always return the extra output as a flat tuple with the corresponding number of leaves.

@colehaus
Copy link
Author

colehaus commented Sep 5, 2024

And here's an implementation of remat that seems to work (important for use with scan: https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes):

@quax.register(jax._src.ad_checkpoint.remat_p)
def _(*args, jaxpr, prevent_cse, differentiated, policy):
    del prevent_cse, differentiated, policy
    # `jaxpr_as_fun` expects a closed jaxpr. `scan_p` already gets one but `remat_p` doesn't.
    closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
    quax_f = quax.quaxify(jax.core.jaxpr_as_fun(closed_jaxpr))
    in_flat, in_treedef = jtu.tree_flatten(args)
    in_avals = tuple(safe_map(_abstractify, in_flat))
    quax_jaxpr, consts, out_tree = _initial_style_jaxpr(quax_f, in_treedef, in_avals, "remat")
    out_flat = jax.core.eval_jaxpr(quax_jaxpr.jaxpr, (), *consts, *in_flat)
    return jtu.tree_unflatten(out_tree, out_flat)

@patrick-kidger
Copy link
Owner

I think this might be a mistake in Quax and the wrapping and asserting need to be tree-mapped over x?

I don't think so. Primitive binds can produce either a single array or a sequence of arrays. By the time we're in the JAX internals like this then pytrees have largely disappeared.

Anyway, these broadly all look good to me! I'd be happy to take PRs on these, including tests for the kinds of edge cases you're bumping into.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants