Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/combined loss #70

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions training/src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,16 @@ def __init__(
assert len(losses) == len(loss_weights), "Number of losses and weights must match"
assert len(losses) > 0, "At least one loss must be provided"

self.losses = [
GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs)
for loss in losses
]
self.losses = []
for loss in losses:
if isinstance(loss, dict):
self.losses.append(GraphForecaster.get_loss_function(loss, **kwargs))
elif isinstance(loss, type):
self.losses.append(loss(**kwargs))
elif hasattr(loss, "__class__"):
self.losses.append(loss)
Comment on lines +90 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we checking for __class__? If checking for an object why not isinstance(loss, object)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it could originally only take a class (of type "type", not instantiated) as losses arguments. Indeed, loss(**kwargs) called later in the function expects init arguments from the individual loss object and not forward arguments. As I said, I'll try to commit recent changes later.

else:
raise TypeError(f"Loss {loss} is not a valid loss function")
self.loss_weights = loss_weights

def forward(
Expand Down
31 changes: 21 additions & 10 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
Expand All @@ -27,14 +30,11 @@
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,13 +121,24 @@ def __init__(
"limited_area_mask": (2, limited_area_mask),
}
self.updated_loss_mask = False

self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs)

assert isinstance(self.loss, BaseWeightedLoss) and not isinstance(
self.loss,
torch.nn.ModuleList,
), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}"
if config.training.training_loss._target_ == 'anemoi.training.losses.combined.CombinedLoss':
assert "loss_weights" in config.training.training_loss, "Loss weights must be provided for combined loss"
losses = []
ignore_nans = config.training.training_loss.get("ignore_nans", False) # no point in doing this for each loss, nan+nan is nan
for loss in config.training.training_loss.losses:
node_weighting = instantiate(loss.node_weights)
loss_node_weights = node_weighting.weights(graph_data)
loss_node_weights = self.output_mask.apply(loss_node_weights, dim=0, fill_value=0.0)
loss_instantiated = self.get_loss_function(loss, scalars=self.scalars, **{"node_weights": loss_node_weights, "ignore_nans": ignore_nans})
losses.append(loss_instantiated)
assert isinstance(loss_instantiated, BaseWeightedLoss)
self.loss = instantiate({"_target_": config.training.training_loss._target_}, losses=losses, loss_weights = config.training.training_loss.loss_weights, **loss_kwargs)
else:
self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs)
assert isinstance(self.loss, BaseWeightedLoss) and not isinstance(
self.loss,
torch.nn.ModuleList,
), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}"
Comment on lines +124 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is over specific for this use case, and instantiate's objects unneccessarily

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instantiating node_weights was necessary to call the combined loss but if you find a way around it, please let me know... I have another version where all of this is implemented in the get_loss_function from the forecaster. It is cleaner so I'll try to commit it soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, yeah, as I wrote the loss functions code originally, I was able to find a way around, and only update the CombinedLoss class.

Copy link
Member

@HCookie HCookie Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you'd like, we can work together on https://github.com/ecmwf/anemoi-core/tree/fix/combined_loss_hcookie to make sure your use case is addressed.


self.metrics = self.get_loss_function(config.training.validation_metrics, scalars=self.scalars, **loss_kwargs)
if not isinstance(self.metrics, torch.nn.ModuleList):
Expand Down
14 changes: 14 additions & 0 deletions training/tests/train/test_loss_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from omegaconf import DictConfig

from anemoi.training.losses.mse import WeightedMSELoss
from anemoi.training.losses.mae import WeightedMAELoss
from anemoi.training.losses.combined import CombinedLoss
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.train.forecaster import GraphForecaster

Expand Down Expand Up @@ -62,3 +64,15 @@ def test_dynamic_init_scalar_not_add() -> None:
assert isinstance(loss, BaseWeightedLoss)
torch.testing.assert_close(loss.node_weights, torch.ones(1))
assert "test" not in loss.scalar


def test_combined_loss() -> None:
loss1 = WeightedMSELoss(torch.ones(1))
loss2 = WeightedMAELoss(torch.ones(1))
cl = CombinedLoss(losses=[loss1, loss2], loss_weights=(1.0, 0.5))
assert isinstance(cl, CombinedLoss)
cl_class = CombinedLoss(losses=[WeightedMSELoss, WeightedMAELoss],
node_weights=torch.ones(1),
loss_weights=(1.0, 0.5))
assert isinstance(cl_class, CombinedLoss)

Loading