-
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan
primitive implementation missing
#29
Comments
Very happy to take a PR on this. But I'm not immediately sure what the issue is here though I'm afraid! |
Something like this seems to roughly work: @quax.register(lax.scan_p)
def _(
*args: Union[quax.ArrayValue, ArrayLike],
reverse: bool,
length: int,
jaxpr,
num_consts: int,
num_carry: int,
linear,
unroll: int = 1,
_split_transpose: Optional[bool] = None,
):
const = args[:num_consts]
init = args[num_consts : num_consts + num_carry]
xs = args[num_consts + num_carry :]
const_flat, _ = jtu.tree_flatten(const)
const_avals = tuple(safe_map(_abstractify, const_flat))
xs_flat, _ = jtu.tree_flatten(xs)
xs_avals = tuple(safe_map(_abstractify, xs_flat))
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
quax_f = quax.quaxify(jax.core.jaxpr_as_fun(jaxpr))
init_flat, init_treedef = jtu.tree_flatten(init)
carry_avals = tuple(safe_map(_abstractify, init_flat))
in_flat, in_treedef = jtu.tree_flatten(const + init + xs)
jaxpr, consts, out_treedef = _initial_style_jaxpr(quax_f, in_treedef, (*const_avals, *carry_avals, *x_avals), "scan")
out_flat = lax.scan_p.bind(
*consts,
*in_flat,
reverse=reverse,
length=length,
jaxpr=jaxpr,
num_consts=num_consts,
num_carry=num_carry,
linear=(False,) * (len(consts) + len(in_flat)),
unroll=unroll,
_split_transpose=_split_transpose,
)
carry_out = jtu.tree_unflatten(init_treedef, out_flat[: init_treedef.num_leaves])
num_extra_outs = out_treedef.num_leaves - init_treedef.num_leaves
flat_structure = jtu.tree_structure((0,) * num_extra_outs)
extra_out = jtu.tree_unflatten(flat_structure, out_flat[init_treedef.num_leaves :])
return carry_out[0] if len(init) == 1 else carry_out, extra_out[0] if num_extra_outs == 1 else extra_out Two problems:
I think this might be a mistake in Quax and the wrapping and asserting need to be tree-mapped over
|
And here's an implementation of @quax.register(jax._src.ad_checkpoint.remat_p)
def _(*args, jaxpr, prevent_cse, differentiated, policy):
del prevent_cse, differentiated, policy
# `jaxpr_as_fun` expects a closed jaxpr. `scan_p` already gets one but `remat_p` doesn't.
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
quax_f = quax.quaxify(jax.core.jaxpr_as_fun(closed_jaxpr))
in_flat, in_treedef = jtu.tree_flatten(args)
in_avals = tuple(safe_map(_abstractify, in_flat))
quax_jaxpr, consts, out_tree = _initial_style_jaxpr(quax_f, in_treedef, in_avals, "remat")
out_flat = jax.core.eval_jaxpr(quax_jaxpr.jaxpr, (), *consts, *in_flat)
return jtu.tree_unflatten(out_tree, out_flat) |
I don't think so. Primitive binds can produce either a single array or a sequence of arrays. By the time we're in the JAX internals like this then pytrees have largely disappeared. Anyway, these broadly all look good to me! I'd be happy to take PRs on these, including tests for the kinds of edge cases you're bumping into. |
Looking at
quax/quax/_core.py
Line 551 in a9d875e
scan
is supposed to be implemented generically for all Quax objects. I took a quick try at an implementation like this:But there are at least two problems:
out_treedef
for the second element of the return value. I think this would be possible by following the same strategy as what's in https://github.com/google/jax/blob/ebc6c1815297c79bc1c9c907aaf858d70caef5e6/jax/_src/lax/control_flow/loops.py#L123, but I'm not sure if there's a simpler way.Having
scan
available is pretty handy for use with the scan over layers technique.The text was updated successfully, but these errors were encountered: