From 2ee5af4479bbf4a4099d908393baf83663cc0034 Mon Sep 17 00:00:00 2001 From: Jasper Wijnands Date: Tue, 5 Nov 2024 15:32:08 +0000 Subject: [PATCH] Code improvements based on code quality check output --- .../training/config/stretched_grid_cerra.yaml | 2 +- .../diagnostics/callbacks/__init__.py | 22 +++++----- src/anemoi/training/losses/mse.py | 4 +- src/anemoi/training/train/forecaster.py | 41 ++++++++++++++----- 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/anemoi/training/config/stretched_grid_cerra.yaml b/src/anemoi/training/config/stretched_grid_cerra.yaml index 39571148..8a053a96 100644 --- a/src/anemoi/training/config/stretched_grid_cerra.yaml +++ b/src/anemoi/training/config/stretched_grid_cerra.yaml @@ -112,4 +112,4 @@ training: lr: rate: 5e-4 #8 * 0.625e-4 min: 2.4e-6 #8 * 3e-7 - iterations: 150000 \ No newline at end of file + iterations: 150000 diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index ea9fd2c1..8fa3ee5a 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -196,7 +196,7 @@ def __init__(self, config: OmegaConf) -> None: ) self.rollout = config.diagnostics.eval.rollout self.frequency = config.diagnostics.eval.frequency - self.config=config + self.config = config def _eval( self, @@ -252,7 +252,7 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, rank_zero_only=True, ) - #check if stretched grid + # check if stretched grid if self.config.graph.nodes.hidden.node_builder.lam_resolution: for str_area in ["inside", "contribution_inside", "outside", "contribution_outside"]: pl_module.log( @@ -667,7 +667,7 @@ def __init__(self, config: OmegaConf) -> None: """ super().__init__(config) self.sample_idx = self.config.diagnostics.plot.sample_idx - self.config=config + self.config = config self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields LOGGER.info(f"Using defined accumulation colormap for fields: {self.precip_and_related_fields}") @@ -739,8 +739,8 @@ def _plot( exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", ) - #check if stretched grid - if self.config.graph.nodes.hidden.node_builder.lam_resolution: + # check if stretched grid + if "lam_resolution" in getattr(self.config.graph.nodes.hidden, "node_builder", []): fig_lam_inside = plot_predicted_multilevel_flat_sample( plot_parameters_dict, self.config.diagnostics.plot.per_sample, @@ -748,15 +748,15 @@ def _plot( self.config.diagnostics.plot.accumulation_levels_plot, self.config.diagnostics.plot.cmap_accumulation, data[0, :, pl_module.mask, :].squeeze(), - data[rollout_step + 1, :, pl_module.mask, :].squeeze(), + data[rollout_step + 1, :, pl_module.mask, :].squeeze(), output_tensor[rollout_step, :, pl_module.mask, :], ) self._output_figure( logger, fig_lam_inside, epoch=epoch, - tag=f"lam_inside_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"lam_inside_val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", + tag=f"pred_val_sample_inside_lam_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_inside_lam_rstep{rollout_step:02d}_rank{local_rank:01d}", ) fig_lam_outside = plot_predicted_multilevel_flat_sample( plot_parameters_dict, @@ -765,15 +765,15 @@ def _plot( self.config.diagnostics.plot.accumulation_levels_plot, self.config.diagnostics.plot.cmap_accumulation, data[0, :, ~pl_module.mask, :].squeeze(), - data[rollout_step + 1, :, ~pl_module.mask, :].squeeze(), + data[rollout_step + 1, :, ~pl_module.mask, :].squeeze(), output_tensor[rollout_step, :, ~pl_module.mask, :], ) self._output_figure( logger, fig_lam_outside, epoch=epoch, - tag=f"lam_outside_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"lam_outside_val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", + tag=f"pred_val_sample_outside_lam_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_outside_lam_rstep{rollout_step:02d}_rank{local_rank:01d}", ) def on_validation_batch_end( diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index a7549a9c..d0e2113e 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -162,11 +162,11 @@ def forward( full_out_dims = pred[:, :, :, 0] if self.inside_LAM: - pred = pred[ :, :, self.mask] + pred = pred[:, :, self.mask] target = target[:, :, self.mask] weights_selected = self.weights_inside_LAM else: - pred = pred[ :, :, ~self.mask] + pred = pred[:, :, ~self.mask] target = target[:, :, ~self.mask] weights_selected = self.weights_outside_LAM diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 7a92e24c..ff5fd9d6 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -28,7 +28,8 @@ from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from anemoi.training.losses.mse import WeightedMSELoss, WeightedMSELossStretchedGrid +from anemoi.training.losses.mse import WeightedMSELoss +from anemoi.training.losses.mse import WeightedMSELossStretchedGrid from anemoi.training.losses.utils import grad_scaler from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.masks import Boolean1DMask @@ -109,15 +110,35 @@ def __init__( # Define stretched grid metrics according to options for stretched grid loss logging self.sg_metrics = torch.nn.ModuleDict() if config.diagnostics.sg_metrics.wmse_per_region: - self.sg_metrics['wmse_inside_lam_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask, - inside_LAM=True, wmse_contribution=False, data_variances=loss_scaling) - self.sg_metrics['wmse_outside_lam_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask, - inside_LAM=False, wmse_contribution=False, data_variances=loss_scaling) + self.sg_metrics['wmse_inside_lam_epoch'] = WeightedMSELossStretchedGrid( + node_weights=self.loss_weights, + mask=self.mask, + inside_LAM=True, + wmse_contribution=False, + data_variances=loss_scaling, + ) + self.sg_metrics['wmse_outside_lam_epoch'] = WeightedMSELossStretchedGrid( + node_weights=self.loss_weights, + mask=self.mask, + inside_LAM=False, + wmse_contribution=False, + data_variances=loss_scaling, + ) if config.diagnostics.sg_metrics.wmse_contributions: - self.sg_metrics['wmse_inside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask, - inside_LAM=True, wmse_contribution=True, data_variances=loss_scaling) - self.sg_metrics['wmse_outside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid(node_weights=self.loss_weights, mask=self.mask, - inside_LAM=False, wmse_contribution=True, data_variances=loss_scaling) + self.sg_metrics['wmse_inside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid( + node_weights=self.loss_weights, + mask=self.mask, + inside_LAM=True, + wmse_contribution=True, + data_variances=loss_scaling, + ) + self.sg_metrics['wmse_outside_lam_contribution_epoch'] = WeightedMSELossStretchedGrid( + node_weights=self.loss_weights, + mask=self.mask, + inside_LAM=False, + wmse_contribution=True, + data_variances=loss_scaling, + ) else: self.stretched_grid = False @@ -307,7 +328,7 @@ def calculate_val_metrics( ) if self.stretched_grid: for name, metric in self.sg_metrics.items(): - metrics["{}".format(name)] = metric(y_pred, y) + metrics[f"{name}"] = metric(y_pred, y) if enable_plot: y_preds.append(y_pred) return metrics, y_preds