Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/combined loss #70

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions training/src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import torch

from anemoi.training.train.forecaster import GraphForecaster
if TYPE_CHECKING:
from collections.abc import Sequence


class CombinedLoss(torch.nn.Module):
"""Combined Loss function."""

def __init__(
self,
*extra_losses: dict[str, Any] | Callable,
losses: tuple[dict[str, Any] | Callable] | None = None,
losses: Sequence[torch.nn.Module],
loss_weights: tuple[int, ...],
**kwargs,
):
"""Combined loss function.

Expand All @@ -39,15 +39,10 @@ def __init__(

Parameters
----------
losses: tuple[dict[str, Any]| Callable]
Tuple of losses to initialise with `GraphForecaster.get_loss_function`.
Allows for kwargs to be passed, and weighings controlled.
*extra_losses: dict[str, Any] | Callable
Additional arg form of losses to include in the combined loss.
losses:
Loss objects.
loss_weights : tuple[int, ...]
Weights of each loss function in the combined loss.
kwargs: Any
Additional arguments to pass to the loss functions

Examples
--------
Expand Down Expand Up @@ -76,15 +71,10 @@ def __init__(
"""
super().__init__()

losses = (*(losses or []), *extra_losses)

assert len(losses) == len(loss_weights), "Number of losses and weights must match"
assert len(losses) > 0, "At least one loss must be provided"

self.losses = [
GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs)
for loss in losses
]
self.losses = losses
self.loss_weights = loss_weights

def forward(
Expand Down
78 changes: 64 additions & 14 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.training.losses.combined import CombinedLoss
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
Expand Down Expand Up @@ -121,15 +122,21 @@ def __init__(
"limited_area_mask": (2, limited_area_mask),
}
self.updated_loss_mask = False
self.loss = self.get_loss_function(
config.training.training_loss,
scalars=self.scalars,
graph_data=graph_data,
output_mask=self.output_mask,
**loss_kwargs,
)

self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs)

assert isinstance(self.loss, BaseWeightedLoss) and not isinstance(
self.loss,
torch.nn.ModuleList,
), f"Loss function must be a `BaseWeightedLoss`, not a {type(self.loss).__name__!r}"

self.metrics = self.get_loss_function(config.training.validation_metrics, scalars=self.scalars, **loss_kwargs)
self.metrics = self.get_loss_function(
config.training.validation_metrics,
scalars=self.scalars,
graph_data=graph_data,
output_mask=self.output_mask,
**loss_kwargs,
)
if not isinstance(self.metrics, torch.nn.ModuleList):
self.metrics = torch.nn.ModuleList([self.metrics])

Expand Down Expand Up @@ -172,10 +179,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x, self.model_comm_group)

