Skip to content

Commit

Permalink
optimizer metrics to tensorboard slows things down considerably, only…
Browse files Browse the repository at this point in the history
… output at end of epoch to history
  • Loading branch information
grantbuster committed Nov 12, 2024
1 parent 5689f64 commit 06c5d99
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,15 +742,9 @@ def train_epoch(
b_loss_details['gen_trained_frac'] = float(trained_gen)
b_loss_details['disc_trained_frac'] = float(trained_disc)

opt_g = self.get_optimizer_state(self.optimizer)
opt_d = self.get_optimizer_state(self.optimizer_disc)
opt_g = {f'Gen/{key}': val for key, val in opt_g.items()}
opt_d = {f'Disc/{key}': val for key, val in opt_g.items()}
b_loss_details.update(opt_g)
b_loss_details.update(opt_d)

self.dict_to_tensorboard(b_loss_details)
self.dict_to_tensorboard(self.timer.log)

loss_details = self.update_loss_details(
loss_details,
b_loss_details,
Expand Down Expand Up @@ -1003,10 +997,9 @@ def train(
loss_details['train_loss_gen'], loss_details['train_loss_disc']
)

if all(
loss in loss_details
for loss in ('val_loss_gen', 'val_loss_disc')
):
check1 = 'val_loss_gen' in loss_details
check2 = 'val_loss_disc' in loss_details
if check1 and check2:
msg += 'gen/disc val loss: {:.2e}/{:.2e} '.format(
loss_details['val_loss_gen'], loss_details['val_loss_disc']
)
Expand All @@ -1023,8 +1016,8 @@ def train(

opt_g = self.get_optimizer_state(self.optimizer)
opt_d = self.get_optimizer_state(self.optimizer_disc)
opt_g = {f'Gen/{key}': val for key, val in opt_g.items()}
opt_d = {f'Disc/{key}': val for key, val in opt_g.items()}
opt_g = {f'OptmGen/{key}': val for key, val in opt_g.items()}
opt_d = {f'OptmDisc/{key}': val for key, val in opt_d.items()}
extras.update(opt_g)
extras.update(opt_d)

Expand Down

0 comments on commit 06c5d99

Please sign in to comment.