Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Pre-commits
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Dec 13, 2024
1 parent 2838355 commit 1388f89
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 94 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ Keep it human-readable, your future self will thank you!
- Added dynamic NaN masking for the imputer class with two new classes: DynamicInputImputer, DynamicConstantImputer [#89](https://github.com/ecmwf/anemoi-models/pull/89)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)

- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88]

### Changed
- Bugfixes for CI
- Change Changelog CI to run after successful publish
Expand Down
15 changes: 11 additions & 4 deletions src/anemoi/models/preprocessing/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch


def cos_converter(x):
"""Convert angle in degree to cos."""
return torch.cos(x / 180 * torch.pi)
Expand All @@ -28,36 +29,42 @@ def atan2_converter(x):
"""
return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360)


def log1p_converter(x):
"""Convert positive var in to log(1+var)."""
return torch.log1p(x)


def boxcox_converter(x, lambd=0.5):
"""Convert positive var in to boxcox(var)."""
pos_lam = (torch.pow(x, lambd)-1)/lambd
pos_lam = (torch.pow(x, lambd) - 1) / lambd
null_lam = torch.log(x)
if lambd == 0:
return null_lam
else:
return pos_lam


def sqrt_converter(x):
"""Convert positive var in to sqrt(var)."""
return torch.sqrt(x)


def expm1_converter(x):
"""Convert back log(1+var) to var."""
return torch.expm1(x)


def square_converter(x):
"""Convert back sqrt(var) to var."""
return x ** 2
return x**2


def inverse_boxcox_converter(x, lambd=0.5):
"""Convert back boxcox(var) to var."""
pos_lam = torch.pow(x*lambd+1, 1/lambd)
pos_lam = torch.pow(x * lambd + 1, 1 / lambd)
null_lam = torch.exp(x)
if lambd == 0:
return null_lam
else:
return pos_lam
return pos_lam
22 changes: 8 additions & 14 deletions src/anemoi/models/preprocessing/monomapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing import BasePreprocessor
from anemoi.models.preprocessing.mappings import (
boxcox_converter,
expm1_converter,
inverse_boxcox_converter,
log1p_converter,
sqrt_converter,
square_converter,
)
from anemoi.models.preprocessing.mappings import boxcox_converter
from anemoi.models.preprocessing.mappings import expm1_converter
from anemoi.models.preprocessing.mappings import inverse_boxcox_converter
from anemoi.models.preprocessing.mappings import log1p_converter
from anemoi.models.preprocessing.mappings import sqrt_converter
from anemoi.models.preprocessing.mappings import square_converter

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,9 +60,7 @@ def __init__(
self._validate_indices()

def _validate_indices(self):
assert (
len(self.index_training) == len(self.index_inference) <= len(self.remappers)
), (
assert len(self.index_training) == len(self.index_inference) <= len(self.remappers), (
f"Error creating conversion indices {len(self.index_training)}, "
f"{len(self.index_inference)}, {len(self.remappers)}"
)
Expand Down Expand Up @@ -98,9 +94,7 @@ def _create_remapping_indices(
# this is a forcing variable. It is not in the inference output.
self.index_inference.append(None)
for name_dst in self.method_config[method][name]:
assert (
name_dst in self.data_indices.internal_data.input.name_to_index
), (
assert name_dst in self.data_indices.internal_data.input.name_to_index, (
f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. "
f"Remap {name} to {name_dst} in config.data.remapped. "
)
Expand Down
119 changes: 45 additions & 74 deletions src/anemoi/models/preprocessing/multimapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.preprocessing import BasePreprocessor
from anemoi.models.preprocessing.mappings import atan2_converter, cos_converter, sin_converter
from anemoi.models.preprocessing.mappings import atan2_converter
from anemoi.models.preprocessing.mappings import cos_converter
from anemoi.models.preprocessing.mappings import sin_converter

LOGGER = logging.getLogger(__name__)

Expand All @@ -32,14 +34,16 @@ class Multimapper(BasePreprocessor, ABC):
"mwd" : ["cos_mwd", "sin_mwd"]
```
"""

supported_methods = {
method: [f, inv]
for method, f, inv in zip(
["cos_sin"],
[[cos_converter, sin_converter]],
[atan2_converter],
)
}
method: [f, inv]
for method, f, inv in zip(
["cos_sin"],
[[cos_converter, sin_converter]],
[atan2_converter],
)
}