# Future import breaks other type hints TODO Harrison Cook
@staticmethod
@classmethod
def get_loss_function(
cls,
config: DictConfig,
scalars: Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None] = None, # noqa: FA100
scalars: dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], # noqa: FA100,
graph_data: HeteroData,
output_mask: torch.Tensor,
**kwargs,
) -> Union[BaseWeightedLoss, torch.nn.ModuleList]: # noqa: FA100
"""Get loss functions from config.
Expand All @@ -186,12 +196,16 @@ def get_loss_function(
----------
config : DictConfig
Loss function configuration, should include `scalars` if scalars are to be added to the loss function.
scalars : Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None], optional
Scalars which can be added to the loss function. Defaults to None., by default None
scalars : dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]]
Scalars which can be added to the loss function.
If a scalar is to be added to the loss, ensure it is in `scalars` in the loss config
E.g.
If `scalars: ['variable']` is set in the config, and `variable` in `scalars`
`variable` will be added to the scalar of the loss function.
graph_data : HeteroData
Graph data
output_mask : torch.Tensor
Output mask
kwargs : Any
Additional arguments to pass to the loss function

Expand All @@ -211,19 +225,55 @@ def get_loss_function(
if isinstance(config_container, list):
return torch.nn.ModuleList(
[
GraphForecaster.get_loss_function(
cls.get_loss_function(
OmegaConf.create(loss_config),
scalars=scalars,
graph_data=graph_data,
output_mask=output_mask,
**kwargs,
)
for loss_config in config
],
)

OmegaConf.resolve(config)

# Special case for combined loss
# The underlying losses must be instantiated first
# with kwargs that don't come from the config,
# so we can't use the normal instantiation method.
def full_name(type_: type) -> str:
return type_.__module__ + "." + type_.__name__

if config.get("_target_") == full_name(CombinedLoss):
assert hasattr(config, "loss_weights"), "Loss weights must be provided for combined loss"
loss_kwargs = kwargs.copy()
if config.get("ignore_nans", False):
loss_kwargs["ignore_nans"] = True

losses = [
cls.get_loss_function(
loss,
scalars=scalars,
graph_data=graph_data,
output_mask=output_mask,
**loss_kwargs,
)
for loss in config.losses
]
return instantiate({"_target_": config._target_}, losses=losses, loss_weights=config.loss_weights, **kwargs)

loss_config = OmegaConf.to_container(config, resolve=True)
scalars_to_include = loss_config.pop("scalars", [])

# Instantiate the loss function with the loss_init_config
if config.get("node_weights", None) is not None:
node_weighting = instantiate(config.node_weights)
node_weights = node_weighting.weights(graph_data)
node_weights = output_mask.apply(node_weights, dim=0, fill_value=0.0)
if node_weights.dtype == torch.bool:
node_weights = node_weights / node_weights.sum()
kwargs["node_weights"] = node_weights

loss_function = instantiate(loss_config, **kwargs)

if not isinstance(loss_function, BaseWeightedLoss):
Expand Down
59 changes: 56 additions & 3 deletions training/tests/train/test_loss_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,65 @@
# nor does it submit to any jurisdiction.


import pytest
import torch
from omegaconf import DictConfig
from torch_geometric.data import HeteroData

from anemoi.training.losses.combined import CombinedLoss
from anemoi.training.losses.mae import WeightedMAELoss
from anemoi.training.losses.mse import WeightedMSELoss
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.train.forecaster import GraphForecaster


@pytest.fixture
def graph_data() -> HeteroData:
hdata = HeteroData()
lons = torch.tensor([1.56, 3.12, 4.68, 6.24])
lats = torch.tensor([-3.12, -1.56, 1.56, 3.12])
cutout_mask = torch.tensor([False, True, False, False]).unsqueeze(1)
area_weights = torch.ones(cutout_mask.shape)
hdata["data"]["x"] = torch.stack((lats, lons), dim=1)
hdata["data"]["cutout"] = cutout_mask
hdata["data"]["area_weight"] = area_weights
return hdata


@pytest.fixture
def scalars() -> dict[str, tuple]:
variable_scaling = torch.tensor([1.0])
limited_area_mask = torch.tensor([1.0])
return {
"variable": (-1, variable_scaling),
"loss_weights_mask": ((-2, -1), torch.ones((1, 1))),
"limited_area_mask": (2, limited_area_mask),
}


@pytest.fixture
def output_mask() -> torch.Tensor:
return torch.tensor([1.0, 1.0, 1.0, 1.0])


def test_manual_init() -> None:
loss = WeightedMSELoss(torch.ones(1))
assert loss.node_weights == torch.ones(1)


def test_dynamic_init_include() -> None:
def test_dynamic_init_include(scalars: dict[str, tuple], graph_data: HeteroData, output_mask: torch.Tensor) -> None:
loss = GraphForecaster.get_loss_function(
DictConfig({"_target_": "anemoi.training.losses.mse.WeightedMSELoss"}),
node_weights=torch.ones(1),
scalars=scalars,
output_mask=output_mask,
graph_data=graph_data,
)
assert isinstance(loss, BaseWeightedLoss)
assert loss.node_weights == torch.ones(1)


def test_dynamic_init_scalar() -> None:
def test_dynamic_init_scalar(graph_data: HeteroData, output_mask: torch.Tensor) -> None:
loss = GraphForecaster.get_loss_function(
DictConfig(
{
Expand All @@ -40,6 +76,8 @@ def test_dynamic_init_scalar() -> None:
),
node_weights=torch.ones(1),
scalars={"test": ((0, 1), torch.ones((1, 2)))},
output_mask=output_mask,
graph_data=graph_data,
)
assert isinstance(loss, BaseWeightedLoss)

Expand All @@ -48,7 +86,7 @@ def test_dynamic_init_scalar() -> None:
torch.testing.assert_close(loss.scalar.get_scalar(2), torch.ones((1, 2)))


def test_dynamic_init_scalar_not_add() -> None:
def test_dynamic_init_scalar_not_add(graph_data: HeteroData, output_mask: torch.Tensor) -> None:
loss = GraphForecaster.get_loss_function(
DictConfig(
{
Expand All @@ -58,7 +96,22 @@ def test_dynamic_init_scalar_not_add() -> None:
),
node_weights=torch.ones(1),
scalars={"test": (-1, torch.ones(2))},
output_mask=output_mask,
graph_data=graph_data,
)
assert isinstance(loss, BaseWeightedLoss)
torch.testing.assert_close(loss.node_weights, torch.ones(1))
assert "test" not in loss.scalar


def test_combined_loss() -> None:
loss1 = WeightedMSELoss(torch.ones(1))
loss2 = WeightedMAELoss(torch.ones(1))
cl = CombinedLoss(losses=[loss1, loss2], loss_weights=(1.0, 0.5))
assert isinstance(cl, CombinedLoss)
cl_class = CombinedLoss(
losses=[WeightedMSELoss, WeightedMAELoss],
node_weights=torch.ones(1),
loss_weights=(1.0, 0.5),
)
assert isinstance(cl_class, CombinedLoss)
Loading