diff --git a/praxis/layers/BUILD b/praxis/layers/BUILD index be546009..2b29b3fd 100644 --- a/praxis/layers/BUILD +++ b/praxis/layers/BUILD @@ -787,6 +787,7 @@ pytype_strict_library( deps = [ ":activations", ":attentions", + ":base_ops", ":checkpoint_policy", ":linears", ":normalizations", diff --git a/praxis/layers/transformers.py b/praxis/layers/transformers.py index e5931122..ff1ed63b 100644 --- a/praxis/layers/transformers.py +++ b/praxis/layers/transformers.py @@ -32,6 +32,7 @@ from praxis import pytypes from praxis.layers import activations as activations_lib from praxis.layers import attentions +from praxis.layers import base_ops from praxis.layers import checkpoint_policy from praxis.layers import linears from praxis.layers import normalizations @@ -660,6 +661,7 @@ class TransformerFeedForwardMoe(base_layer.BaseLayer): gating_logit_cap: float = 0.0 moe_gating_embedding_level: str = 'token' use_gated_activation: bool = False + einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) # SPMD partition related params. # M - model_dim, for both inputs and outputs @@ -823,6 +825,7 @@ def setup(self) -> None: ) logging.debug('moe wo WeightHParams %s', wo_pc) self.create_variable('wo_0', wo_pc) + self.create_child('einsum', self.einsum_tpl.clone()) def _split(self, t_in, sharding): return base_layer.maybe_shard(t_in, sharding, self.mesh_axis_names) @@ -1042,8 +1045,8 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids): expert_inputs = self._split(expert_inputs, ap.egcm) if self._is_ffn1_gated: - hidden0 = jnp.einsum('egcm,emh->egch', expert_inputs, theta_wi) - hidden1 = jnp.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated) + hidden0 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi) + hidden1 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated) if self.gating_func in ['top2', 'expert_choice_v2']: self._count_dead_neurons(hidden1, dispatch_tensor) hidden1 = self.activation(hidden1) @@ -1058,13 +1061,13 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids): # Dropout. hidden = self.relu_dropout(hidden) # Output. - expert_output = jnp.einsum('egch,ehm->egcm', hidden, theta_wo) + expert_output = self.einsum('egch,ehm->egcm', hidden, theta_wo) expert_output = self._split(expert_output, ap.egcm) # Now transpose and reshard. transposed_expert_output = jnp.einsum('egcm->gecm', expert_output) transposed_expert_output = self._split(transposed_expert_output, ap.gecm) if self.gating_func in ['top2', 'expert_choice_v2']: - combined_output = jnp.einsum( + combined_output = self.einsum( 'gecm,gsec->gsm', transposed_expert_output, combine_tensor ) elif self.gating_func == 'expert_choice':