Skip to content

Commit

Permalink
comment and warning about using in_place=True in remapper as this is …
Browse files Browse the repository at this point in the history
…not possible
  • Loading branch information
sahahner committed Aug 14, 2024
1 parent 33dccb9 commit 9e66b5c
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions src/anemoi/models/preprocessing/remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,16 @@ def _create_remapping_indices(
raise ValueError[f"Unknown remapping method for {name}: {method}"]

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Remap and convert the input tensor."""
"""Remap and convert the input tensor.
```
x : torch.Tensor
Input tensor
in_place : bool
Whether to process the tensor in place.
in_place is not possible for this preprocessor.
```
"""
# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
Expand All @@ -199,6 +208,12 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:

# create new tensor with target number of columns
x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device)
if in_place and not self.printed_preprocessor_warning:
LOGGER.warning(
"Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.",
)
self.printed_preprocessor_warning = True

# copy variables that are not remapped
x_remapped[..., : len(indices_keep)] = x[..., indices_keep]

Expand All @@ -211,7 +226,16 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
return x_remapped

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Convert and remap the output tensor."""
"""Convert and remap the output tensor.
```
x : torch.Tensor
Input tensor
in_place : bool
Whether to process the tensor in place.
in_place is not possible for this postprocessor.
```
"""
# Choose correct index based on number of variables
if x.shape[-1] == self.num_remapped_training_output_vars:
index = self.index_training_output
Expand All @@ -233,6 +257,12 @@ def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Ten

# create new tensor with target number of columns
x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device)
if in_place and not self.printed_postprocessor_warning:
LOGGER.warning(
"Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.",
)
self.printed_postprocessor_warning = True

# copy variables that are not remapped
x_remapped[..., indices_keep] = x[..., : len(indices_keep)]

Expand Down Expand Up @@ -265,6 +295,8 @@ def __init__(
) -> None:
super().__init__(config, data_indices, statistics)

self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False

self._create_remapping_indices(statistics)

self._validate_indices()

0 comments on commit 9e66b5c

Please sign in to comment.