Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature/flexible remappers #88

Merged
merged 22 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c8b07e6
Add log1p imputer
OpheliaMiralles Nov 19, 2024
fb18028
Less aggressive refactoring
OpheliaMiralles Nov 22, 2024
2838355
Add checks for inference/training indices, add base class
OpheliaMiralles Dec 13, 2024
1388f89
Pre-commits
OpheliaMiralles Dec 13, 2024
6e6b0e5
Add changelog
OpheliaMiralles Dec 13, 2024
b857c65
Tests+add train_loss_mask multimapper
OpheliaMiralles Dec 13, 2024
c65e9e4
Reformat
OpheliaMiralles Dec 13, 2024
ce791c0
Removed unnecessary brackets
OpheliaMiralles Dec 13, 2024
09bf216
Add test and case where mapped variable is a str not a list/dic
OpheliaMiralles Dec 16, 2024
ba61dc4
Pre-commits
OpheliaMiralles Dec 16, 2024
4d37d31
Merge branch 'develop' into feature/flexible-remappers
OpheliaMiralles Dec 16, 2024
7954ea1
Add forcings to test
OpheliaMiralles Dec 16, 2024
b9ed6e4
Pre-commits
OpheliaMiralles Dec 16, 2024
8041c70
Update src/anemoi/models/preprocessing/monomapper.py
OpheliaMiralles Dec 17, 2024
e96e8df
Update tests/preprocessing/test_preprocessor_remapper.py
OpheliaMiralles Dec 17, 2024
ce322f2
Update tests/preprocessing/test_preprocessor_remapper.py
OpheliaMiralles Dec 17, 2024
c0bfbb6
Refactor / add noop
OpheliaMiralles Dec 17, 2024
f551ed0
Remove bin
OpheliaMiralles Dec 17, 2024
e297b02
Pre commits
OpheliaMiralles Dec 17, 2024
f7c385b
Merge branch 'develop' into feature/flexible-remappers
OpheliaMiralles Dec 17, 2024
7a80df1
Update tests/preprocessing/test_preprocessor_remapper.py
OpheliaMiralles Dec 17, 2024
5782378
Add tests and refactor accordingly
OpheliaMiralles Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Keep it human-readable, your future self will thank you!
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)
- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88]
- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97)

