From 657310e3acd7ec4a47c0cda6165c3a82992ea6b0 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Tue, 12 Nov 2024 09:57:14 -0700 Subject: [PATCH] added optimizer state variables to model history output --- sup3r/models/abstract.py | 24 +++++++++++++++++++++++- sup3r/models/base.py | 13 +++++++------ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 4b48fa89f..31a2b49fc 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1037,7 +1037,7 @@ def get_optimizer_config(optimizer): Parameters ---------- optimizer : tf.keras.optimizers.Optimizer - TF-Keras optimizer object + TF-Keras optimizer object (e.g., Adam) Returns ------- @@ -1053,6 +1053,28 @@ def get_optimizer_config(optimizer): conf[k] = int(v) return conf + @classmethod + def get_optimizer_state(cls, optimizer): + """Get a set of state variables for the optimizer + + Parameters + ---------- + optimizer : tf.keras.optimizers.Optimizer + TF-Keras optimizer object (e.g., Adam) + + Returns + ------- + state : dict + Optimizer state variables + """ + lr = cls.get_optimizer_config(optimizer)['learning_rate'] + state = {'learning_rate': lr} + for idv, var in enumerate(optimizer.variables): + name = var.name + var = var.numpy().flatten()[0] # collapse single value ndarrays + state[name] = float(var) + return state + @staticmethod def update_loss_details(loss_details, new_data, batch_len, prefix=None): """Update a dictionary of loss_details with loss information from a new diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 19b4ce8be..7a01c0c53 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1011,14 +1011,15 @@ def train( 'weight_gen_advers': weight_gen_advers, 'disc_loss_bound_0': disc_loss_bounds[0], 'disc_loss_bound_1': disc_loss_bounds[1], - 'learning_rate_gen': self.get_optimizer_config(self.optimizer)[ - 'learning_rate' - ], - 'learning_rate_disc': self.get_optimizer_config( - self.optimizer_disc - )['learning_rate'], } + opt_g = self.get_optimizer_state(self.optimizer) + opt_d = self.get_optimizer_state(self.optimizer_disc) + opt_g = {f'Gen/{key}': val for key, val in opt_g.items()} + opt_d = {f'Disc/{key}': val for key, val in opt_g.items()} + extras.update(opt_g) + extras.update(opt_d) + weight_gen_advers = self.update_adversarial_weights( loss_details, adaptive_update_fraction,