Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 18, 2024
1 parent 3ae3014 commit d37665a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(config: VizGpt2Config):
def compute_log_probs(model: LmHeadModel, example: LmExample):
model = inference_mode(model, True)
model = mp.cast_to_compute(model)
logprobs = compute_next_token_loss(model, example, reduction=None)
logprobs, where, _ = compute_next_token_loss(model, example)
# roll forward to get the loss for each predicted token
logprobs = hax.roll(logprobs, 1, Pos)
return logprobs.rearrange((EvalBatch, Pos)).array
Expand Down

0 comments on commit d37665a

Please sign in to comment.