Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 18, 2024
1 parent 3ae3014 commit 3e60320
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
4 changes: 1 addition & 3 deletions src/levanter/main/train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def compute_loss(
example: AudioTextExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jax.numpy.ndarray | hax.NamedArray:
return m.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis)
return m.compute_loss(example, key=key)

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(config: VizGpt2Config):
def compute_log_probs(model: LmHeadModel, example: LmExample):
model = inference_mode(model, True)
model = mp.cast_to_compute(model)
logprobs = compute_next_token_loss(model, example, reduction=None)
logprobs, where, _ = compute_next_token_loss(model, example)
# roll forward to get the loss for each predicted token
logprobs = hax.roll(logprobs, 1, Pos)
return logprobs.rearrange((EvalBatch, Pos)).array
Expand Down
12 changes: 7 additions & 5 deletions src/levanter/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmConfig
from levanter.utils.types import Extras


class AudioTextExample(eqx.Module):
Expand Down Expand Up @@ -97,9 +98,7 @@ def compute_loss(
example: AudioTextExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jnp.ndarray | NamedArray:
) -> tuple[jnp.ndarray | NamedArray, NamedArray, Extras]:
"""
Computes the cross-entropy loss for predicted ASR tokens. If reduction is not None, the loss is reduced
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
Expand All @@ -110,10 +109,13 @@ def compute_loss(
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
loss = cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
logits,
self.Vocab,
target_y,
reduction=None,
)

return loss
return loss, example.loss_mask, {}

@property
def vocab_size(self) -> int:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def test_lm_example_handles_ignore_id():
lm_head = hax.zeros((Embed, Vocab))
lm_head = lm_head.at[Vocab, ignore_id].set(-100)

ignored_loss = maybe_fused_next_token_loss(
ignored_loss, ignored_where = maybe_fused_next_token_loss(
Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask
)
no_ignore_loss = maybe_fused_next_token_loss(
ignored_loss = hax.sum(ignored_loss, where=ignored_where)
no_ignore_loss, no_ignore_where = maybe_fused_next_token_loss(
Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask
)
no_ignore_loss = hax.sum(no_ignore_loss, where=no_ignore_where)

assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size

Expand Down

0 comments on commit 3e60320

Please sign in to comment.