Skip to content

Commit

Permalink
adding a flag to return logits in the outputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667700864
  • Loading branch information
The praxis Authors committed Aug 26, 2024
1 parent 31a33e7 commit cd7d76d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion praxis/layers/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ class TransformerLm(base_layer.BaseLayer):
Used in cases where the Model class does this, to prevent double counting.
(defaults to False)
record_activations_in_xent_output: If true, record activations in the
compute_loss output, so we have both the activations and logits available.
compute_loss output, so we have activations available.
record_logits_in_xent_output: If true, record logits in the compute_loss
output, so we have logits available.
entropy_loss_weight: If not None, an entropy loss is added to training.
"""

Expand All @@ -386,6 +388,7 @@ class TransformerLm(base_layer.BaseLayer):
skip_compute_loss: bool = False
skip_aux_loss: bool = False
record_activations_in_xent_output: bool = False
record_logits_in_xent_output: bool = False
entropy_loss_weight: float | None = None

@classmethod
Expand Down Expand Up @@ -920,6 +923,9 @@ def compute_loss(
# callers).
if self.record_activations_in_xent_output:
xent_output.activations = activations
if self.record_logits_in_xent_output:
logits = self.softmax.get_logits(inputs=activations, input_ids=input_ids)
xent_output.logits = logits
return xent_output

def _prepare_input(
Expand Down

0 comments on commit cd7d76d

Please sign in to comment.