Skip to content

Commit

Permalink
fix: sparse broadcasting dispatch
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored and patrick-kidger committed Aug 17, 2024
1 parent f31984a commit 81ba207
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions quax/examples/sparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,15 @@ def _(value: BCOO, *, broadcast_dimensions, shape) -> BCOO:
raise NotImplementedError(
"BCOO matrices only support broadcasting additional batch dimensions."
)
extra_batch_dims = shape[:n_extra_batch_dims]
data = jnp.broadcast_to(value.data, extra_batch_dims + value.data.shape)
indices = jnp.broadcast_to(value.indices, extra_batch_dims + value.indices.shape)
bdims = shape[:n_extra_batch_dims]
dims = jnp.broadcast_shapes(
(bdims + value.data.shape)[:-1],
(bdims + value.indices.shape)[:-2],
shape[: n_extra_batch_dims + len(value.data.shape) - 1],
)
data = jnp.broadcast_to(value.data, dims + value.data.shape[-1:])
indices = jnp.broadcast_to(value.indices, dims + value.indices.shape[-2:])

return BCOO(data, indices, shape, allow_materialise=value.allow_materialise)


Expand Down

0 comments on commit 81ba207

Please sign in to comment.