Skip to content

Commit

Permalink
More code quality improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper Wijnands committed Nov 5, 2024
1 parent bb0b7d2 commit 6b5401a
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/anemoi/training/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
if TYPE_CHECKING:
from typing import Tuple
from typing import Union

VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object
Expand Down
1 change: 0 additions & 1 deletion src/anemoi/training/config/debug_dowa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,3 @@ model:
hidden2data: 0
hidden2hidden: 0
num_channels: 256

1 change: 0 additions & 1 deletion src/anemoi/training/config/graph/stretched_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,3 @@ edges:
edge_dirs:
_target_: anemoi.graphs.edges.attributes.EdgeDirection
norm: unit-std

24 changes: 15 additions & 9 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def forward(

class WeightedMSELossStretchedGrid(nn.Module):
"""Latitude-weighted MSE loss, calculated only within or outside the limited area.
Further, the loss can be computed for the specified region (default),
or as the contribution to the overall loss.
"""
Expand All @@ -101,7 +102,7 @@ def __init__(
self,
node_weights: torch.Tensor,
mask: torch.Tensor,
inside_LAM: bool = True,
inside_lam: bool = True,
wmse_contribution: bool = False,
data_variances: torch.Tensor | None = None,
ignore_nans: bool | None = False,
Expand All @@ -114,7 +115,7 @@ def __init__(
Weight of each node in the loss function
mask: torch.Tensor
the mask marking the indices of the regional data points (bool)
inside_LAM: bool
inside_lam: bool
compute the loss inside or outside the limited area, by default inside (True)
wmse_contribution: bool
compute loss as the contribution to the overall MSE, by default False
Expand All @@ -128,16 +129,21 @@ def __init__(
self.avg_function = torch.nanmean if ignore_nans else torch.mean
self.sum_function = torch.nansum if ignore_nans else torch.sum

self.inside_LAM = inside_LAM
self.inside_lam = inside_lam
self.wmse_contribution = wmse_contribution
self.register_buffer("weights", node_weights, persistent=True)
self.register_buffer("weights_inside_LAM", node_weights[mask], persistent=True)
self.register_buffer("weights_outside_LAM", node_weights[~mask], persistent=True)
self.register_buffer("weights_inside_lam", node_weights[mask], persistent=True)
self.register_buffer("weights_outside_lam", node_weights[~mask], persistent=True)
self.register_buffer("mask", mask, persistent=True)
if data_variances is not None:
self.register_buffer("ivar", data_variances, persistent=True)

def forward(self, pred: torch.Tensor, target: torch.Tensor, squash=True) -> torch.Tensor:
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
) -> torch.Tensor:
"""Calculates the lat-weighted MSE loss.
Parameters
Expand All @@ -156,14 +162,14 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor, squash=True) -> torc
"""
full_out_dims = pred[:, :, :, 0]

if self.inside_LAM:
if self.inside_lam:
pred = pred[:, :, self.mask]
target = target[:, :, self.mask]
weights_selected = self.weights_inside_LAM
weights_selected = self.weights_inside_lam
else:
pred = pred[:, :, ~self.mask]
target = target[:, :, ~self.mask]
weights_selected = self.weights_outside_LAM
weights_selected = self.weights_outside_lam

out = torch.square(pred - target)

Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,29 +113,29 @@ def __init__(
self.sg_metrics["wmse_inside_lam_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=True,
inside_lam=True,
wmse_contribution=False,
data_variances=loss_scaling,
)
self.sg_metrics["wmse_outside_lam_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=False,
inside_lam=False,
wmse_contribution=False,
data_variances=loss_scaling,
)
if config.diagnostics.sg_metrics.wmse_contributions:
self.sg_metrics["wmse_inside_lam_contribution_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=True,
inside_lam=True,
wmse_contribution=True,
data_variances=loss_scaling,
)
self.sg_metrics["wmse_outside_lam_contribution_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=False,
inside_lam=False,
wmse_contribution=True,
data_variances=loss_scaling,
)
Expand Down
8 changes: 8 additions & 0 deletions src/lumi_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# (C) Copyright 2024 Anemoi contributors
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from hydra import compose
from hydra import initialize

from anemoi.training.train.train import AnemoiTrainer

with initialize(version_base=None, config_path="anemoi/training/config"):
Expand Down

0 comments on commit 6b5401a

Please sign in to comment.