diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index d0e2113e..9395f3d5 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -136,12 +136,7 @@ def __init__( if data_variances is not None: self.register_buffer("ivar", data_variances, persistent=True) - def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - squash=True - ) -> torch.Tensor: + def forward(self, pred: torch.Tensor, target: torch.Tensor, squash=True) -> torch.Tensor: """Calculates the lat-weighted MSE loss. Parameters diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff5fd9d6..f5e67915 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -110,14 +110,14 @@ def __init__( # Define stretched grid metrics according to options for stretched grid loss logging self.sg_metrics = torch.nn.ModuleDict() if config.diagnostics.sg_metrics.wmse_per_region: - self.sg_metrics['wmse_inside_lam_epoch'] = WeightedMSELossStretchedGrid( + self.sg_metrics["wmse_inside_lam_epoch"] = WeightedMSELossStretchedGrid( node_weights=self.loss_weights, mask=self.mask, inside_LAM=True, wmse_contribution=False, data_variances=loss_scaling, ) - self.sg_metrics['wmse_outside_lam_epoch'] = WeightedMSELossStretchedGrid( + self.sg_metrics["wmse_outside_lam_epoch"] = WeightedMSELossStretchedGrid( node_weights=self.loss_weights, mask=self.mask, inside_LAM=False, @@ -125,14 +125,14 @@ def __init__( data_variances=loss_scaling, ) if config.diagnostics.sg_metrics.wmse_contributions: - self.sg_metrics['wmse_inside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid( + self.sg_metrics["wmse_inside_lam_contribution_epoch"] = WeightedMSELossStretchedGrid( node_weights=self.loss_weights, mask=self.mask, inside_LAM=True, wmse_contribution=True, data_variances=loss_scaling, ) - self.sg_metrics['wmse_outside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid( + self.sg_metrics["wmse_outside_lam_contribution_epoch"] = WeightedMSELossStretchedGrid( node_weights=self.loss_weights, mask=self.mask, inside_LAM=False,