Skip to content

Commit

Permalink
Add more regression metrics for logging
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693561333
  • Loading branch information
xingyousong authored and copybara-github committed Nov 6, 2024
1 parent 840260a commit 86b717f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
74 changes: 74 additions & 0 deletions optformer/embed_then_regress/regression_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Compute regression-related metrics."""

import jax
import jax.numpy as jnp
import jaxtyping as jt

EPS = 1e-7
Scalar = jt.Float[jax.Array, '']


def masked_mean(
values: jt.Float[jax.Array, 'B L'],
target_mask: jt.Bool[jax.Array, 'B L'],
) -> jt.Float[jax.Array, 'B']:
"""Calculate means, only considering mask=True values."""
values = values * target_mask # [B, L]
return jnp.sum(values, axis=1) / jnp.sum(target_mask, axis=1) # [B]


def pointwise_mse(
mu: jt.Float[jax.Array, 'B L'],
ys: jt.Float[jax.Array, 'B L'],
target_mask: jt.Bool[jax.Array, 'B L'],
) -> Scalar:
"""Pointwise MSE."""
squared_error = jnp.square(ys - mu)
mse = masked_mean(squared_error, target_mask)
return jnp.mean(mse) # [B] -> Scalar


def pointwise_r2(
mu: jt.Float[jax.Array, 'B L'],
ys: jt.Float[jax.Array, 'B L'],
target_mask: jt.Bool[jax.Array, 'B L'],
) -> Scalar:
"""Pointwise R2."""
# Calculate centered values.
mu_centered = (mu - masked_mean(mu, target_mask)) * target_mask # [B, L]
ys_centered = (ys - masked_mean(ys, target_mask)) * target_mask # [B, L]

# Calculate covariance and standard deviations.
covariance = jnp.sum(mu_centered * ys_centered, axis=1) # [B]
std_mu = jnp.sqrt(jnp.sum(mu_centered**2, axis=1)) # [B]
std_ys = jnp.sqrt(jnp.sum(ys_centered**2, axis=1)) # [B]

# Calculate correlation coefficient
corrcoef = covariance / (std_mu * std_ys + EPS) # [B]
return jnp.mean(corrcoef**2) # [B] -> Scalar


def default_metrics(
mu: jt.Float[jax.Array, 'B L'],
ys: jt.Float[jax.Array, 'B L'],
target_mask: jt.Bool[jax.Array, 'B L'],
) -> dict[str, Scalar]:
"""Default metrics."""
return {
'pointwise_mse': pointwise_mse(mu, ys, target_mask),
'pointwise_r2': pointwise_r2(mu, ys, target_mask),
}
9 changes: 6 additions & 3 deletions optformer/embed_then_regress/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optformer.embed_then_regress import checkpointing as ckpt_lib
from optformer.embed_then_regress import configs
from optformer.embed_then_regress import icl_transformer
from optformer.embed_then_regress import regression_metrics
import tensorflow as tf


Expand Down Expand Up @@ -104,10 +105,12 @@ def loss_fn(
target_mask = 1 - batch['mask'] # [B, L]
target_nlogprob = nlogprob * target_mask # [B, L]

avg_nlogprob = jnp.sum(target_nlogprob, axis=1) / jnp.sum(target_mask, axis=1)
avg_nlogprob = regression_metrics.masked_mean(target_nlogprob, target_mask)
loss = jnp.mean(avg_nlogprob) # [B] -> Scalar
# TODO: Get more metrics.
return loss, {'loss': loss}

metrics = regression_metrics.default_metrics(mean, batch['y'], target_mask)
metrics['loss'] = loss
return loss, metrics


def train_step(
Expand Down

0 comments on commit 86b717f

Please sign in to comment.