diff --git a/CHANGELOG.md b/CHANGELOG.md index a39f43d2..ac7eae89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ Keep it human-readable, your future self will thank you! - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) - New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) - Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69) +- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88] +- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) ### Changed - Bugfixes for CI diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index 4c7f6153..e26505f3 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -57,25 +57,32 @@ def __init__( super().__init__() - self.default, self.method_config = self._process_config(config) + self.default, self.remap, self.method_config = self._process_config(config) self.methods = self._invert_key_value_list(self.method_config) self.data_indices = data_indices - def _process_config(self, config): + @classmethod + def _process_config(cls, config): _special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method. default = config.get("default", "none") - self.remap = config.get("remap", {}) + remap = config.get("remap", {}) method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"} if not method_config: LOGGER.warning( - f"{self.__class__.__name__}: Using default method {default} for all variables not specified in the config.", + f"{cls.__name__}: Using default method {default} for all variables not specified in the config.", ) + for m in method_config: + if isinstance(method_config[m], str): + method_config[m] = {method_config[m]: f"{m}_{method_config[m]}"} + elif isinstance(method_config[m], list): + method_config[m] = {method: f"{m}_{method}" for method in method_config[m]} - return default, method_config + return default, remap, method_config - def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[str, str]: + @staticmethod + def _invert_key_value_list(method_config: dict[str, list[str]]) -> dict[str, str]: """Invert a dictionary of methods with lists of variables. Parameters diff --git a/src/anemoi/models/preprocessing/mappings.py b/src/anemoi/models/preprocessing/mappings.py new file mode 100644 index 00000000..dab46734 --- /dev/null +++ b/src/anemoi/models/preprocessing/mappings.py @@ -0,0 +1,75 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch + + +def noop(x): + """No operation.""" + return x + + +def cos_converter(x): + """Convert angle in degree to cos.""" + return torch.cos(x / 180 * torch.pi) + + +def sin_converter(x): + """Convert angle in degree to sin.""" + return torch.sin(x / 180 * torch.pi) + + +def atan2_converter(x): + """Convert cos and sin to angle in degree. + + Input: + x[..., 0]: cos + x[..., 1]: sin + """ + return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) + + +def log1p_converter(x): + """Convert positive var in to log(1+var).""" + return torch.log1p(x) + + +def boxcox_converter(x, lambd=0.5): + """Convert positive var in to boxcox(var).""" + pos_lam = (torch.pow(x, lambd) - 1) / lambd + null_lam = torch.log(x) + if lambd == 0: + return null_lam + else: + return pos_lam + + +def sqrt_converter(x): + """Convert positive var in to sqrt(var).""" + return torch.sqrt(x) + + +def expm1_converter(x): + """Convert back log(1+var) to var.""" + return torch.expm1(x) + + +def square_converter(x): + """Convert back sqrt(var) to var.""" + return x**2 + + +def inverse_boxcox_converter(x, lambd=0.5): + """Convert back boxcox(var) to var.""" + pos_lam = torch.pow(x * lambd + 1, 1 / lambd) + null_lam = torch.exp(x) + if lambd == 0: + return null_lam + else: + return pos_lam diff --git a/src/anemoi/models/preprocessing/monomapper.py b/src/anemoi/models/preprocessing/monomapper.py new file mode 100644 index 00000000..0359a4c3 --- /dev/null +++ b/src/anemoi/models/preprocessing/monomapper.py @@ -0,0 +1,150 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.mappings import boxcox_converter +from anemoi.models.preprocessing.mappings import expm1_converter +from anemoi.models.preprocessing.mappings import inverse_boxcox_converter +from anemoi.models.preprocessing.mappings import log1p_converter +from anemoi.models.preprocessing.mappings import noop +from anemoi.models.preprocessing.mappings import sqrt_converter +from anemoi.models.preprocessing.mappings import square_converter + +LOGGER = logging.getLogger(__name__) + + +class Monomapper(BasePreprocessor, ABC): + """Remap and convert variables for single variables.""" + + supported_methods = { + method: [f, inv] + for method, f, inv in zip( + ["log1p", "sqrt", "boxcox", "none"], + [log1p_converter, sqrt_converter, boxcox_converter, noop], + [expm1_converter, square_converter, inverse_boxcox_converter, noop], + ) + } + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + self._create_remapping_indices(statistics) + self._validate_indices() + + def _validate_indices(self): + assert ( + len(self.index_training_input) + == len(self.index_inference_input) + == len(self.index_inference_output) + == len(self.index_training_out) + == len(self.remappers) + ), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.index_training_input)}, {len(self.index_training_out)}, {len(self.remappers)}" + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + + ( + self.remappers, + self.backmappers, + self.index_training_input, + self.index_training_out, + self.index_inference_input, + self.index_inference_output, + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + method = self.methods.get(name, self.default) + if method in self.supported_methods: + self.remappers.append(self.supported_methods[method][0]) + self.backmappers.append(self.supported_methods[method][1]) + self.index_training_input.append(name_to_index_training_input[name]) + if name in name_to_index_training_output: + self.index_training_out.append(name_to_index_training_output[name]) + else: + self.index_training_out.append(None) + if name in name_to_index_inference_input: + self.index_inference_input.append(name_to_index_inference_input[name]) + else: + self.index_inference_input.append(None) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + else: + raise KeyError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x, in_place: bool = True) -> torch.Tensor: + if not in_place: + x = x.clone() + if x.shape[-1] == self.num_training_input_vars: + idx = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + idx = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + for i, remapper in zip(idx, self.remappers): + if i is not None: + x[..., i] = remapper(x[..., i]) + return x + + def inverse_transform(self, x, in_place: bool = True) -> torch.Tensor: + if not in_place: + x = x.clone() + if x.shape[-1] == self.num_training_output_vars: + idx = self.index_training_out + elif x.shape[-1] == self.num_inference_output_vars: + idx = self.index_inference_output + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})", + ) + for i, backmapper in zip(idx, self.backmappers): + if i is not None: + x[..., i] = backmapper(x[..., i]) + return x diff --git a/src/anemoi/models/preprocessing/multimapper.py b/src/anemoi/models/preprocessing/multimapper.py new file mode 100644 index 00000000..f7772e48 --- /dev/null +++ b/src/anemoi/models/preprocessing/multimapper.py @@ -0,0 +1,306 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.mappings import atan2_converter +from anemoi.models.preprocessing.mappings import cos_converter +from anemoi.models.preprocessing.mappings import sin_converter + +LOGGER = logging.getLogger(__name__) + + +class Multimapper(BasePreprocessor, ABC): + """Remap single variable to 2 or more variables, or the other way around. + + cos_sin: + Remap the variable to cosine and sine. + Map output back to degrees. + + ``` + cos_sin: + "mwd" : ["cos_mwd", "sin_mwd"] + ``` + """ + + supported_methods = { + method: [f, inv] + for method, f, inv in zip( + ["cos_sin"], + [[cos_converter, sin_converter]], + [atan2_converter], + ) + } + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + """Initialize the remapper. + + Parameters + ---------- + config : DotDict + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables + statistics : dict + Data statistics dictionary + """ + super().__init__(config, data_indices, statistics) + self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False + self._create_remapping_indices(statistics) + self._validate_indices() + + def _validate_indices(self): + assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.remappers)}" + ) + assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_output)}, " + f"{len(self.index_inference_output)}, {len(self.remappers)}" + ) + assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( + "Error creating conversion indices: variables remapped in config.data.remapped " + "that have no remapping function defined. Preprocessed tensors contains empty columns." + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index + name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index + name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index + name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) + self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) + self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) + self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + self.indices_keep_training_input = [] + for key, item in self.data_indices.data.input.name_to_index.items(): + if key in self.data_indices.internal_data.input.name_to_index: + self.indices_keep_training_input.append(item) + self.indices_keep_inference_input = [] + for key, item in self.data_indices.model.input.name_to_index.items(): + if key in self.data_indices.internal_model.input.name_to_index: + self.indices_keep_inference_input.append(item) + self.indices_keep_training_output = [] + for key, item in self.data_indices.data.output.name_to_index.items(): + if key in self.data_indices.internal_data.output.name_to_index: + self.indices_keep_training_output.append(item) + self.indices_keep_inference_output = [] + for key, item in self.data_indices.model.output.name_to_index.items(): + if key in self.data_indices.internal_model.output.name_to_index: + self.indices_keep_inference_output.append(item) + + ( + self.index_training_input, + self.index_training_remapped_input, + self.index_inference_input, + self.index_inference_remapped_input, + self.index_training_output, + self.index_training_backmapped_output, + self.index_inference_output, + self.index_inference_backmapped_output, + self.remappers, + self.backmappers, + ) = ([], [], [], [], [], [], [], [], [], []) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + + method = self.methods.get(name, self.default) + + if method == "none": + continue + + if method == "cos_sin": + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output[name]) + self.index_inference_input.append(name_to_index_inference_input[name]) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + multiple_training_output, multiple_inference_output = [], [] + multiple_training_input, multiple_inference_input = [], [] + for name_dst in self.method_config[method][name]: + assert name_dst in self.data_indices.internal_data.input.name_to_index, ( + f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " + f"Remap {name} to {name_dst} in config.data.remapped. " + ) + multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) + multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) + multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) + if name_dst in name_to_index_inference_remapped_output: + multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) + else: + # this is a forcing variable. It is not in the inference output. + multiple_inference_output.append(None) + + self.index_training_remapped_input.append(multiple_training_input) + self.index_inference_remapped_input.append(multiple_inference_input) + self.index_training_backmapped_output.append(multiple_training_output) + self.index_inference_backmapped_output.append(multiple_inference_output) + + self.remappers.append([cos_converter, sin_converter]) + self.backmappers.append(atan2_converter) + + LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") + + else: + raise ValueError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Remap and convert the input tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this preprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + indices_remapped = self.index_training_remapped_input + indices_keep = self.indices_keep_training_input + target_number_columns = self.num_remapped_training_input_vars + + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + indices_remapped = self.index_inference_remapped_input + indices_keep = self.indices_keep_inference_input + target_number_columns = self.num_remapped_inference_input_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_preprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_preprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., : len(indices_keep)] = x[..., indices_keep] + + # Remap variables + for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): + if idx_src is not None: + for jj, ii in enumerate(idx_dst): + x_remapped[..., ii] = remapper[jj](x[..., idx_src]) + + return x_remapped + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Convert and remap the output tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this postprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_remapped_training_output_vars: + index = self.index_training_output + indices_remapped = self.index_training_backmapped_output + indices_keep = self.indices_keep_training_output + target_number_columns = self.num_training_output_vars + + elif x.shape[-1] == self.num_remapped_inference_output_vars: + index = self.index_inference_output + indices_remapped = self.index_inference_backmapped_output + indices_keep = self.indices_keep_inference_output + target_number_columns = self.num_inference_output_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_postprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_postprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., indices_keep] = x[..., : len(indices_keep)] + + # Backmap variables + for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): + if idx_dst is not None: + x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) + + return x_remapped + + def transform_loss_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Remap the loss mask. + + ``` + x : torch.Tensor + Loss mask + ``` + """ + # use indices at model output level + index = self.index_inference_backmapped_output + indices_remapped = self.index_inference_output + indices_keep = self.indices_keep_inference_output + + # create new loss mask with target number of columns + mask_remapped = torch.zeros( + mask.shape[:-1] + (mask.shape[-1] + len(indices_remapped),), dtype=mask.dtype, device=mask.device + ) + + # copy loss mask for variables that are not remapped + mask_remapped[..., : len(indices_keep)] = mask[..., indices_keep] + + # remap loss mask for rest of variables + for idx_src, idx_dst in zip(indices_remapped, index): + if idx_dst is not None: + for ii in idx_dst: + mask_remapped[..., ii] = mask[..., idx_src] + + return mask_remapped diff --git a/src/anemoi/models/preprocessing/remapper.py b/src/anemoi/models/preprocessing/remapper.py index 7cb15a66..c3f39c2e 100644 --- a/src/anemoi/models/preprocessing/remapper.py +++ b/src/anemoi/models/preprocessing/remapper.py @@ -12,319 +12,36 @@ from abc import ABC from typing import Optional -import torch - from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.monomapper import Monomapper +from anemoi.models.preprocessing.multimapper import Multimapper LOGGER = logging.getLogger(__name__) -def cos_converter(x): - """Convert angle in degree to cos.""" - return torch.cos(x / 180 * torch.pi) - - -def sin_converter(x): - """Convert angle in degree to sin.""" - return torch.sin(x / 180 * torch.pi) - - -def atan2_converter(x): - """Convert cos and sin to angle in degree. - - Input: - x[..., 0]: cos - x[..., 1]: sin - """ - return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) - - -class BaseRemapperVariable(BasePreprocessor, ABC): - """Base class for Remapping Variables.""" +class Remapper(BasePreprocessor, ABC): + """Remap and convert variables for single variables.""" - def __init__( - self, + def __new__( + cls, config=None, data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, ) -> None: - """Initialize the remapper. - - Parameters - ---------- - config : DotDict - configuration object of the processor - data_indices : IndexCollection - Data indices for input and output variables - statistics : dict - Data statistics dictionary - """ - super().__init__(config, data_indices, statistics) - - def _validate_indices(self): - assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( - f"Error creating conversion indices {len(self.index_training_input)}, " - f"{len(self.index_inference_input)}, {len(self.remappers)}" - ) - assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( - f"Error creating conversion indices {len(self.index_training_output)}, " - f"{len(self.index_inference_output)}, {len(self.remappers)}" - ) - assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( - "Error creating conversion indices: variables remapped in config.data.remapped " - "that have no remapping function defined. Preprocessed tensors contains empty columns." - ) - - def _create_remapping_indices( - self, - statistics=None, - ): - """Create the parameter indices for remapping.""" - # list for training and inference mode as position of parameters can change - name_to_index_training_input = self.data_indices.data.input.name_to_index - name_to_index_inference_input = self.data_indices.model.input.name_to_index - name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index - name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index - name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index - name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index - name_to_index_training_output = self.data_indices.data.output.name_to_index - name_to_index_inference_output = self.data_indices.model.output.name_to_index - - self.num_training_input_vars = len(name_to_index_training_input) - self.num_inference_input_vars = len(name_to_index_inference_input) - self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) - self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) - self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) - self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) - self.num_training_output_vars = len(name_to_index_training_output) - self.num_inference_output_vars = len(name_to_index_inference_output) - self.indices_keep_training_input = [] - for key, item in self.data_indices.data.input.name_to_index.items(): - if key in self.data_indices.internal_data.input.name_to_index: - self.indices_keep_training_input.append(item) - self.indices_keep_inference_input = [] - for key, item in self.data_indices.model.input.name_to_index.items(): - if key in self.data_indices.internal_model.input.name_to_index: - self.indices_keep_inference_input.append(item) - self.indices_keep_training_output = [] - for key, item in self.data_indices.data.output.name_to_index.items(): - if key in self.data_indices.internal_data.output.name_to_index: - self.indices_keep_training_output.append(item) - self.indices_keep_inference_output = [] - for key, item in self.data_indices.model.output.name_to_index.items(): - if key in self.data_indices.internal_model.output.name_to_index: - self.indices_keep_inference_output.append(item) - - ( - self.index_training_input, - self.index_training_remapped_input, - self.index_inference_input, - self.index_inference_remapped_input, - self.index_training_output, - self.index_training_backmapped_output, - self.index_inference_output, - self.index_inference_backmapped_output, - self.remappers, - self.backmappers, - ) = ([], [], [], [], [], [], [], [], [], []) - - # Create parameter indices for remapping variables - for name in name_to_index_training_input: - - method = self.methods.get(name, self.default) - - if method == "none": - continue - - if method == "cos_sin": - self.index_training_input.append(name_to_index_training_input[name]) - self.index_training_output.append(name_to_index_training_output[name]) - self.index_inference_input.append(name_to_index_inference_input[name]) - if name in name_to_index_inference_output: - self.index_inference_output.append(name_to_index_inference_output[name]) - else: - # this is a forcing variable. It is not in the inference output. - self.index_inference_output.append(None) - multiple_training_output, multiple_inference_output = [], [] - multiple_training_input, multiple_inference_input = [], [] - for name_dst in self.method_config[method][name]: - assert name_dst in self.data_indices.internal_data.input.name_to_index, ( - f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " - f"Remap {name} to {name_dst} in config.data.remapped. " - ) - multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) - multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) - multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) - if name_dst in name_to_index_inference_remapped_output: - multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) - else: - # this is a forcing variable. It is not in the inference output. - multiple_inference_output.append(None) - - self.index_training_remapped_input.append(multiple_training_input) - self.index_inference_remapped_input.append(multiple_inference_input) - self.index_training_backmapped_output.append(multiple_training_output) - self.index_inference_backmapped_output.append(multiple_inference_output) - - self.remappers.append([cos_converter, sin_converter]) - self.backmappers.append(atan2_converter) - - LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") - - else: - raise ValueError[f"Unknown remapping method for {name}: {method}"] - - def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Remap and convert the input tensor. - - ``` - x : torch.Tensor - Input tensor - in_place : bool - Whether to process the tensor in place. - in_place is not possible for this preprocessor. - ``` - """ - # Choose correct index based on number of variables - if x.shape[-1] == self.num_training_input_vars: - index = self.index_training_input - indices_remapped = self.index_training_remapped_input - indices_keep = self.indices_keep_training_input - target_number_columns = self.num_remapped_training_input_vars - - elif x.shape[-1] == self.num_inference_input_vars: - index = self.index_inference_input - indices_remapped = self.index_inference_remapped_input - indices_keep = self.indices_keep_inference_input - target_number_columns = self.num_remapped_inference_input_vars - + _, _, method_config = cls._process_config(config) + monomappings = Monomapper.supported_methods + multimappings = Multimapper.supported_methods + if all(method in monomappings for method in method_config): + return Monomapper(config, data_indices, statistics) + elif all(method in multimappings for method in method_config): + return Multimapper(config, data_indices, statistics) + elif not ( + any(method in monomappings for method in method_config) + or any(method in multimappings for method in method_config) + ): + raise ValueError("No valid remapping method found.") else: - raise ValueError( - f"Input tensor ({x.shape[-1]}) does not match the training " - f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", - ) - - # create new tensor with target number of columns - x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) - if in_place and not self.printed_preprocessor_warning: - LOGGER.warning( - "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + raise NotImplementedError( + f"Not implemented: method_config contains a mix of monomapper and multimapper methods: {list(method_config.keys())}" ) - self.printed_preprocessor_warning = True - - # copy variables that are not remapped - x_remapped[..., : len(indices_keep)] = x[..., indices_keep] - - # Remap variables - for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): - if idx_src is not None: - for jj, ii in enumerate(idx_dst): - x_remapped[..., ii] = remapper[jj](x[..., idx_src]) - - return x_remapped - - def transform_loss_mask(self, mask: torch.Tensor) -> torch.Tensor: - """Remap the loss mask. - - ``` - x : torch.Tensor - Loss mask - ``` - """ - # use indices at model output level - index = self.index_inference_backmapped_output - indices_remapped = self.index_inference_output - indices_keep = self.indices_keep_inference_output - - # create new loss mask with target number of columns - mask_remapped = torch.zeros( - mask.shape[:-1] + (mask.shape[-1] + len(indices_remapped),), dtype=mask.dtype, device=mask.device - ) - - # copy loss mask for variables that are not remapped - mask_remapped[..., : len(indices_keep)] = mask[..., indices_keep] - - # remap loss mask for rest of variables - for idx_src, idx_dst in zip(indices_remapped, index): - if idx_dst is not None: - for ii in idx_dst: - mask_remapped[..., ii] = mask[..., idx_src] - - return mask_remapped - - def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Convert and remap the output tensor. - - ``` - x : torch.Tensor - Input tensor - in_place : bool - Whether to process the tensor in place. - in_place is not possible for this postprocessor. - ``` - """ - # Choose correct index based on number of variables - if x.shape[-1] == self.num_remapped_training_output_vars: - index = self.index_training_output - indices_remapped = self.index_training_backmapped_output - indices_keep = self.indices_keep_training_output - target_number_columns = self.num_training_output_vars - - elif x.shape[-1] == self.num_remapped_inference_output_vars: - index = self.index_inference_output - indices_remapped = self.index_inference_backmapped_output - indices_keep = self.indices_keep_inference_output - target_number_columns = self.num_inference_output_vars - - else: - raise ValueError( - f"Input tensor ({x.shape[-1]}) does not match the training " - f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", - ) - - # create new tensor with target number of columns - x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) - if in_place and not self.printed_postprocessor_warning: - LOGGER.warning( - "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", - ) - self.printed_postprocessor_warning = True - - # copy variables that are not remapped - x_remapped[..., indices_keep] = x[..., : len(indices_keep)] - - # Backmap variables - for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): - if idx_dst is not None: - x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) - - return x_remapped - - -class Remapper(BaseRemapperVariable): - """Remap and convert variables. - - cos_sin: - Remap the variable to cosine and sine. - Map output back to degrees. - - ``` - cos_sin: - "mwd" : ["cos_mwd", "sin_mwd"] - ``` - """ - - def __init__( - self, - config=None, - data_indices: Optional[IndexCollection] = None, - statistics: Optional[dict] = None, - ) -> None: - super().__init__(config, data_indices, statistics) - - self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False - - self._create_remapping_indices(statistics) - - self._validate_indices() diff --git a/tests/preprocessing/test_preprocessor_remapper.py b/tests/preprocessing/test_preprocessor_remapper.py index bbb3a168..6d27f906 100644 --- a/tests/preprocessing/test_preprocessor_remapper.py +++ b/tests/preprocessing/test_preprocessor_remapper.py @@ -43,6 +43,27 @@ def input_remapper(): return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) +@pytest.fixture() +def input_remapper_1d(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": { + "log1p": "d", + "sqrt": "q", + }, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) + + @pytest.fixture() def input_imputer(): config = DictConfig( @@ -74,19 +95,30 @@ def input_imputer(): def test_remap_not_inplace(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) input_remapper(x, in_place=False) - assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])) + assert torch.allclose( + x, + torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]), + ) def test_remap(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) expected_output = torch.Tensor( - [[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]] + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], + [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795], + ] ) assert torch.allclose(input_remapper.transform(x), expected_output) def test_inverse_transform(input_remapper) -> None: - x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]]) + x = torch.Tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], + [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795], + ] + ) expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose(input_remapper.inverse_transform(x), expected_output) @@ -94,7 +126,8 @@ def test_inverse_transform(input_remapper) -> None: def test_remap_inverse_transform(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose( - input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x + input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), + x, ) @@ -106,3 +139,64 @@ def test_transform_loss_mask(input_imputer, input_remapper) -> None: loss_mask_training = input_imputer.loss_mask_training loss_mask_training = input_remapper.transform_loss_mask(loss_mask_training) assert torch.allclose(loss_mask_training, expected_output) + + +def test_monomap_transform(input_remapper_1d) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + expected_output = torch.Tensor( + [ + [1.0, 2.0, 3.0, np.sqrt(4.0), np.log1p(150.0), 5.0], + [6.0, 7.0, 8.0, np.sqrt(9.0), np.log1p(201.0), 10.0], + ] + ) + assert torch.allclose(input_remapper_1d.transform(x, in_place=False), expected_output) + # inference mode (without prognostic variables) + assert torch.allclose( + input_remapper_1d.transform( + x[..., input_remapper_1d.data_indices.data.todict()["input"]["full"]], in_place=False + ), + expected_output[..., input_remapper_1d.data_indices.data.todict()["input"]["full"]], + ) + # this one actually changes the values in x so need to be last + assert torch.allclose(input_remapper_1d.transform(x), expected_output) + + +def test_monomap_inverse_transform(input_remapper_1d) -> None: + expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + y = torch.Tensor( + [ + [1.0, 2.0, 3.0, np.sqrt(4.0), np.log1p(150.0), 5.0], + [6.0, 7.0, 8.0, np.sqrt(9.0), np.log1p(201.0), 10.0], + ] + ) + assert torch.allclose(input_remapper_1d.inverse_transform(y, in_place=False), expected_output) + # inference mode (without prognostic variables) + assert torch.allclose( + input_remapper_1d.inverse_transform( + y[..., input_remapper_1d.data_indices.data.todict()["output"]["full"]], in_place=False + ), + expected_output[..., input_remapper_1d.data_indices.data.todict()["output"]["full"]], + ) + + +def test_unsupported_remapper(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": {"log1p": "q", "cos_sin": "d"}, + "forcing": [], + "diagnostic": [], + }, + } + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "q": 2, "d": 3} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + + with pytest.raises(NotImplementedError): + Remapper( + config=config.data.remapper, + data_indices=data_indices, + statistics=statistics, + )