From 99a2f329d48f58d328ad0391bd200ff49a0a3a56 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Tue, 5 Nov 2024 19:49:09 -0800 Subject: [PATCH] Add more regression metrics for logging PiperOrigin-RevId: 693561333 --- .../embed_then_regress/regression_metrics.py | 78 +++++++++++++++++++ optformer/embed_then_regress/train.py | 9 ++- .../embed_then_regress/vizier/serializers.py | 2 +- 3 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 optformer/embed_then_regress/regression_metrics.py diff --git a/optformer/embed_then_regress/regression_metrics.py b/optformer/embed_then_regress/regression_metrics.py new file mode 100644 index 0000000..6df6e9d --- /dev/null +++ b/optformer/embed_then_regress/regression_metrics.py @@ -0,0 +1,78 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compute regression-related metrics.""" + +import jax +import jax.numpy as jnp +import jaxtyping as jt + +EPS = 1e-7 +Scalar = jt.Float[jax.Array, ''] + + +def masked_mean( + values: jt.Float[jax.Array, 'B L'], + target_mask: jt.Bool[jax.Array, 'B L'], +) -> jt.Float[jax.Array, 'B']: + """Calculate means, only considering mask=True values.""" + values = values * target_mask # [B, L] + return jnp.sum(values, axis=1) / jnp.sum(target_mask, axis=1) # [B] + + +def pointwise_mse( + mu: jt.Float[jax.Array, 'B L'], + ys: jt.Float[jax.Array, 'B L'], + target_mask: jt.Bool[jax.Array, 'B L'], +) -> Scalar: + """Pointwise MSE.""" + squared_error = jnp.square(ys - mu) + mse = masked_mean(squared_error, target_mask) + return jnp.mean(mse) # [B] -> Scalar + + +def pointwise_r2( + mu: jt.Float[jax.Array, 'B L'], + ys: jt.Float[jax.Array, 'B L'], + target_mask: jt.Bool[jax.Array, 'B L'], +) -> Scalar: + """Pointwise R2.""" + # Calculate centered values. + + mu_mean = jnp.expand_dims(masked_mean(mu, target_mask), axis=-1) # [B, 1] + ys_mean = jnp.expand_dims(masked_mean(ys, target_mask), axis=-1) # [B, 1] + + mu_centered = (mu - mu_mean) * target_mask # [B, L] + ys_centered = (ys - ys_mean) * target_mask # [B, L] + + # Calculate covariance and standard deviations. + covariance = jnp.sum(mu_centered * ys_centered, axis=1) # [B] + std_mu = jnp.sqrt(jnp.sum(mu_centered**2, axis=1)) # [B] + std_ys = jnp.sqrt(jnp.sum(ys_centered**2, axis=1)) # [B] + + # Calculate correlation coefficient + corrcoef = covariance / (std_mu * std_ys + EPS) # [B] + return jnp.mean(corrcoef**2) # [B] -> Scalar + + +def default_metrics( + mu: jt.Float[jax.Array, 'B L'], + ys: jt.Float[jax.Array, 'B L'], + target_mask: jt.Bool[jax.Array, 'B L'], +) -> dict[str, Scalar]: + """Default metrics.""" + return { + 'pointwise_mse': pointwise_mse(mu, ys, target_mask), + 'pointwise_r2': pointwise_r2(mu, ys, target_mask), + } diff --git a/optformer/embed_then_regress/train.py b/optformer/embed_then_regress/train.py index 75da1e0..5a97b80 100644 --- a/optformer/embed_then_regress/train.py +++ b/optformer/embed_then_regress/train.py @@ -27,6 +27,7 @@ from optformer.embed_then_regress import checkpointing as ckpt_lib from optformer.embed_then_regress import configs from optformer.embed_then_regress import icl_transformer +from optformer.embed_then_regress import regression_metrics import tensorflow as tf @@ -104,10 +105,12 @@ def loss_fn( target_mask = 1 - batch['mask'] # [B, L] target_nlogprob = nlogprob * target_mask # [B, L] - avg_nlogprob = jnp.sum(target_nlogprob, axis=1) / jnp.sum(target_mask, axis=1) + avg_nlogprob = regression_metrics.masked_mean(target_nlogprob, target_mask) loss = jnp.mean(avg_nlogprob) # [B] -> Scalar - # TODO: Get more metrics. - return loss, {'loss': loss} + + metrics = regression_metrics.default_metrics(mean, batch['y'], target_mask) + metrics['loss'] = loss + return loss, metrics def train_step( diff --git a/optformer/embed_then_regress/vizier/serializers.py b/optformer/embed_then_regress/vizier/serializers.py index eee31fc..1a2a423 100644 --- a/optformer/embed_then_regress/vizier/serializers.py +++ b/optformer/embed_then_regress/vizier/serializers.py @@ -45,7 +45,7 @@ def to_str(self, t: vz.TrialSuggestion, /) -> str: for pc in self.search_space.parameters: value = param_dict[pc.name] if isinstance(value, (float, int)): - float_format = '.2e' if self.use_scientific else '.2f' + float_format = '.4e' if self.use_scientific else '.4f' new_param_dict[pc.name] = format(value, float_format) else: new_param_dict[pc.name] = value