def __init__(
self,
config=None,
Expand All @@ -63,19 +67,15 @@ def __init__(
self._validate_indices()

def _validate_indices(self):
assert len(self.index_training_input) == len(
self.index_inference_input
) <= len(self.remappers), (
assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), (
f"Error creating conversion indices {len(self.index_training_input)}, "
f"{len(self.index_inference_input)}, {len(self.remappers)}")
assert len(self.index_training_output) == len(
self.index_inference_output
) <= len(self.remappers), (
f"{len(self.index_inference_input)}, {len(self.remappers)}"
)
assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), (
f"Error creating conversion indices {len(self.index_training_output)}, "
f"{len(self.index_inference_output)}, {len(self.remappers)}")
assert len(
set(self.index_training_input + self.indices_keep_training_input)
) == self.num_training_input_vars, (
f"{len(self.index_inference_output)}, {len(self.remappers)}"
)
assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, (
"Error creating conversion indices: variables remapped in config.data.remapped "
"that have no remapping function defined. Preprocessed tensors contains empty columns."
)
Expand All @@ -97,14 +97,10 @@ def _create_remapping_indices(

self.num_training_input_vars = len(name_to_index_training_input)
self.num_inference_input_vars = len(name_to_index_inference_input)
self.num_remapped_training_input_vars = len(
name_to_index_training_remapped_input)
self.num_remapped_inference_input_vars = len(
name_to_index_inference_remapped_input)
self.num_remapped_training_output_vars = len(
name_to_index_training_remapped_output)
self.num_remapped_inference_output_vars = len(
name_to_index_inference_remapped_output)
self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input)
self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input)
self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output)
self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output)
self.num_training_output_vars = len(name_to_index_training_output)
self.num_inference_output_vars = len(name_to_index_inference_output)
self.indices_keep_training_input = []
Expand Down Expand Up @@ -146,15 +142,11 @@ def _create_remapping_indices(
continue

if method == "cos_sin":
self.index_training_input.append(
name_to_index_training_input[name])
self.index_training_output.append(
name_to_index_training_output[name])
self.index_inference_input.append(
name_to_index_inference_input[name])
self.index_training_input.append(name_to_index_training_input[name])
self.index_training_output.append(name_to_index_training_output[name])
self.index_inference_input.append(name_to_index_inference_input[name])
if name in name_to_index_inference_output:
self.index_inference_output.append(
name_to_index_inference_output[name])
self.index_inference_output.append(name_to_index_inference_output[name])
else:
# this is a forcing variable. It is not in the inference output.
self.index_inference_output.append(None)
Expand All @@ -165,42 +157,29 @@ def _create_remapping_indices(
f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. "
f"Remap {name} to {name_dst} in config.data.remapped. "
)
multiple_training_input.append(
name_to_index_training_remapped_input[name_dst])
multiple_training_output.append(
name_to_index_training_remapped_output[name_dst])
multiple_inference_input.append(
name_to_index_inference_remapped_input[name_dst])
multiple_training_input.append(name_to_index_training_remapped_input[name_dst])
multiple_training_output.append(name_to_index_training_remapped_output[name_dst])
multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst])
if name_dst in name_to_index_inference_remapped_output:
multiple_inference_output.append(
name_to_index_inference_remapped_output[name_dst])
multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst])
else:
# this is a forcing variable. It is not in the inference output.
multiple_inference_output.append(None)

self.index_training_remapped_input.append(
multiple_training_input)
self.index_inference_remapped_input.append(
multiple_inference_input)
self.index_training_backmapped_output.append(
multiple_training_output)
self.index_inference_backmapped_output.append(
multiple_inference_output)
self.index_training_remapped_input.append(multiple_training_input)
self.index_inference_remapped_input.append(multiple_inference_input)
self.index_training_backmapped_output.append(multiple_training_output)
self.index_inference_backmapped_output.append(multiple_inference_output)

self.remappers.append([cos_converter, sin_converter])
self.backmappers.append(atan2_converter)

LOGGER.info(
f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}."
)
LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.")

else:
raise ValueError[
f"Unknown remapping method for {name}: {method}"]
raise ValueError[f"Unknown remapping method for {name}: {method}"]

def transform(self,
x: torch.Tensor,
in_place: bool = True) -> torch.Tensor:
def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Remap and convert the input tensor.
```
Expand Down Expand Up @@ -231,30 +210,25 @@ def transform(self,
)

# create new tensor with target number of columns
x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns, ),
dtype=x.dtype,
device=x.device)
x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device)
if in_place and not self.printed_preprocessor_warning:
LOGGER.warning(
"Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.",
)
self.printed_preprocessor_warning = True

# copy variables that are not remapped
x_remapped[..., :len(indices_keep)] = x[..., indices_keep]
x_remapped[..., : len(indices_keep)] = x[..., indices_keep]

# Remap variables
for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers,
index):
for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index):
if idx_src is not None:
for jj, ii in enumerate(idx_dst):
x_remapped[..., ii] = remapper[jj](x[..., idx_src])

return x_remapped

def inverse_transform(self,
x: torch.Tensor,
in_place: bool = True) -> torch.Tensor:
def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Convert and remap the output tensor.
```
Expand Down Expand Up @@ -285,21 +259,18 @@ def inverse_transform(self,
)

# create new tensor with target number of columns
x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns, ),
dtype=x.dtype,
device=x.device)
x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device)
if in_place and not self.printed_postprocessor_warning:
LOGGER.warning(
"Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.",
)
self.printed_postprocessor_warning = True

# copy variables that are not remapped
x_remapped[..., indices_keep] = x[..., :len(indices_keep)]
x_remapped[..., indices_keep] = x[..., : len(indices_keep)]

# Backmap variables
for idx_dst, backmapper, idx_src in zip(index, self.backmappers,
indices_remapped):
for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped):
if idx_dst is not None:
x_remapped[..., idx_dst] = backmapper(x[..., idx_src])

Expand Down
1 change: 0 additions & 1 deletion src/anemoi/models/preprocessing/remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from anemoi.models.preprocessing.monomapper import Monomapper
from anemoi.models.preprocessing.multimapper import Multimapper


LOGGER = logging.getLogger(__name__)


Expand Down

0 comments on commit 1388f89

Please sign in to comment.