From 97729a3242dbc3b45c9d96dcea002bbc63121bc0 Mon Sep 17 00:00:00 2001 From: Xiangzhong Liu Date: Mon, 4 Dec 2023 23:59:32 +0100 Subject: [PATCH] fix bug use_best_ckpt_folder --- bark_ml/experiment/experiment_runner.py | 9 ++++++--- .../lib_tf_agents/agents/tfa_agent.py | 14 ++++++++++++-- .../lib_tf_agents/runners/tfa_runner.py | 2 +- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/bark_ml/experiment/experiment_runner.py b/bark_ml/experiment/experiment_runner.py index d0a08896..cee50fdb 100644 --- a/bark_ml/experiment/experiment_runner.py +++ b/bark_ml/experiment/experiment_runner.py @@ -26,6 +26,7 @@ def __init__(self, use_best_ckpt_folder = False): self._logger = logging.getLogger() self._experiment_json = json_file + self._use_best_ckpt_folder = False if not use_best_ckpt_folder else use_best_ckpt_folder if params is not None: self._params = params else: @@ -36,11 +37,12 @@ def __init__(self, self._random_seed = random_seed np.random.seed(random_seed) tf.random.set_seed(random_seed) - self.SetCkptsAndSummaries(use_best_ckpt_folder = use_best_ckpt_folder) + self.SetCkptsAndSummaries() self._experiment = self.BuildExperiment(json_file, mode) self.Visitor(mode) def Visitor(self, mode): + self._experiment.runner._agent.LoadCheckpoint(self._use_best_ckpt_folder) if mode == "train": self._experiment._params.Save(self._runs_folder+"params.json") self.Train() @@ -84,10 +86,11 @@ def CompareHashes(self): if experiment_hash != old_experiment_hash: self._logger.warning("\033[31m Trained experiment hash does not match \033[0m") - def SetCkptsAndSummaries(self, use_best_ckpt_folder = False): + def SetCkptsAndSummaries(self): self._runs_folder = \ str(self._experiment_folder) + "/" + self._json_name + "/" + str(self._random_seed) + "/" - ckpt_folder = self._runs_folder + ("ckpts/" if not use_best_ckpt_folder else "ckpts/best_checkpoint/") + ckpt_folder = self._runs_folder + "ckpts/" + # ckpt_folder = self._runs_folder + ("ckpts/" if not use_best_ckpt_folder else "ckpts/best_checkpoint/") summ_folder = self._runs_folder + "summ/" self._logger.info(f"Run folder of the agent {self._runs_folder}.") self._hash_file_path = self._runs_folder + "hash.txt" diff --git a/bark_ml/library_wrappers/lib_tf_agents/agents/tfa_agent.py b/bark_ml/library_wrappers/lib_tf_agents/agents/tfa_agent.py index 348a73d1..a0b716c5 100644 --- a/bark_ml/library_wrappers/lib_tf_agents/agents/tfa_agent.py +++ b/bark_ml/library_wrappers/lib_tf_agents/agents/tfa_agent.py @@ -42,10 +42,12 @@ def __init__(self, self._ckpt = tf.train.Checkpoint(step=tf.Variable(0, dtype=tf.int64), agent=self._agent) ckpt_path= self._params["ML"]["BehaviorTFAAgents"]["CheckpointPath", "", ""] + self._best_ckpt_manager = self.GetCheckpointer(ckpt_path+"best_checkpoint/",1) self._ckpt_manager = self.GetCheckpointer(ckpt_path,self._params["ML"]["BehaviorTFAAgents"][ "NumCheckpointsToKeep", "", 3]) - self._best_ckpt_manager = self.GetCheckpointer(ckpt_path+"best_checkpoint/",1) self._logger = logging.getLogger() + # restored_step = self._agent.train_step_counter.numpy() + # print(f"Loaded the agent checkpoint at step {restored_step}") # NOTE: by default we do not want the action to be set externally # as this enables the agents to be plug and played in BARK. self._set_action_externally = False @@ -80,12 +82,20 @@ def Save(self): int(self._agent._train_step_counter.numpy()))) - def SaveCheckpoint(self): + def SaveBestCheckpoint(self): self._best_ckpt_manager.save( global_step=self._agent._train_step_counter) self._logger.info( f"Saved best checkpoint for step " f"{int(self._agent._train_step_counter.numpy())} at {self._best_ckpt_manager._manager._directory}.") + + def LoadCheckpoint(self, use_best_ckpt_folder=False): + if use_best_ckpt_folder: + self._best_ckpt_manager.initialize_or_restore() + self._logger.info("Restored agent from best checkpoint!") + else: + self._ckpt_manager.initialize_or_restore() + self._logger.info("Restored agent from latest checkpoint!") def Load(self): try: diff --git a/bark_ml/library_wrappers/lib_tf_agents/runners/tfa_runner.py b/bark_ml/library_wrappers/lib_tf_agents/runners/tfa_runner.py index 341ae7a9..0a86a5d3 100644 --- a/bark_ml/library_wrappers/lib_tf_agents/runners/tfa_runner.py +++ b/bark_ml/library_wrappers/lib_tf_agents/runners/tfa_runner.py @@ -176,7 +176,7 @@ def Run( best_ckpt_folder=self._agent._best_ckpt_manager._manager._directory if success_rate > self._max_success_rate or \ (success_rate == self._max_success_rate and mean_reward > self._max_reward): - self._agent.SaveCheckpoint() + self._agent.SaveBestCheckpoint() with open(best_ckpt_folder + 'info.txt', 'w') as f: f.write(f"Success-rate {success_rate:.3f}, collision-rate: {col_rate:.5f}" f", reward {mean_reward:.3f}, steps: {mean_steps:.3f}.")