Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improve loss functions #70

Open
wants to merge 54 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
e205af4
Dynamic loss function intialisation
HCookie Oct 2, 2024
85905d9
Add more loss functions and include ensemble dim
HCookie Oct 2, 2024
f105df6
Update CHANGELOG
HCookie Oct 2, 2024
dba6c01
Reference PR in CHANGELOG
HCookie Oct 2, 2024
91e3ae3
Address PR Comments
HCookie Oct 3, 2024
23c1670
Allow for lists of losses
HCookie Oct 3, 2024
78d6588
Fix ruff complaining
HCookie Oct 3, 2024
5f25370
refactor: Create WeightedLoss
HCookie Oct 4, 2024
1054f7f
fix: Assert and conversion in loss function get
HCookie Oct 4, 2024
19c1799
Fix names to reference weighted
HCookie Oct 5, 2024
ef49cbb
fix: rework init include loop to raise an error if missing key found
HCookie Oct 5, 2024
aa605a6
fix: refine documentation for feature_weights
HCookie Oct 5, 2024
7c4c8cb
Fix: Remove reduction over ensemble
HCookie Oct 5, 2024
eb7bded
Rename WeightedLoss to BaseWeightedLoss
HCookie Oct 9, 2024
4ad033a
Rename `error` to `x`
HCookie Oct 9, 2024
93d1647
Rework feature_scale use
HCookie Oct 9, 2024
17b0940
Change to registering feature_weights as None
HCookie Oct 9, 2024
e60e5b3
Fix: Use numpy docstring
HCookie Oct 9, 2024
c81d4d5
Fix: Improve assert message
HCookie Oct 9, 2024
6a6bcf1
Refactor: Renames
HCookie Oct 9, 2024
4c28d96
Fix: Remove redundant if
HCookie Oct 11, 2024
b4d0068
Simplify val metrics
HCookie Oct 11, 2024
97b3f7c
Clamp logcosh at 710
HCookie Oct 11, 2024
5befe5f
Fix log
HCookie Oct 11, 2024
27b45f2
Rename loss function configuration
HCookie Oct 14, 2024
3b75f0e
split feature weights from val_metrics
mc4117 Oct 14, 2024
1084c71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
035483d
Merge pull request #85 from ecmwf/feature/improve_feature_loss
mc4117 Oct 14, 2024
92e7587
Merge branch 'develop' into feature/improve_loss_functions
mc4117 Oct 14, 2024
1c36e0f
Remove mention to optional ensemble
HCookie Oct 15, 2024
5b93dda
Merge branch 'develop' into feature/improve_loss_functions
HCookie Oct 15, 2024
d384f85
Merge remote-tracking branch 'origin/develop' into feature/improve_lo…
HCookie Oct 18, 2024
211879d
Check if weighted loss in callback
HCookie Oct 18, 2024
6cd0e70
Improve docs
HCookie Oct 18, 2024
4b170d0
Rename loss_scaling in losses to variable_scaling
HCookie Oct 18, 2024
df1c92a
Add Huber Loss
HCookie Oct 22, 2024
6dcc9b7
Add CombinedLoss
HCookie Oct 22, 2024
202011f
Merge branch 'develop' into feature/improve_loss_functions
HCookie Oct 22, 2024
ec2c204
Add getattr to CombinedLoss
HCookie Oct 22, 2024
da89d99
Add ScaleTensor (#96)
HCookie Oct 23, 2024
8392873
Update copyright notice
HCookie Oct 23, 2024
a628fae
Merge branch 'develop' into feature/improve_loss_functions
HCookie Oct 23, 2024
e43011d
Fix reference to metric_ranges_validation
HCookie Oct 24, 2024
1406d51
Merge branch 'feature/improve_loss_functions' of github.com:ecmwf/ane…
HCookie Oct 24, 2024
11fe213
Upadate copyright notice
HCookie Oct 24, 2024
7694733
Improve config documentation
HCookie Oct 24, 2024
5c7ed01
Fix initalisation bugs
HCookie Oct 24, 2024
07a2a8b
Add tests to ensure correct initalisation
HCookie Oct 24, 2024
fda2085
Add docs to add_scalar
HCookie Oct 24, 2024
1892ad6
Merge branch 'develop' into feature/improve_loss_functions
HCookie Oct 24, 2024
8a17fb1
Refactor of loss functions
HCookie Oct 25, 2024
2e9f721
Update docs
HCookie Oct 25, 2024
22cd154
Reorder docs
HCookie Oct 25, 2024
6d605ce
Merge branch 'develop' into feature/improve_loss_functions
HCookie Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you!
### Added
- Codeowners file (#56)
- Changelog merge strategy (#56)
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)

#### Miscellaneous

Expand Down
32 changes: 28 additions & 4 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,34 @@ swa:
# use ZeroRedundancyOptimizer ; saves memory for larger models
zero_optimizer: False

# dynamic rescaling of the loss gradient
# see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2
# don't enable this by default until it's been tested and proven beneficial
loss_gradient_scaling: False
# loss functions
loss_functions:
HCookie marked this conversation as resolved.
Show resolved Hide resolved

# dynamic rescaling of the loss gradient
# see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2
# don't enable this by default until it's been tested and proven beneficial
loss_gradient_scaling: False


# loss function for the model
loss:
# loss class to initialise, can be anything subclassing torch.nn.Module
_target_: anemoi.training.losses.mse.WeightedMSELoss
# what to include in the loss class initialisation
include_node_weights: True
HCookie marked this conversation as resolved.
Show resolved Hide resolved
include_feature_weights: True
# other kwargs
ignore_nans: False

# loss function for metric calculation
metrics:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
# loss class to initialise, can be anything subclassing torch.nn.Module
_target_: anemoi.training.losses.mse.WeightedMSELoss
# what to include in the loss class initialisation
include_node_weights: True
include_feature_weights: False
# other kwargs
ignore_nans: True
HCookie marked this conversation as resolved.
Show resolved Hide resolved

# length of the "rollout" window (see Keisler's paper)
rollout:
Expand Down
109 changes: 109 additions & 0 deletions src/anemoi/training/losses/logcosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from __future__ import annotations

import logging
from functools import cached_property

import torch
from torch import nn

LOGGER = logging.getLogger(__name__)


class WeightedLogCoshLoss(nn.Module):
"""Latitude-weighted LogCosh loss."""

def __init__(
self,
node_weights: torch.Tensor,
feature_weights: torch.Tensor | None = None,
ignore_nans: bool | None = False,
) -> None:
"""Latitude- and (inverse-)variance-weighted LogCosh Loss.

Parameters
----------
node_weights : torch.Tensor of shape (N, )
Weight of each node in the loss function
feature_weights : Optional[torch.Tensor], optional
precomputed, per-variable stepwise variance estimate, by default None
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False

"""
super().__init__()

self.avg_function = torch.nanmean if ignore_nans else torch.mean
self.sum_function = torch.nansum if ignore_nans else torch.sum

self.register_buffer("weights", node_weights, persistent=True)
if feature_weights is not None:
self.register_buffer("ivar", feature_weights, persistent=True)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
) -> torch.Tensor:
"""Calculates the lat-weighted LogCosh loss.

Parameters
----------
pred : torch.Tensor
Prediction tensor, shape (bs, lat*lon, n_outputs)
target : torch.Tensor
Target tensor, shape (bs, lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights

Returns
-------
torch.Tensor
Weighted LogCosh loss

"""
if pred.ndim == 4:
pred = pred.mean(dim=1)

out = torch.log(torch.cosh(pred - target))

# Use variances if available
if feature_scale and hasattr(self, "feature_weights"):
out = (
out * self.feature_weights
if feature_indices is None
else out * self.feature_weights[..., feature_indices]
)

# Squash by last dimension
if squash:
out = self.avg_function(out, dim=-1)
# Weight by area
out *= self.weights.expand_as(out)
out /= self.sum_function(self.weights.expand_as(out))
return self.sum_function(out)

# Weight by area, due to weighting construction is analagous to a mean
out *= self.weights[..., None].expand_as(out)
# keep last dimension (variables) when summing weights
out /= self.sum_function(self.weights[..., None].expand_as(out), axis=(0, 1, 2))
return self.sum_function(out, axis=(0, 1, 2))

@cached_property
def name(self) -> str:
return "logcosh"
108 changes: 108 additions & 0 deletions src/anemoi/training/losses/mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from __future__ import annotations

import logging
from functools import cached_property

import torch
from torch import nn

LOGGER = logging.getLogger(__name__)


class WeightedMAELoss(nn.Module):
"""Latitude-weighted MAE loss."""

def __init__(
self,
node_weights: torch.Tensor,
feature_weights: torch.Tensor | None = None,
ignore_nans: bool = False,
) -> None:
"""Latitude- and (inverse-)variance-weighted MAE Loss.

Also known as the Weighted L1 loss.

Parameters
----------
node_weights : torch.Tensor
Weights by area
feature_weights : Optional[torch.Tensor], optional
precomputed, per-variable stepwise variance estimate, by default None
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False
"""
super().__init__()

self.avg_function = torch.nanmean if ignore_nans else torch.mean
self.sum_function = torch.nansum if ignore_nans else torch.sum

self.register_buffer("node_weights", node_weights, persistent=True)
if feature_weights is not None:
self.register_buffer("feature_weights", feature_weights, persistent=True)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
) -> torch.Tensor:
"""Calculates the lat-weighted MAE loss.

Parameters
----------
pred : torch.Tensor
Prediction tensor, shape (bs, (optional_ensemble), lat*lon, n_outputs)
target : torch.Tensor
Target tensor, shape (bs, (optional_ensemble), lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights

Returns
-------
torch.Tensor
Weighted MAE loss
"""
if pred.ndim == 4:
pred = pred.mean(dim=1)

