diff --git a/src/anemoi/models/preprocessing/remapper.py b/src/anemoi/models/preprocessing/remapper.py index 85f32e8a..243e6b46 100644 --- a/src/anemoi/models/preprocessing/remapper.py +++ b/src/anemoi/models/preprocessing/remapper.py @@ -209,7 +209,10 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: x_remapped = x[..., index] # create new tensor with target number of columns - x = torch.nn.functional.pad(x, (0, target_number_columns - x.shape[-1])) + # x = torch.nn.functional.pad(x, (0, target_number_columns - x.shape[-1])) + x = torch.cat( + (x, torch.zeros(x.shape[:-1] + (target_number_columns - x.shape[-1],), dtype=x.dtype, device=x.device)), -1 + ) if in_place and not self.printed_preprocessor_warning: LOGGER.warning( "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.",