### Changed
- Bugfixes for CI
Expand Down
19 changes: 13 additions & 6 deletions src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,32 @@ def __init__(

super().__init__()

self.default, self.method_config = self._process_config(config)
self.default, self.remap, self.method_config = self._process_config(config)
self.methods = self._invert_key_value_list(self.method_config)

self.data_indices = data_indices

def _process_config(self, config):
@classmethod
def _process_config(cls, config):
_special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method.
default = config.get("default", "none")
self.remap = config.get("remap", {})
remap = config.get("remap", {})
method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"}

if not method_config:
LOGGER.warning(
f"{self.__class__.__name__}: Using default method {default} for all variables not specified in the config.",
f"{cls.__name__}: Using default method {default} for all variables not specified in the config.",
)
for m in method_config:
if isinstance(method_config[m], str):
method_config[m] = {method_config[m]: f"{m}_{method_config[m]}"}
elif isinstance(method_config[m], list):
method_config[m] = {method: f"{m}_{method}" for method in method_config[m]}

return default, method_config
return default, remap, method_config

def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[str, str]:
@staticmethod
def _invert_key_value_list(method_config: dict[str, list[str]]) -> dict[str, str]:
"""Invert a dictionary of methods with lists of variables.

Parameters
Expand Down
75 changes: 75 additions & 0 deletions src/anemoi/models/preprocessing/mappings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import torch

sahahner marked this conversation as resolved.
Show resolved Hide resolved

def noop(x):
"""No operation."""
return x


def cos_converter(x):
"""Convert angle in degree to cos."""
return torch.cos(x / 180 * torch.pi)


def sin_converter(x):
"""Convert angle in degree to sin."""
return torch.sin(x / 180 * torch.pi)


def atan2_converter(x):
"""Convert cos and sin to angle in degree.

Input:
x[..., 0]: cos
x[..., 1]: sin
"""
return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360)


def log1p_converter(x):
"""Convert positive var in to log(1+var)."""
return torch.log1p(x)


def boxcox_converter(x, lambd=0.5):
"""Convert positive var in to boxcox(var)."""
pos_lam = (torch.pow(x, lambd) - 1) / lambd
null_lam = torch.log(x)
if lambd == 0:
return null_lam
else:
return pos_lam


def sqrt_converter(x):
"""Convert positive var in to sqrt(var)."""
return torch.sqrt(x)


def expm1_converter(x):
"""Convert back log(1+var) to var."""
return torch.expm1(x)


def square_converter(x):
"""Convert back sqrt(var) to var."""
return x**2


def inverse_boxcox_converter(x, lambd=0.5):
"""Convert back boxcox(var) to var."""
pos_lam = torch.pow(x * lambd + 1, 1 / lambd)
null_lam = torch.exp(x)
if lambd == 0:
return null_lam
else:
return pos_lam
150 changes: 150 additions & 0 deletions src/anemoi/models/preprocessing/monomapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
from abc import ABC
from typing import Optional

import torch

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing import BasePreprocessor
from anemoi.models.preprocessing.mappings import boxcox_converter
from anemoi.models.preprocessing.mappings import expm1_converter
from anemoi.models.preprocessing.mappings import inverse_boxcox_converter
from anemoi.models.preprocessing.mappings import log1p_converter
from anemoi.models.preprocessing.mappings import noop
from anemoi.models.preprocessing.mappings import sqrt_converter
from anemoi.models.preprocessing.mappings import square_converter

LOGGER = logging.getLogger(__name__)


class Monomapper(BasePreprocessor, ABC):
"""Remap and convert variables for single variables."""

supported_methods = {
method: [f, inv]
for method, f, inv in zip(
["log1p", "sqrt", "boxcox", "none"],
[log1p_converter, sqrt_converter, boxcox_converter, noop],
[expm1_converter, square_converter, inverse_boxcox_converter, noop],
)
}

def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
OpheliaMiralles marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(config, data_indices, statistics)
self._create_remapping_indices(statistics)
self._validate_indices()

def _validate_indices(self):
assert (
len(self.index_training_input)
== len(self.index_inference_input)
== len(self.index_inference_output)
== len(self.index_training_out)
== len(self.remappers)
), (
f"Error creating conversion indices {len(self.index_training_input)}, "
f"{len(self.index_inference_input)}, {len(self.index_training_input)}, {len(self.index_training_out)}, {len(self.remappers)}"
)

def _create_remapping_indices(
self,
statistics=None,
):
"""Create the parameter indices for remapping."""
# list for training and inference mode as position of parameters can change
name_to_index_training_input = self.data_indices.data.input.name_to_index
name_to_index_inference_input = self.data_indices.model.input.name_to_index
name_to_index_training_output = self.data_indices.data.output.name_to_index
name_to_index_inference_output = self.data_indices.model.output.name_to_index
self.num_training_input_vars = len(name_to_index_training_input)
self.num_inference_input_vars = len(name_to_index_inference_input)
self.num_training_output_vars = len(name_to_index_training_output)
self.num_inference_output_vars = len(name_to_index_inference_output)

(
self.remappers,
self.backmappers,
self.index_training_input,
self.index_training_out,
self.index_inference_input,
self.index_inference_output,
) = (
[],
[],
[],
[],
[],
[],
)

# Create parameter indices for remapping variables
for name in name_to_index_training_input:
method = self.methods.get(name, self.default)
if method in self.supported_methods:
self.remappers.append(self.supported_methods[method][0])
self.backmappers.append(self.supported_methods[method][1])
self.index_training_input.append(name_to_index_training_input[name])
if name in name_to_index_training_output:
self.index_training_out.append(name_to_index_training_output[name])
else:
self.index_training_out.append(None)
if name in name_to_index_inference_input:
self.index_inference_input.append(name_to_index_inference_input[name])
else:
self.index_inference_input.append(None)
if name in name_to_index_inference_output:
self.index_inference_output.append(name_to_index_inference_output[name])
else:
# this is a forcing variable. It is not in the inference output.
self.index_inference_output.append(None)
else:
raise KeyError[f"Unknown remapping method for {name}: {method}"]

def transform(self, x, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()
if x.shape[-1] == self.num_training_input_vars:
idx = self.index_training_input
elif x.shape[-1] == self.num_inference_input_vars:
idx = self.index_inference_input
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)
for i, remapper in zip(idx, self.remappers):
if i is not None:
x[..., i] = remapper(x[..., i])
return x

def inverse_transform(self, x, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()
if x.shape[-1] == self.num_training_output_vars:
idx = self.index_training_out
elif x.shape[-1] == self.num_inference_output_vars:
idx = self.index_inference_output
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})",
)
for i, backmapper in zip(idx, self.backmappers):
if i is not None:
x[..., i] = backmapper(x[..., i])
return x
Loading
Loading