out = torch.abs(pred - target)

if feature_scale and hasattr(self, "feature_weights"):
out = (
out * self.feature_weights
if feature_indices is None
else out * self.feature_weights[..., feature_indices]
)

# Squash by last dimension
if squash:
out = self.avg_function(out, dim=-1)
# Weight by area
out *= self.node_weights.expand_as(out)
out /= self.sum_function(self.node_weights.expand_as(out))
return self.sum_function(out)

# Weight by area, due to weighting construction is analagous to a mean
out *= self.node_weights[..., None].expand_as(out)
# keep last dimension (variables) when summing weights
out /= self.sum_function(self.node_weights[..., None].expand_as(out))
return self.sum_function(out, axis=(0, 1, 2))

@cached_property
def name(self) -> str:
return "mae"
59 changes: 39 additions & 20 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from __future__ import annotations

import logging
from functools import cached_property

import torch
from torch import nn
Expand All @@ -23,69 +24,87 @@ class WeightedMSELoss(nn.Module):
def __init__(
self,
node_weights: torch.Tensor,
data_variances: torch.Tensor | None = None,
ignore_nans: bool | None = False,
feature_weights: torch.Tensor | None = None,
ignore_nans: bool = False,
) -> None:
"""Latitude- and (inverse-)variance-weighted MSE Loss.

Parameters
----------
node_weights : torch.Tensor of shape (N, )
Weight of each node in the loss function
data_variances : Optional[torch.Tensor], optional
node_weights : torch.Tensor
Weights by area
feature_weights : Optional[torch.Tensor], optional
precomputed, per-variable stepwise variance estimate, by default None
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False

"""
super().__init__()

