From 953a5ab0818f4f1498d5be13afc051d219dbe7e5 Mon Sep 17 00:00:00 2001 From: omiralles Date: Thu, 9 Jan 2025 11:54:20 +0100 Subject: [PATCH 1/2] Fix combined loss and test --- .../src/anemoi/training/losses/combined.py | 14 ++++++--- .../src/anemoi/training/train/forecaster.py | 31 +++++++++++++------ training/tests/train/test_loss_function.py | 14 +++++++++ 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/training/src/anemoi/training/losses/combined.py b/training/src/anemoi/training/losses/combined.py index 11d2b4fe..8832fcbf 100644 --- a/training/src/anemoi/training/losses/combined.py +++ b/training/src/anemoi/training/losses/combined.py @@ -81,10 +81,16 @@ def __init__( assert len(losses) == len(loss_weights), "Number of losses and weights must match" assert len(losses) > 0, "At least one loss must be provided" - self.losses = [ - GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs) - for loss in losses - ] + self.losses = [] + for loss in losses: + if isinstance(loss, dict): + self.losses.append(GraphForecaster.get_loss_function(loss, **kwargs)) + elif isinstance(loss, type): + self.losses.append(loss(**kwargs)) + elif hasattr(loss, "__class__"): + self.losses.append(loss) + else: + raise TypeError(f"Loss {loss} is not a valid loss function") self.loss_weights = loss_weights def forward( diff --git a/training/src/anemoi/training/train/forecaster.py b/training/src/anemoi/training/train/forecaster.py index e88db201..21fc784c 100644 --- a/training/src/anemoi/training/train/forecaster.py +++ b/training/src/anemoi/training/train/forecaster.py @@ -18,6 +18,9 @@ import numpy as np import pytorch_lightning as pl import torch +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.interface import AnemoiModelInterface +from anemoi.utils.config import DotDict from hydra.utils import instantiate from omegaconf import DictConfig from omegaconf import OmegaConf @@ -27,14 +30,11 @@ from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from anemoi.models.data_indices.collection import IndexCollection -from anemoi.models.interface import AnemoiModelInterface from anemoi.training.losses.utils import grad_scaler from anemoi.training.losses.weightedloss import BaseWeightedLoss from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.masks import Boolean1DMask from anemoi.training.utils.masks import NoOutputMask -from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) @@ -121,13 +121,24 @@ def __init__( "limited_area_mask": (2, limited_area_mask), } self.updated_loss_mask = False - - self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs) - - assert isinstance(self.loss, BaseWeightedLoss) and not isinstance( - self.loss, - torch.nn.ModuleList, - ), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}" + if config.training.training_loss._target_ == 'anemoi.training.losses.combined.CombinedLoss': + assert "loss_weights" in config.training.training_loss, "Loss weights must be provided for combined loss" + losses = [] + ignore_nans = config.training.training_loss.get("ignore_nans", False) # no point in doing this for each loss, nan+nan is nan + for loss in config.training.training_loss.losses: + node_weighting = instantiate(loss.node_weights) + loss_node_weights = node_weighting.weights(graph_data) + loss_node_weights = self.output_mask.apply(loss_node_weights, dim=0, fill_value=0.0) + loss_instantiated = self.get_loss_function(loss, scalars=self.scalars, **{"node_weights": loss_node_weights, "ignore_nans": ignore_nans}) + losses.append(loss_instantiated) + assert isinstance(loss_instantiated, BaseWeightedLoss) + self.loss = instantiate({"_target_": config.training.training_loss._target_}, losses=losses, loss_weights = config.training.training_loss.loss_weights, **loss_kwargs) + else: + self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs) + assert isinstance(self.loss, BaseWeightedLoss) and not isinstance( + self.loss, + torch.nn.ModuleList, + ), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}" self.metrics = self.get_loss_function(config.training.validation_metrics, scalars=self.scalars, **loss_kwargs) if not isinstance(self.metrics, torch.nn.ModuleList): diff --git a/training/tests/train/test_loss_function.py b/training/tests/train/test_loss_function.py index 73d7f246..d75436f1 100644 --- a/training/tests/train/test_loss_function.py +++ b/training/tests/train/test_loss_function.py @@ -12,6 +12,8 @@ from omegaconf import DictConfig from anemoi.training.losses.mse import WeightedMSELoss +from anemoi.training.losses.mae import WeightedMAELoss +from anemoi.training.losses.combined import CombinedLoss from anemoi.training.losses.weightedloss import BaseWeightedLoss from anemoi.training.train.forecaster import GraphForecaster @@ -62,3 +64,15 @@ def test_dynamic_init_scalar_not_add() -> None: assert isinstance(loss, BaseWeightedLoss) torch.testing.assert_close(loss.node_weights, torch.ones(1)) assert "test" not in loss.scalar + + +def test_combined_loss() -> None: + loss1 = WeightedMSELoss(torch.ones(1)) + loss2 = WeightedMAELoss(torch.ones(1)) + cl = CombinedLoss(losses=[loss1, loss2], loss_weights=(1.0, 0.5)) + assert isinstance(cl, CombinedLoss) + cl_class = CombinedLoss(losses=[WeightedMSELoss, WeightedMAELoss], + node_weights=torch.ones(1), + loss_weights=(1.0, 0.5)) + assert isinstance(cl_class, CombinedLoss) + From 96ac9ae050aa7cac9ffdfd5a2842a5ef958722da Mon Sep 17 00:00:00 2001 From: anaprietonem Date: Mon, 6 Jan 2025 10:05:24 +0000 Subject: [PATCH 2/2] exclude nans from error colorbars --- training/src/anemoi/training/diagnostics/plots.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/training/src/anemoi/training/diagnostics/plots.py b/training/src/anemoi/training/diagnostics/plots.py index 0ce55cdd..114a3709 100644 --- a/training/src/anemoi/training/diagnostics/plots.py +++ b/training/src/anemoi/training/diagnostics/plots.py @@ -584,6 +584,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: # For 'errors', only persistence and increments need identical colorbar-limits combined_error = np.concatenate(((pred - input_), (truth - input_))) norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data)) + norm_error = TwoSlopeNorm(vmin=np.nanmin(combined_error), vcenter=0.0, vmax=np.nanmax(combined_error)) single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader) single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader) single_plot( @@ -688,7 +689,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, pred - input_, cmap="bwr", - norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), + norm=norm_error, title=f"{vname} increment [pred - input]", datashader=datashader, ) @@ -699,7 +700,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, truth - input_, cmap="bwr", - norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), + norm=norm_error, title=f"{vname} persist err", datashader=datashader, )