From d468268be8fc159648f2253bebd27427b1ca3330 Mon Sep 17 00:00:00 2001 From: liellnima Date: Thu, 5 Sep 2024 22:36:48 +0200 Subject: [PATCH] handle last channel option and squeeze dims for losses to fix shape issues --- emulator/src/core/models/basemodel.py | 36 +++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/emulator/src/core/models/basemodel.py b/emulator/src/core/models/basemodel.py index 11c4ac7..6f4a09a 100644 --- a/emulator/src/core/models/basemodel.py +++ b/emulator/src/core/models/basemodel.py @@ -109,7 +109,13 @@ def on_train_epoch_start(self) -> None: self._start_epoch_time = time.time() def predict(self, X, idx, *args, **kwargs): - # x (batch_size, time, lat, lon, num_features) + """ + X (batch_size, time, lat, lon, num_features) + + Notes: + if channels are not last --> permutated in split_vector_by_variable + always returns a dictionary with variables as keys. value of each key: [batch, time, lat, lon] + """ # TODO if we want to apply any input normalization or other stuff we should do it here # if idx is None or if we do not have a decoder @@ -123,7 +129,7 @@ def predict(self, X, idx, *args, **kwargs): # else we will just return raw predictions # splitting predictions to get dict accessible via target var id - preds_dict = self.output_postprocesser.split_vector_by_variable(preds) + preds_dict = self.output_postprocesser.split_vector_by_variable(preds, self.channels_last) return preds_dict @@ -135,9 +141,10 @@ def training_step(self, batch: Any, batch_idx: int): idx = None preds = self.predict(X, idx) + # dict with keys being the output var ids Y = self.output_postprocesser.split_vector_by_variable( - Y + Y, self.channels_last ) # split per var id #TODO: might need to remove that for other datamodule train_log = dict() # everything we want to log to wandb should go in here @@ -145,15 +152,20 @@ def training_step(self, batch: Any, batch_idx: int): loss = 0 # Loop over output variable to compute loss seperateley!!! - for out_var in self._out_var_ids: - loss_per_var = self.criterion(preds[out_var], Y[out_var]) + for out_var in self._out_var_ids: + # squeeze last dimension if necessary (nothing happens if it's not there) + preds_var = torch.squeeze(preds[out_var], -1) + Y_var = torch.squeeze(Y[out_var], -1) + loss_per_var = self.criterion(preds_var, Y_var) + if torch.isnan(loss_per_var).sum() > 0: - exit(0) - loss += loss_per_var + raise ValueError("Loss contains NaNs. Analyse problem in loss functions.") + + loss += loss_per_var train_log[f"train/{out_var}/loss"] = loss_per_var # any additional losses can be computed, logged and added to the loss here - # Average Loss over vars + # Average loss over vars loss = loss / len(self._out_var_ids) n_zero_gradients = ( @@ -185,7 +197,7 @@ def on_train_epoch_end(self): # if(self.track_emissions): # self.tot_co2_emission += self.tracker.stop() # self.log("co2_emission", self.tot_co2_emission) - print("HERE") + #print("HERE") self.log_dict({"epoch": self.current_epoch, "time/train": train_time}) def _evaluation_step(self, batch: Any, batch_idx: int): @@ -206,8 +218,12 @@ def _evaluation_get_preds( self, outputs: List[Any] ) -> (Dict[str, np.ndarray], Dict[str, np.ndarray]): for batch in outputs: + print("MUST check shape for split_vector_by_variable:") + print("Shape should have channels last? ", self.channels_last) + print("Actual shape:", batch["targets"].shape) + exit(0) batch["targets"] = self.output_postprocesser.split_vector_by_variable( - batch["targets"] + batch["targets"], self.channels_last ) # TODO: we might want to remove that for the real data module Y = {