Skip to content

Commit

Permalink
fix bug use_best_ckpt_folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangzhong Liu committed Dec 4, 2023
1 parent 5a7c17b commit 97729a3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
9 changes: 6 additions & 3 deletions bark_ml/experiment/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self,
use_best_ckpt_folder = False):
self._logger = logging.getLogger()
self._experiment_json = json_file
self._use_best_ckpt_folder = False if not use_best_ckpt_folder else use_best_ckpt_folder
if params is not None:
self._params = params
else:
Expand All @@ -36,11 +37,12 @@ def __init__(self,
self._random_seed = random_seed
np.random.seed(random_seed)
tf.random.set_seed(random_seed)
self.SetCkptsAndSummaries(use_best_ckpt_folder = use_best_ckpt_folder)
self.SetCkptsAndSummaries()
self._experiment = self.BuildExperiment(json_file, mode)
self.Visitor(mode)

def Visitor(self, mode):
self._experiment.runner._agent.LoadCheckpoint(self._use_best_ckpt_folder)
if mode == "train":
self._experiment._params.Save(self._runs_folder+"params.json")
self.Train()
Expand Down Expand Up @@ -84,10 +86,11 @@ def CompareHashes(self):
if experiment_hash != old_experiment_hash:
self._logger.warning("\033[31m Trained experiment hash does not match \033[0m")

def SetCkptsAndSummaries(self, use_best_ckpt_folder = False):
def SetCkptsAndSummaries(self):
self._runs_folder = \
str(self._experiment_folder) + "/" + self._json_name + "/" + str(self._random_seed) + "/"
ckpt_folder = self._runs_folder + ("ckpts/" if not use_best_ckpt_folder else "ckpts/best_checkpoint/")
ckpt_folder = self._runs_folder + "ckpts/"
# ckpt_folder = self._runs_folder + ("ckpts/" if not use_best_ckpt_folder else "ckpts/best_checkpoint/")
summ_folder = self._runs_folder + "summ/"
self._logger.info(f"Run folder of the agent {self._runs_folder}.")
self._hash_file_path = self._runs_folder + "hash.txt"
Expand Down
14 changes: 12 additions & 2 deletions bark_ml/library_wrappers/lib_tf_agents/agents/tfa_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def __init__(self,
self._ckpt = tf.train.Checkpoint(step=tf.Variable(0, dtype=tf.int64),
agent=self._agent)
ckpt_path= self._params["ML"]["BehaviorTFAAgents"]["CheckpointPath", "", ""]
self._best_ckpt_manager = self.GetCheckpointer(ckpt_path+"best_checkpoint/",1)
self._ckpt_manager = self.GetCheckpointer(ckpt_path,self._params["ML"]["BehaviorTFAAgents"][
"NumCheckpointsToKeep", "", 3])
self._best_ckpt_manager = self.GetCheckpointer(ckpt_path+"best_checkpoint/",1)
self._logger = logging.getLogger()
# restored_step = self._agent.train_step_counter.numpy()
# print(f"Loaded the agent checkpoint at step {restored_step}")
# NOTE: by default we do not want the action to be set externally
# as this enables the agents to be plug and played in BARK.
self._set_action_externally = False
Expand Down Expand Up @@ -80,12 +82,20 @@ def Save(self):
int(self._agent._train_step_counter.numpy())))


def SaveCheckpoint(self):
def SaveBestCheckpoint(self):
self._best_ckpt_manager.save(
global_step=self._agent._train_step_counter)
self._logger.info(
f"Saved best checkpoint for step "
f"{int(self._agent._train_step_counter.numpy())} at {self._best_ckpt_manager._manager._directory}.")

def LoadCheckpoint(self, use_best_ckpt_folder=False):
if use_best_ckpt_folder:
self._best_ckpt_manager.initialize_or_restore()
self._logger.info("Restored agent from best checkpoint!")
else:
self._ckpt_manager.initialize_or_restore()
self._logger.info("Restored agent from latest checkpoint!")

def Load(self):
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def Run(
best_ckpt_folder=self._agent._best_ckpt_manager._manager._directory
if success_rate > self._max_success_rate or \
(success_rate == self._max_success_rate and mean_reward > self._max_reward):
self._agent.SaveCheckpoint()
self._agent.SaveBestCheckpoint()
with open(best_ckpt_folder + 'info.txt', 'w') as f:
f.write(f"Success-rate {success_rate:.3f}, collision-rate: {col_rate:.5f}"
f", reward {mean_reward:.3f}, steps: {mean_steps:.3f}.")
Expand Down

0 comments on commit 97729a3

Please sign in to comment.