Skip to content

Commit

Permalink
Fixing types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613808281
  • Loading branch information
bignamehyp authored and pax authors committed Mar 8, 2024
1 parent 493eb5c commit f48a941
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions praxis/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,12 +1058,12 @@ def sample_insert(
for attr in attrs:
update = getattr(prefix_decode_state, attr)
update = update[prefix_slot]

if (
getattr(decode_state, attr).dtype == jnp.float32
and update.dtype == jnp.bfloat16
):
update = update.astype(jnp.float32)
state_dtype = getattr(decode_state, attr).dtype
if state_dtype != update.dtype:
logging.info(
'%s changed from %s to %s:', attr, state_dtype, update.dtype
)
update = update.astype(state_dtype)

if update.ndim == 0:
update = jnp.expand_dims(update, axis=0)
Expand Down

0 comments on commit f48a941

Please sign in to comment.