Skip to content

Commit

Permalink
Merge pull request #70 from hx89:grok_fp8
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638339290
  • Loading branch information
pax authors committed May 29, 2024
2 parents c983f07 + 6cd99da commit 69444e8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
1 change: 1 addition & 0 deletions praxis/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ pytype_strict_library(
deps = [
":activations",
":attentions",
":base_ops",
":checkpoint_policy",
":linears",
":normalizations",
Expand Down
11 changes: 7 additions & 4 deletions praxis/layers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from praxis import pytypes
from praxis.layers import activations as activations_lib
from praxis.layers import attentions
from praxis.layers import base_ops
from praxis.layers import checkpoint_policy
from praxis.layers import linears
from praxis.layers import normalizations
Expand Down Expand Up @@ -660,6 +661,7 @@ class TransformerFeedForwardMoe(base_layer.BaseLayer):
gating_logit_cap: float = 0.0
moe_gating_embedding_level: str = 'token'
use_gated_activation: bool = False
einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp)

# SPMD partition related params.
# M - model_dim, for both inputs and outputs
Expand Down Expand Up @@ -823,6 +825,7 @@ def setup(self) -> None:
)
logging.debug('moe wo WeightHParams %s', wo_pc)
self.create_variable('wo_0', wo_pc)
self.create_child('einsum', self.einsum_tpl.clone())

def _split(self, t_in, sharding):
return base_layer.maybe_shard(t_in, sharding, self.mesh_axis_names)
Expand Down Expand Up @@ -1042,8 +1045,8 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
expert_inputs = self._split(expert_inputs, ap.egcm)

if self._is_ffn1_gated:
hidden0 = jnp.einsum('egcm,emh->egch', expert_inputs, theta_wi)
hidden1 = jnp.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated)
hidden0 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi)
hidden1 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated)
if self.gating_func in ['top2', 'expert_choice_v2']:
self._count_dead_neurons(hidden1, dispatch_tensor)
hidden1 = self.activation(hidden1)
Expand All @@ -1058,13 +1061,13 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
# Dropout.
hidden = self.relu_dropout(hidden)
# Output.
expert_output = jnp.einsum('egch,ehm->egcm', hidden, theta_wo)
expert_output = self.einsum('egch,ehm->egcm', hidden, theta_wo)
expert_output = self._split(expert_output, ap.egcm)
# Now transpose and reshard.
transposed_expert_output = jnp.einsum('egcm->gecm', expert_output)
transposed_expert_output = self._split(transposed_expert_output, ap.gecm)
if self.gating_func in ['top2', 'expert_choice_v2']:
combined_output = jnp.einsum(
combined_output = self.einsum(
'gecm,gsec->gsm', transposed_expert_output, combine_tensor
)
elif self.gating_func == 'expert_choice':
Expand Down

0 comments on commit 69444e8

Please sign in to comment.