Skip to content

Commit

Permalink
Rename regression_metrics -> metrics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698903998
  • Loading branch information
xingyousong authored and copybara-github committed Nov 21, 2024
1 parent 3d18dcd commit 8522551
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
File renamed without changes.
6 changes: 3 additions & 3 deletions optformer/embed_then_regress/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from optformer.embed_then_regress import checkpointing as ckpt_lib
from optformer.embed_then_regress import configs
from optformer.embed_then_regress import icl_transformer
from optformer.embed_then_regress import regression_metrics
from optformer.embed_then_regress import metrics as metrics_lib
import tensorflow as tf


Expand Down Expand Up @@ -109,10 +109,10 @@ def loss_fn(
target_mask = 1 - batch['mask'] # [B, L]
target_nlogprob = nlogprob * target_mask # [B, L]

avg_nlogprob = regression_metrics.masked_mean(target_nlogprob, target_mask)
avg_nlogprob = metrics_lib.masked_mean(target_nlogprob, target_mask)
loss = jnp.mean(avg_nlogprob) # [B] -> Scalar

metrics = regression_metrics.default_metrics(mean, batch['y'], target_mask)
metrics = metrics_lib.default_metrics(mean, batch['y'], target_mask)
metrics['loss'] = loss
return loss, metrics

Expand Down

0 comments on commit 8522551

Please sign in to comment.