diff --git a/src/anemoi/models/preprocessing/remapper.py b/src/anemoi/models/preprocessing/remapper.py index 7ffee0d..ec315c0 100644 --- a/src/anemoi/models/preprocessing/remapper.py +++ b/src/anemoi/models/preprocessing/remapper.py @@ -177,7 +177,16 @@ def _create_remapping_indices( raise ValueError[f"Unknown remapping method for {name}: {method}"] def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Remap and convert the input tensor.""" + """Remap and convert the input tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this preprocessor. + ``` + """ # Choose correct index based on number of variables if x.shape[-1] == self.num_training_input_vars: index = self.index_training_input @@ -199,6 +208,12 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # create new tensor with target number of columns x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_preprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_preprocessor_warning = True + # copy variables that are not remapped x_remapped[..., : len(indices_keep)] = x[..., indices_keep] @@ -211,7 +226,16 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: return x_remapped def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Convert and remap the output tensor.""" + """Convert and remap the output tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this postprocessor. + ``` + """ # Choose correct index based on number of variables if x.shape[-1] == self.num_remapped_training_output_vars: index = self.index_training_output @@ -233,6 +257,12 @@ def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Ten # create new tensor with target number of columns x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_postprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_postprocessor_warning = True + # copy variables that are not remapped x_remapped[..., indices_keep] = x[..., : len(indices_keep)] @@ -265,6 +295,8 @@ def __init__( ) -> None: super().__init__(config, data_indices, statistics) + self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False + self._create_remapping_indices(statistics) self._validate_indices()