self.avg_function = torch.nanmean if ignore_nans else torch.mean
self.sum_function = torch.nansum if ignore_nans else torch.sum

self.register_buffer("weights", node_weights, persistent=True)
if data_variances is not None:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
self.register_buffer("ivar", data_variances, persistent=True)
self.register_buffer("node_weights", node_weights, persistent=True)
if feature_weights is not None:
self.register_buffer("feature_weights", feature_weights, persistent=True)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
feature_indices: torch.Tensor | None = None,
feature_scale: bool = True,
) -> torch.Tensor:
"""Calculates the lat-weighted MSE loss.

Parameters
----------
pred : torch.Tensor
Prediction tensor, shape (bs, lat*lon, n_outputs)
Prediction tensor, shape (bs, (optional_ensemble), lat*lon, n_outputs)
target : torch.Tensor
Target tensor, shape (bs, lat*lon, n_outputs)
Target tensor, shape (bs, (optional_ensemble), lat*lon, n_outputs)
squash : bool, optional
Average last dimension, by default True
feature_indices:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
feature indices (relative to full model output) of the features passed in pred and target
feature_scale:
If True, scale the loss by the feature_weights

Returns
-------
torch.Tensor
Weighted MSE loss

"""
if pred.ndim == 4:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
pred = pred.mean(dim=1)

torch.save(self.node_weights, "node_weights.pt")
torch.save(pred, "pred.pt")
torch.save(target, "target.pt")
HCookie marked this conversation as resolved.
Show resolved Hide resolved

out = torch.square(pred - target)

# Use variances if available
if hasattr(self, "ivar"):
out *= self.ivar
if feature_scale and hasattr(self, "feature_weights"):
out = (
HCookie marked this conversation as resolved.
Show resolved Hide resolved
out * self.feature_weights
if feature_indices is None
else out * self.feature_weights[..., feature_indices]
)

# Squash by last dimension
if squash:
out = self.avg_function(out, dim=-1)
# Weight by area
out *= self.weights.expand_as(out)
out /= self.sum_function(self.weights.expand_as(out))
out *= self.node_weights.expand_as(out)
out /= self.sum_function(self.node_weights.expand_as(out))
return self.sum_function(out)

# Weight by area
out *= self.weights[..., None].expand_as(out)
# Weight by area, due to weighting construction is analagous to a mean
out *= self.node_weights[..., None].expand_as(out)
# keep last dimension (variables) when summing weights
out /= self.sum_function(self.weights[..., None].expand_as(out), axis=(0, 1, 2))
out /= self.sum_function(self.node_weights[..., None].expand_as(out))
HCookie marked this conversation as resolved.
Show resolved Hide resolved
return self.sum_function(out, axis=(0, 1, 2))

@cached_property
def name(self) -> str:
HCookie marked this conversation as resolved.
Show resolved Hide resolved
return "mse"
Loading
Loading