Skip to content

Commit

Permalink
added optimizer state variables to model history output
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Nov 12, 2024
1 parent e3c4b45 commit 657310e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
24 changes: 23 additions & 1 deletion sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def get_optimizer_config(optimizer):
Parameters
----------
optimizer : tf.keras.optimizers.Optimizer
TF-Keras optimizer object
TF-Keras optimizer object (e.g., Adam)
Returns
-------
Expand All @@ -1053,6 +1053,28 @@ def get_optimizer_config(optimizer):
conf[k] = int(v)
return conf

@classmethod
def get_optimizer_state(cls, optimizer):
"""Get a set of state variables for the optimizer
Parameters
----------
optimizer : tf.keras.optimizers.Optimizer
TF-Keras optimizer object (e.g., Adam)
Returns
-------
state : dict
Optimizer state variables
"""
lr = cls.get_optimizer_config(optimizer)['learning_rate']
state = {'learning_rate': lr}
for idv, var in enumerate(optimizer.variables):
name = var.name
var = var.numpy().flatten()[0] # collapse single value ndarrays
state[name] = float(var)
return state

@staticmethod
def update_loss_details(loss_details, new_data, batch_len, prefix=None):
"""Update a dictionary of loss_details with loss information from a new
Expand Down
13 changes: 7 additions & 6 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,14 +1011,15 @@ def train(
'weight_gen_advers': weight_gen_advers,
'disc_loss_bound_0': disc_loss_bounds[0],
'disc_loss_bound_1': disc_loss_bounds[1],
'learning_rate_gen': self.get_optimizer_config(self.optimizer)[
'learning_rate'
],
'learning_rate_disc': self.get_optimizer_config(
self.optimizer_disc
)['learning_rate'],
}

opt_g = self.get_optimizer_state(self.optimizer)
opt_d = self.get_optimizer_state(self.optimizer_disc)
opt_g = {f'Gen/{key}': val for key, val in opt_g.items()}
opt_d = {f'Disc/{key}': val for key, val in opt_g.items()}
extras.update(opt_g)
extras.update(opt_d)

weight_gen_advers = self.update_adversarial_weights(
loss_details,
adaptive_update_fraction,
Expand Down

0 comments on commit 657310e

Please sign in to comment.