Skip to content

Commit

Permalink
fix modifier sf calculation for a single modifier in a group
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Mar 21, 2024
1 parent f1a7cb5 commit a5166bf
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def scale_factor(self, hist: Array) -> SF:
groups[hash(jtu.tree_structure(mod))].append(mod)
# then do the `jax.lax.scan` loops
for _, group_mods in groups.items():
# skip empty groups
if not group_mods:
continue
# Essentially we are turning an array of modifiers into a single modifier of stacked leaves.
# Then we can use XLA's loop constructs (e.g.: `jax.lax.scan`) to calculate the scale factors
# without having to compile the fully unrolled loop.
Expand All @@ -366,12 +369,18 @@ def calc_sf(_hist, _dynamic_stack, _static_stack):
sf = stack.scale_factor(_hist)
return _hist, sf

_, sf = jax.lax.scan(
partial(calc_sf, _static_stack=static_stack),
hist,
dynamic_stack,
)
multiplicative_sf *= jnp.prod(sf.multiplicative, axis=0)
additive_sf += jnp.sum(sf.additive, axis=0)
# if there is only one modifier in the group, we can skip the scan
if len(group_mods) == 1:
_, sf = calc_sf(hist, dynamic_stack, static_stack)
multiplicative_sf *= sf.multiplicative
additive_sf *= sf.additive
else:
_, sf = jax.lax.scan(
partial(calc_sf, _static_stack=static_stack),
hist,
dynamic_stack,
)
multiplicative_sf *= jnp.prod(sf.multiplicative, axis=0)
additive_sf += jnp.sum(sf.additive, axis=0)

return SF(multiplicative=multiplicative_sf, additive=additive_sf)

0 comments on commit a5166bf

Please sign in to comment.