Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 29, 2024
1 parent a2aa047 commit eb8aece
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions examples/baselines/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def clip_action(action: torch.Tensor):
print(f"eval_episodic_return={episodic_return}")
writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step)
writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step)
writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"], global_step)
writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"].float().mean().cpu().numpy(), global_step)
# exit()
if args.save_model and iteration % args.eval_freq == 1:
model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model"
Expand Down Expand Up @@ -273,7 +273,7 @@ def clip_action(action: torch.Tensor):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/success_rate", info["success"].float().mean().cpu().numpy(), global_step)
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", info["elapsed_steps"], global_step)
writer.add_scalar("charts/episodic_length", info["elapsed_steps"].float().mean().cpu().numpy(), global_step)
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
Expand Down
4 changes: 3 additions & 1 deletion mani_skill2/vector/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def step(
self.returns += rew
infos["episode"] = dict(r=self.returns)
terminations = torch.zeros(self.num_envs, device=self.env.device)
if truncations:
if (isinstance(truncations, torch.Tensor) and truncations.any()) or (
not isinstance(truncations, torch.Tensor) and truncations
):
infos["episode"]["r"] = self.returns.clone()
final_obs = obs
obs, _ = self.reset()
Expand Down

0 comments on commit eb8aece

Please sign in to comment.