Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Callbacks #60

Open
wants to merge 47 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b12fac8
Refactor Callbacks
HCookie Sep 24, 2024
29a8477
Update changelog
HCookie Sep 24, 2024
15824be
Fix TypeError
HCookie Sep 24, 2024
4077bf4
Move to hydra.instantiate
HCookie Sep 25, 2024
494d39d
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Sep 25, 2024
fe37c02
Add __all__
HCookie Sep 25, 2024
2d8275c
Add to base config
HCookie Sep 25, 2024
230eb0e
Fix nested list
HCookie Sep 25, 2024
5547b20
Fix nested get issue
HCookie Sep 26, 2024
1d80cfb
Fix type checking
HCookie Sep 27, 2024
e79dfc7
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 1, 2024
96ab74c
feat: edge plot in callbacks
JPXKQX Oct 1, 2024
4aeb1a5
feat: set default extra callbacks
JPXKQX Oct 1, 2024
816b3af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2024
644038f
fix: typing & refactoring
JPXKQX Oct 2, 2024
8356cd4
fix: remove list comprehension
JPXKQX Oct 2, 2024
930e4d2
Refactor according to PR
HCookie Oct 2, 2024
52ea91f
Update deprecation warning
HCookie Oct 4, 2024
0dd81b7
Merge branch 'fxi/refactor_callbacks' into feature/graph-features-cal…
JPXKQX Oct 4, 2024
332f746
Merge pull request #71 from ecmwf/feature/graph-features-callback
HCookie Oct 4, 2024
bb8b9bb
Refactor: Remove backwards compatability,
HCookie Oct 10, 2024
0349be2
Fix tests
HCookie Oct 10, 2024
1e97ff1
PR Fixes
HCookie Oct 15, 2024
d7f713e
Merge branch 'develop' into fix/refactor_callbacks
HCookie Oct 18, 2024
ebfaf90
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 18, 2024
460c8ba
Update Changelog
HCookie Oct 18, 2024
5671c7e
Merge branch 'develop' into fix/refactor_callbacks
HCookie Oct 21, 2024
21c05de
Refactor rollout (#87)
HCookie Oct 21, 2024
3c5e144
Remove batch frequency from LongRolloutPlots
HCookie Oct 21, 2024
5742754
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 21, 2024
8671543
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 22, 2024
382728c
Remove TP reference
HCookie Oct 22, 2024
6fa66cc
Remove missing config reference
HCookie Oct 23, 2024
110fb64
Swapped histogram and spectrum
HCookie Oct 23, 2024
23cc785
Update copyright notice
HCookie Oct 23, 2024
bfe76f3
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 23, 2024
5a6880e
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 24, 2024
51a455d
Fix issues with split of PlotAdditionalMetrics
HCookie Oct 24, 2024
3318675
Merge branch 'fxi/refactor_callbacks' of github.com:ecmwf/anemoi-trai…
HCookie Oct 24, 2024
77bd65d
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 24, 2024
3c6e1af
Fix CHANGELOG
HCookie Oct 25, 2024
86059a9
Fix documentation for callbacks
HCookie Oct 25, 2024
0bce490
Add all callback submodules to docs
HCookie Oct 25, 2024
f5057c6
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 25, 2024
d6e1d9c
Apply suggestions from code review
HCookie Oct 25, 2024
6073d84
Fix init args issue in RolloutPlots
HCookie Oct 25, 2024
f1d883f
Add rollout_eval config
HCookie Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ Keep it human-readable, your future self will thank you!
- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79)

### Fixed
- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60)
HCookie marked this conversation as resolved.
Show resolved Hide resolved
- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)
- Enable longer validation rollout than training
- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83)
- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100

Expand Down
13 changes: 6 additions & 7 deletions docs/modules/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@ functionality to use both Weights & Biases and Tensorboard.

The callbacks can also be used to evaluate forecasts over longer
rollouts beyond the forecast time that the model is trained on. The
number of rollout steps (or forecast iteration steps) is set using
``config.eval.rollout = *num_of_rollout_steps*``.
number of rollout steps for verification (or forecast iteration steps)
is set using ``config.dataloader.validation_rollout =
*num_of_rollout_steps*``.

