Skip to content

Commit

Permalink
handle last channel option and squeeze dims for losses to fix shape i…
Browse files Browse the repository at this point in the history
…ssues
  • Loading branch information
liellnima committed Sep 5, 2024
1 parent 7477cf0 commit d468268
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions emulator/src/core/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ def on_train_epoch_start(self) -> None:
self._start_epoch_time = time.time()

def predict(self, X, idx, *args, **kwargs):
# x (batch_size, time, lat, lon, num_features)
"""
X (batch_size, time, lat, lon, num_features)
Notes:
if channels are not last --> permutated in split_vector_by_variable
always returns a dictionary with variables as keys. value of each key: [batch, time, lat, lon]
"""
# TODO if we want to apply any input normalization or other stuff we should do it here
# if idx is None or if we do not have a decoder

Expand All @@ -123,7 +129,7 @@ def predict(self, X, idx, *args, **kwargs):
# else we will just return raw predictions

# splitting predictions to get dict accessible via target var id
preds_dict = self.output_postprocesser.split_vector_by_variable(preds)
preds_dict = self.output_postprocesser.split_vector_by_variable(preds, self.channels_last)

return preds_dict

Expand All @@ -135,25 +141,31 @@ def training_step(self, batch: Any, batch_idx: int):
idx = None

preds = self.predict(X, idx)

# dict with keys being the output var ids
Y = self.output_postprocesser.split_vector_by_variable(
Y
Y, self.channels_last
) # split per var id #TODO: might need to remove that for other datamodule

train_log = dict() # everything we want to log to wandb should go in here

loss = 0

# Loop over output variable to compute loss seperateley!!!
for out_var in self._out_var_ids:
loss_per_var = self.criterion(preds[out_var], Y[out_var])
for out_var in self._out_var_ids:
# squeeze last dimension if necessary (nothing happens if it's not there)
preds_var = torch.squeeze(preds[out_var], -1)
Y_var = torch.squeeze(Y[out_var], -1)
loss_per_var = self.criterion(preds_var, Y_var)

if torch.isnan(loss_per_var).sum() > 0:
exit(0)
loss += loss_per_var
raise ValueError("Loss contains NaNs. Analyse problem in loss functions.")

loss += loss_per_var
train_log[f"train/{out_var}/loss"] = loss_per_var
# any additional losses can be computed, logged and added to the loss here

# Average Loss over vars
# Average loss over vars
loss = loss / len(self._out_var_ids)

n_zero_gradients = (
Expand Down Expand Up @@ -185,7 +197,7 @@ def on_train_epoch_end(self):
# if(self.track_emissions):
# self.tot_co2_emission += self.tracker.stop()
# self.log("co2_emission", self.tot_co2_emission)
print("HERE")
#print("HERE")
self.log_dict({"epoch": self.current_epoch, "time/train": train_time})

def _evaluation_step(self, batch: Any, batch_idx: int):
Expand All @@ -206,8 +218,12 @@ def _evaluation_get_preds(
self, outputs: List[Any]
) -> (Dict[str, np.ndarray], Dict[str, np.ndarray]):
for batch in outputs:
print("MUST check shape for split_vector_by_variable:")
print("Shape should have channels last? ", self.channels_last)
print("Actual shape:", batch["targets"].shape)
exit(0)
batch["targets"] = self.output_postprocesser.split_vector_by_variable(
batch["targets"]
batch["targets"], self.channels_last
) # TODO: we might want to remove that for the real data module

Y = {
Expand Down

0 comments on commit d468268

Please sign in to comment.