Skip to content

Commit

Permalink
Further code quality improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper Wijnands committed Nov 5, 2024
1 parent 2ee5af4 commit ee0d3c9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
7 changes: 1 addition & 6 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,7 @@ def __init__(
if data_variances is not None:
self.register_buffer("ivar", data_variances, persistent=True)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash=True
) -> torch.Tensor:
def forward(self, pred: torch.Tensor, target: torch.Tensor, squash=True) -> torch.Tensor:
"""Calculates the lat-weighted MSE loss.
Parameters
Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,29 +110,29 @@ def __init__(
# Define stretched grid metrics according to options for stretched grid loss logging
self.sg_metrics = torch.nn.ModuleDict()
if config.diagnostics.sg_metrics.wmse_per_region:
self.sg_metrics['wmse_inside_lam_epoch'] = WeightedMSELossStretchedGrid(
self.sg_metrics["wmse_inside_lam_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=True,
wmse_contribution=False,
data_variances=loss_scaling,
)
self.sg_metrics['wmse_outside_lam_epoch'] = WeightedMSELossStretchedGrid(
self.sg_metrics["wmse_outside_lam_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=False,
wmse_contribution=False,
data_variances=loss_scaling,
)
if config.diagnostics.sg_metrics.wmse_contributions:
self.sg_metrics['wmse_inside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(
self.sg_metrics["wmse_inside_lam_contribution_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=True,
wmse_contribution=True,
data_variances=loss_scaling,
)
self.sg_metrics['wmse_outside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(
self.sg_metrics["wmse_outside_lam_contribution_epoch"] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=False,
Expand Down

0 comments on commit ee0d3c9

Please sign in to comment.