Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Rework Loss Scalings to provide better modularity #52

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
511ed18
first version of refactor of variable scaling
sahahner Dec 27, 2024
7ddf6d6
config training changes
sahahner Dec 27, 2024
3ddeccc
avoid multiple scaling
sahahner Dec 27, 2024
be4602c
docstring and explain variable reference
sahahner Dec 31, 2024
195af07
fix to config for pressure level scaler
mc4117 Dec 31, 2024
2644c18
instantiating scalars as a list
mc4117 Dec 31, 2024
718fc57
preparing for tendency losses
mc4117 Dec 31, 2024
a34ac02
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
mc4117 Dec 31, 2024
b91af11
log the variable level scaling information as before
sahahner Jan 2, 2025
c22c50b
adding tendency scaler to additional scalers
pinnstorm Jan 8, 2025
1f4a532
reformatting
pinnstorm Jan 8, 2025
2843d98
updating description in configs
pinnstorm Jan 8, 2025
c978871
updating var-tendency-scaler spec
pinnstorm Jan 12, 2025
f56f9b2
updating training/default config
pinnstorm Jan 12, 2025
be90000
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2025
e474ae9
updating training/default.yaml
pinnstorm Jan 13, 2025
f005f84
updating training/default.yaml
pinnstorm Jan 13, 2025
7cdccc5
first try at tests
mc4117 Jan 17, 2025
61e7933
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 17, 2025
462bb34
variable name and level from mars metadata
sahahner Jan 17, 2025
960a602
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
sahahner Jan 17, 2025
af10173
get variable group and level in utils file
sahahner Jan 17, 2025
395cd6f
empty line
sahahner Jan 17, 2025
1f53a82
convert test for new strucutre. pressure level and general variable s…
sahahner Jan 17, 2025
3747959
more plausible check for availability of mars metadata
sahahner Jan 17, 2025
68cd6e3
update to tendency tests (still not working)
mc4117 Jan 17, 2025
d3a7c29
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
mc4117 Jan 17, 2025
d6e127a
tendency scaler tests now working
mc4117 Jan 20, 2025
fd29cbc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2025
8bff68b
change function into class, extracting variable group and name
sahahner Jan 22, 2025
4c7cbc1
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
sahahner Jan 22, 2025
7d8c76d
correct function call
sahahner Jan 22, 2025
d928b30
correct typo in test
sahahner Jan 22, 2025
bb054ce
incorporate comments
sahahner Jan 22, 2025
d0046fa
introduce base class for all loss scalings
sahahner Jan 22, 2025
a03d6ba
type checking check after all imports
sahahner Jan 22, 2025
aa7f558
comment: explanation about variable groups in config file
sahahner Jan 22, 2025
9a8a4b9
rm if statement for tendency scaler
mc4117 Jan 22, 2025
66d66ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
db05ce5
use utils function to retrieve variable group and reference for valid…
sahahner Jan 22, 2025
61766cd
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
sahahner Jan 22, 2025
3adf924
comment in config file that scler name needs to be added to loss as w…
sahahner Jan 22, 2025
f19d69d
fix pre-commit hooks
mc4117 Jan 22, 2025
c26d744
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
mc4117 Jan 22, 2025
00439cb
Update description in training/default
mc4117 Jan 24, 2025
6c857a6
refactor into training/scaling both the code and the config files, re…
sahahner Jan 27, 2025
a2f2728
more scalar renaming to scaler
sahahner Jan 27, 2025
b5f6b5f
fix tendency loss
mc4117 Jan 27, 2025
b5fa55b
fix merge conflict
mc4117 Jan 27, 2025
cdb9e19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
963c543
Add '*' to scaler selection.
HCookie Jan 27, 2025
4f1566b
Add exclusion of scalers
HCookie Jan 27, 2025
e4ceb8e
Fix scalar reference in tests
HCookie Jan 27, 2025
7178074
Add all and exclude tests
HCookie Jan 27, 2025
08b4cb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
0dbf0b8
fix: update all tests, move scaling module into losses
sahahner Jan 28, 2025
2dccbd2
print final variable scaling in debug mode
sahahner Jan 28, 2025
72793f0
Training reorder parameter names for plot (#55)
sahahner Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ training_loss:
# Available scalars include:
# - 'variable': See `variable_loss_scaling` for more information
# - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function
scalars: ['variable', 'loss_weights_mask']
# - 'tendency': See `additional_scalars` for more information
scalars: ['variable', 'variable_pressure_level', 'loss_weights_mask']

ignore_nans: False

Expand Down Expand Up @@ -109,33 +110,49 @@ lr:
# Variable loss scaling
# 'variable' must be included in `scalars` in the losses for this to be applied.
variable_loss_scaling:
variable_groups:
default: sfc
pl: [q, t, u, v, w, z]
HCookie marked this conversation as resolved.
Show resolved Hide resolved
default: 1
pl:
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sfc:
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025
additional_scalars:
# pressure level scalar
- _target_: anemoi.training.train.scaling.ReluVariableLevelScaler
group: pl
y_intercept: 0.2
slope: 0.001
scale_dim: -1 # dimension on which scaling applied
name: "variable_pressure_level"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest that instead of providing name as an attribute, additional_scalars should be dict[str, dict],
so,

additional_scalars:
   variable_pressure_level:
       _target_: ------
       ....

This then will be used to name the scalar for the loss function to choose between.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be consistent with the way we create the nodes in anemoi-graphs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. I can have a look at it.

# tendency scalers
# scale the prognostic losses by the stdev of the variable tendencies (e.g. the 6-hourly differences of the data)
# useful if including slow vs fast evolving variables in the training (e.g. Land/Ocean vs Atmosphere)
# if using this option 'variable_loss_scalings' should all be set close to 1.0 for prognostic variables
# stdev tendency scaler
# - _target_: anemoi.training.data.scaling.StdevTendencyScaler
sahahner marked this conversation as resolved.
Show resolved Hide resolved
# scale_dim: -1 # dimension on which scaling applied
# name: "tendency"
# var tendency scaler (this should be default!?)
# - _target_: anemoi.training.data.scaling.VarTendencyScaler
# scale_dim: -1 # dimension on which scaling applied
# name: "tendency"

metrics:
- z_500
- t_850
- u_850
- v_850

pressure_level_scaler:
_target_: anemoi.training.data.scaling.ReluPressureLevelScaler
minimum: 0.2
slope: 0.001

node_loss_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
Expand Down
5 changes: 5 additions & 0 deletions training/src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
def statistics(self) -> dict:
return self.ds_train.statistics

@cached_property
def statistics_tendencies(self) -> dict:
return self.ds_train.statistics_tendencies

@cached_property
def metadata(self) -> dict:
return self.ds_train.metadata
Expand Down Expand Up @@ -183,6 +187,7 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
timestep=self.config.data.timestep,
shuffle=shuffle,
grid_indices=self.grid_indices,
label=label,
Expand Down
12 changes: 12 additions & 0 deletions training/src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
timestep: str = "6h",
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1,
Expand All @@ -57,6 +58,8 @@ def __init__(
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
timestep : int, optional
the time frequency of the samples, by default '6h'
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
shuffle : bool, optional
Expand All @@ -73,6 +76,7 @@ def __init__(

self.rollout = rollout
self.timeincrement = timeincrement
self.timestep = timestep
self.grid_indices = grid_indices

# lazy init
Expand Down Expand Up @@ -104,6 +108,14 @@ def statistics(self) -> dict:
"""Return dataset statistics."""
return self.data.statistics

@cached_property
def statistics_tendencies(self) -> dict:
"""Return dataset tendency statistics."""
try:
return self.data.statistics_tendencies(self.timestep)
except (KeyError, AttributeError):
return None

@cached_property
def metadata(self) -> dict:
"""Return dataset metadata."""
Expand Down
79 changes: 0 additions & 79 deletions training/src/anemoi/training/data/scaling.py

This file was deleted.

73 changes: 32 additions & 41 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Optional
from typing import Union

import numpy as np
import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
Expand All @@ -31,6 +30,7 @@
from anemoi.models.interface import AnemoiModelInterface
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.train.scaling import GeneralVariableLossScaler
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask
Expand All @@ -48,6 +48,7 @@ def __init__(
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
statistics_tendencies: dict,
data_indices: IndexCollection,
metadata: dict,
supporting_arrays: dict,
Expand Down Expand Up @@ -95,10 +96,36 @@ def __init__(
self.latlons_data = graph_data[config.graph.data].x
self.node_weights = self.get_node_weights(config, graph_data)
self.node_weights = self.output_mask.apply(self.node_weights, dim=0, fill_value=0.0)
self.statistics_tendencies = statistics_tendencies

self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled

variable_scaling = self.get_variable_scaling(config, data_indices)
variable_scaling = GeneralVariableLossScaler(
config.training.variable_loss_scaling,
data_indices,
).get_variable_scaling()

# Instantiate the pressure level scaling class with the training configuration
sahahner marked this conversation as resolved.
Show resolved Hide resolved
config_container = OmegaConf.to_container(config.training.additional_scalars, resolve=False)
if isinstance(config_container, list):
scalar = [
mc4117 marked this conversation as resolved.
Show resolved Hide resolved
(
instantiate(
scalar_config,
scaling_config=config.training.variable_loss_scaling,
data_indices=data_indices,
statistics=statistics,
statistics_tendencies=statistics_tendencies,
)
if scalar_config["name"] == "tendency"
else instantiate(
scalar_config,
scaling_config=config.training.variable_loss_scaling,
data_indices=data_indices,
mc4117 marked this conversation as resolved.
Show resolved Hide resolved
)
)
for scalar_config in config_container
]

self.internal_metric_ranges, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices)

Expand All @@ -120,6 +147,9 @@ def __init__(
"loss_weights_mask": ((-2, -1), torch.ones((1, 1))),
"limited_area_mask": (2, limited_area_mask),
}
# add addtional user-defined scalars
[self.scalars.update({scale.name: (scale.scale_dim, scale.get_variable_scaling())}) for scale in scalar]

self.updated_loss_mask = False

self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs)
Expand Down Expand Up @@ -299,45 +329,6 @@ def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) ->

return metric_ranges, metric_ranges_validation

@staticmethod
def get_variable_scaling(
config: DictConfig,
data_indices: IndexCollection,
) -> torch.Tensor:
variable_loss_scaling = (
np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32)
* config.training.variable_loss_scaling.default
)
pressure_level = instantiate(config.training.pressure_level_scaler)

LOGGER.info(
"Pressure level scaling: use scaler %s with slope %.4f and minimum %.2f",
type(pressure_level).__name__,
pressure_level.slope,
pressure_level.minimum,
)

for key, idx in data_indices.internal_model.output.name_to_index.items():
split = key.split("_")
if len(split) > 1 and split[-1].isdigit():
# Apply pressure level scaling
if split[0] in config.training.variable_loss_scaling.pl:
variable_loss_scaling[idx] = config.training.variable_loss_scaling.pl[
split[0]
] * pressure_level.scaler(
int(split[-1]),
)
else:
LOGGER.debug("Parameter %s was not scaled.", key)
else:
# Apply surface variable scaling
if key in config.training.variable_loss_scaling.sfc:
variable_loss_scaling[idx] = config.training.variable_loss_scaling.sfc[key]
else:
LOGGER.debug("Parameter %s was not scaled.", key)

return torch.from_numpy(variable_loss_scaling)

@staticmethod
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weighting = instantiate(config.training.node_loss_weights)
Expand Down
Loading
Loading