diff --git a/praxis/layers/transformer_models.py b/praxis/layers/transformer_models.py index 8709b270..d4c18b24 100644 --- a/praxis/layers/transformer_models.py +++ b/praxis/layers/transformer_models.py @@ -362,7 +362,9 @@ class TransformerLm(base_layer.BaseLayer): Used in cases where the Model class does this, to prevent double counting. (defaults to False) record_activations_in_xent_output: If true, record activations in the - compute_loss output, so we have both the activations and logits available. + compute_loss output, so we have activations available. + record_logits_in_xent_output: If true, record logits in the compute_loss + output, so we have logits available. entropy_loss_weight: If not None, an entropy loss is added to training. """ @@ -386,6 +388,7 @@ class TransformerLm(base_layer.BaseLayer): skip_compute_loss: bool = False skip_aux_loss: bool = False record_activations_in_xent_output: bool = False + record_logits_in_xent_output: bool = False entropy_loss_weight: float | None = None @classmethod @@ -920,6 +923,9 @@ def compute_loss( # callers). if self.record_activations_in_xent_output: xent_output.activations = activations + if self.record_logits_in_xent_output: + logits = self.softmax.get_logits(inputs=activations, input_ids=input_ids) + xent_output.logits = logits return xent_output def _prepare_input(