Skip to content

Commit

Permalink
Port #642's loss changes to estimation.py (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 28, 2024
1 parent eae7f10 commit 603889a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def estimate_memory(job_config: JobConfig):
# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

# build model (using meta init)
Expand Down

0 comments on commit 603889a

Please sign in to comment.