Skip to content

Commit

Permalink
Code improvements based on code quality check output
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper Wijnands committed Nov 5, 2024
1 parent 6a9cf38 commit 2ee5af4
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/training/config/stretched_grid_cerra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ training:
lr:
rate: 5e-4 #8 * 0.625e-4
min: 2.4e-6 #8 * 3e-7
iterations: 150000
iterations: 150000
22 changes: 11 additions & 11 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(self, config: OmegaConf) -> None:
)
self.rollout = config.diagnostics.eval.rollout
self.frequency = config.diagnostics.eval.frequency
self.config=config
self.config = config

def _eval(
self,
Expand Down Expand Up @@ -252,7 +252,7 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict,
rank_zero_only=True,
)

#check if stretched grid
# check if stretched grid
if self.config.graph.nodes.hidden.node_builder.lam_resolution:
for str_area in ["inside", "contribution_inside", "outside", "contribution_outside"]:
pl_module.log(
Expand Down Expand Up @@ -667,7 +667,7 @@ def __init__(self, config: OmegaConf) -> None:
"""
super().__init__(config)
self.sample_idx = self.config.diagnostics.plot.sample_idx
self.config=config
self.config = config
self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields
LOGGER.info(f"Using defined accumulation colormap for fields: {self.precip_and_related_fields}")

Expand Down Expand Up @@ -739,24 +739,24 @@ def _plot(
exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}",
)

#check if stretched grid
if self.config.graph.nodes.hidden.node_builder.lam_resolution:
# check if stretched grid
if "lam_resolution" in getattr(self.config.graph.nodes.hidden, "node_builder", []):
fig_lam_inside = plot_predicted_multilevel_flat_sample(
plot_parameters_dict,
self.config.diagnostics.plot.per_sample,
self.latlons[pl_module.mask],
self.config.diagnostics.plot.accumulation_levels_plot,
self.config.diagnostics.plot.cmap_accumulation,
data[0, :, pl_module.mask, :].squeeze(),
data[rollout_step + 1, :, pl_module.mask, :].squeeze(),
data[rollout_step + 1, :, pl_module.mask, :].squeeze(),
output_tensor[rollout_step, :, pl_module.mask, :],
)
self._output_figure(
logger,
fig_lam_inside,
epoch=epoch,
tag=f"lam_inside_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0",
exp_log_tag=f"lam_inside_val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}",
tag=f"pred_val_sample_inside_lam_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0",
exp_log_tag=f"val_pred_sample_inside_lam_rstep{rollout_step:02d}_rank{local_rank:01d}",
)
fig_lam_outside = plot_predicted_multilevel_flat_sample(
plot_parameters_dict,
Expand All @@ -765,15 +765,15 @@ def _plot(
self.config.diagnostics.plot.accumulation_levels_plot,
self.config.diagnostics.plot.cmap_accumulation,
data[0, :, ~pl_module.mask, :].squeeze(),
data[rollout_step + 1, :, ~pl_module.mask, :].squeeze(),
data[rollout_step + 1, :, ~pl_module.mask, :].squeeze(),
output_tensor[rollout_step, :, ~pl_module.mask, :],
)
self._output_figure(
logger,
fig_lam_outside,
epoch=epoch,
tag=f"lam_outside_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0",
exp_log_tag=f"lam_outside_val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}",
tag=f"pred_val_sample_outside_lam_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0",
exp_log_tag=f"val_pred_sample_outside_lam_rstep{rollout_step:02d}_rank{local_rank:01d}",
)

def on_validation_batch_end(
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def forward(
full_out_dims = pred[:, :, :, 0]

if self.inside_LAM:
pred = pred[ :, :, self.mask]
pred = pred[:, :, self.mask]
target = target[:, :, self.mask]
weights_selected = self.weights_inside_LAM
else:
pred = pred[ :, :, ~self.mask]
pred = pred[:, :, ~self.mask]
target = target[:, :, ~self.mask]
weights_selected = self.weights_outside_LAM

Expand Down
41 changes: 31 additions & 10 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.training.losses.mse import WeightedMSELoss, WeightedMSELossStretchedGrid
from anemoi.training.losses.mse import WeightedMSELoss
from anemoi.training.losses.mse import WeightedMSELossStretchedGrid
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
Expand Down Expand Up @@ -109,15 +110,35 @@ def __init__(
# Define stretched grid metrics according to options for stretched grid loss logging
self.sg_metrics = torch.nn.ModuleDict()
if config.diagnostics.sg_metrics.wmse_per_region:
self.sg_metrics['wmse_inside_lam_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask,
inside_LAM=True, wmse_contribution=False, data_variances=loss_scaling)
self.sg_metrics['wmse_outside_lam_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask,
inside_LAM=False, wmse_contribution=False, data_variances=loss_scaling)
self.sg_metrics['wmse_inside_lam_epoch'] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=True,
wmse_contribution=False,
data_variances=loss_scaling,
)
self.sg_metrics['wmse_outside_lam_epoch'] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=False,
wmse_contribution=False,
data_variances=loss_scaling,
)
if config.diagnostics.sg_metrics.wmse_contributions:
self.sg_metrics['wmse_inside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask,
inside_LAM=True, wmse_contribution=True, data_variances=loss_scaling)
self.sg_metrics['wmse_outside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask,
inside_LAM=False, wmse_contribution=True, data_variances=loss_scaling)
self.sg_metrics['wmse_inside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=True,
wmse_contribution=True,
data_variances=loss_scaling,
)
self.sg_metrics['wmse_outside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(
node_weights=self.loss_weights,
mask=self.mask,
inside_LAM=False,
wmse_contribution=True,
data_variances=loss_scaling,
)
else:
self.stretched_grid = False

Expand Down Expand Up @@ -307,7 +328,7 @@ def calculate_val_metrics(
)
if self.stretched_grid:
for name, metric in self.sg_metrics.items():
metrics["{}".format(name)] = metric(y_pred, y)
metrics[f"{name}"] = metric(y_pred, y)
if enable_plot:
y_preds.append(y_pred)
return metrics, y_preds
Expand Down

0 comments on commit 2ee5af4

Please sign in to comment.