Note the user has the option to evaluate the callbacks asynchronously
(using the following config option
``config.diagnostics.plot.asynchronous``, which means that the model
training doesn't stop whilst the callbacks are being evaluated).
However, note that callbacks can still be slow, and therefore the
plotting callbacks can be switched off by setting
``config.diagnostics.plot.enabled`` to ``False`` or all the callbacks
can be completely switched off by setting
``config.diagnostics.eval.enabled`` to ``False``.
Callbacks are configured in the config file under the
``config.diagnostics.callbacks`` key, and plotting callbacks under the
``config.diagnostics.plot`` key.

Below is the documentation for the default callbacks provided, but it is
also possible for users to add callbacks using the same structure:
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- data: zarr
- dataloader: native_grid
- diagnostics: eval_rollout
- diagnostics: evaluation
- hardware: example
- graph: multi_scale
- model: gnn
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ training:
frequency: ${data.frequency}
drop: []

validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks

validation:
dataset: ${dataloader.dataset}
start: 2021
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Add callbacks here
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Add callbacks here
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
rollout: ${dataloader.validation_rollout}
frequency: 20
Original file line number Diff line number Diff line change
@@ -1,53 +1,8 @@
---
eval:
enabled: False
# use this to evaluate the model over longer rollouts, every so many validation batches
rollout: 12
frequency: 20
plot:
enabled: True
asynchronous: True
frequency: 750
sample_idx: 0
per_sample: 6
parameters:
- z_500
- t_850
- u_850
- v_850
- 2t
- 10u
- 10v
- sp
- tp
- cp
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
precip_and_related_fields: [tp, cp]
# Histogram and Spectrum plots
parameters_histogram:
- z_500
- tp
- 2t
- 10u
- 10v
parameters_spectrum:
- z_500
- tp
- 2t
- 10u
- 10v
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
learned_features: False
longrollout:
enabled: False
rollout: [60]
frequency: 20 # every X epochs
defaults:
mc4117 marked this conversation as resolved.
Show resolved Hide resolved
- plot: detailed
- callbacks: pretraining


debug:
# this will detect and trace back NaNs / Infs etc. but will slow down training
Expand All @@ -57,6 +12,7 @@ debug:
# remember to also activate the tensorboard logger (below)
profiler: False

enable_checkpointing: True
checkpoint:
every_n_minutes:
save_frequency: 30 # Approximate, as this is checked at the end of training steps
Expand Down
67 changes: 67 additions & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
asynchronous: True # Whether to plot asynchronously
frequency: # Frequency of the plotting
batch: 750
epoch: 5

# Parameters to plot
parameters:
- z_500
- t_850
- u_850
- v_850
- 2t
- 10u
- 10v
- sp
- tp
- cp

# Sample index
sample_idx: 0

# Precipitation and related fields
precip_and_related_fields: [tp, cp]

callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot
epoch_frequency: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
HCookie marked this conversation as resolved.
Show resolved Hide resolved
parameters: ${diagnostics.plot.parameters}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}

- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# batch_frequency: 100 # Override for batch frequency
sample_idx: ${diagnostics.plot.sample_idx}
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
parameters:
- z_500
- tp
- 2t
- 10u
- 10v
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram
sample_idx: ${diagnostics.plot.sample_idx}
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
parameters:
- z_500
- tp
- 2t
- 10u
- 10v
- _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as dataloader.validation_rollout=1, which is the default, this callback only increases the runtime without providing any additional plots. Should we move it into rollout_eval.yaml?

Copy link
Member Author

@HCookie HCookie Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could provide a rollout plots configuration?
Addressed in f1d883f

rollout:
- ${dataloader.validation_rollout}
epoch_frequency: 20
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/none.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
callbacks: []
40 changes: 40 additions & 0 deletions src/anemoi/training/config/diagnostics/plot/simple.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
asynchronous: True # Whether to plot asynchronously
frequency: # Frequency of the plotting
batch: 750
epoch: 10

# Parameters to plot
parameters:
- z_500
- t_850
- u_850
- v_850
- 2t
- 10u
- 10v
- sp
- tp
- cp

# Sample index
sample_idx: 0

# Precipitation and related fields
precip_and_related_fields: [tp, cp]

callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
6 changes: 2 additions & 4 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,8 @@ def ds_train(self) -> NativeGridDataset:
@cached_property
def ds_valid(self) -> NativeGridDataset:
r = self.rollout
if self.config.diagnostics.eval.enabled:
r = max(r, self.config.diagnostics.eval.rollout)
if self.config.diagnostics.plot.get("longrollout") and self.config.diagnostics.plot.longrollout.enabled:
r = max(r, max(self.config.diagnostics.plot.longrollout.rollout))
r = max(r, self.config.dataloader.get("validation_rollout", 1))

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
f"validation start date {self.config.dataloader.validation.start}"
Expand Down
Loading
Loading