Skip to content

Commit

Permalink
Fix rank of partition specs for count variable with OptaxOptimizer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623276268
  • Loading branch information
laurentes authored and pax authors committed Apr 9, 2024
1 parent bfc7a0e commit ffd4245
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions praxis/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,20 @@ def sharding_function(
return None
return spec

def sharding_for_non_parameter_fields(_: Any) -> base_layer.WeightHParams:
def sharding_for_non_parameter_fields(
params: Any,
) -> base_layer.WeightHParams:
# Typically the non-parameter fields are 'count'.
return base_layer.WeightHParams(
shape=[],
# Preserve the same shape as the original params shape.
# Specifically, the non-parameter `count` must have partition specs
# of the same length as their rank.
# Note that we set the shape rather than `repeat_prefix` as opposed to
# regular Pax optimizers, since there is no easy way to retrieve a
# `repeat_prefix_split_dims_mapping` at that point. This implies that
# prefix vectorization won't be applied to such non parameter fields.
# In practice, those are small arrays, so this shouldn't be a problem.
shape=params.shape,
init=None,
dtype=jnp.int32,
collections=None,
Expand Down

0 comments on commit ffd4245

Please sign in to comment.