Skip to content

Commit

Permalink
Merge pull request #71 from hx89:grok_fp8_dispatch
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640545759
  • Loading branch information
pax authors committed Jun 5, 2024
2 parents 8b9cec6 + 84fb6ab commit 32b9236
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion praxis/layers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
if self.gating_func in ['top2', 'expert_choice_v2']:
combine_tensor = self._split(combine_tensor, ap.gsec)
dispatch_tensor = self._split(dispatch_tensor, ap.gsec)
expert_inputs = jnp.einsum(
expert_inputs = self.einsum(
'gsec,gsm->egcm', dispatch_tensor, reshaped_inputs
)
elif self.gating_func == 'expert_choice':
Expand Down

0 comments on commit 32b9236

Please sign in to comment.