From f67172afb22d6139d508f8362a822f7aa98efa4b Mon Sep 17 00:00:00 2001 From: Bill Huang Date: Sun, 27 Oct 2024 21:34:30 +0800 Subject: [PATCH] fix: brax's visualization logic --- .../reinforcement_learning/brax.py | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/src/evox/problems/neuroevolution/reinforcement_learning/brax.py b/src/evox/problems/neuroevolution/reinforcement_learning/brax.py index 71b84693..ad7e6b08 100644 --- a/src/evox/problems/neuroevolution/reinforcement_learning/brax.py +++ b/src/evox/problems/neuroevolution/reinforcement_learning/brax.py @@ -1,4 +1,4 @@ -from typing import Callable, Any +from typing import Callable, Any, Optional from brax import envs from brax.io import html, image import jax @@ -129,10 +129,27 @@ def visualize( key, weights, output_type: str = "HTML", - respect_done=False, + respect_done: bool = False, + num_episodes: Optional[int] = None, *args, **kwargs, ): + """Visualize the brax environment with the given policy and weights. + + Parameters + ---------- + key + The random key. + weights + The weights of the policy. + output_type + The output type, either "HTML" or "rgb_array". + respect_done + Whether to respect the done signal. + num_episodes + The number of episodes to visualize, used to override the num_episodes in the constructor. + If None, use the num_episodes in the constructor. + """ assert output_type in [ "HTML", "rgb_array", @@ -143,8 +160,23 @@ def visualize( jit_env_step = jit(env.step) trajectory = [brax_state.pipeline_state] episode_length = 1 - for _ in range(self.cap_episode): - action = self.policy(weights, brax_state.obs) + + if self.stateful_policy: + rollout_state = (self.initial_state, brax_state) + else: + rollout_state = (brax_state,) + + for _ in range(self.num_episodes): + if self.stateful_policy: + state, brax_state = rollout_state + action, state = self.policy(state, weights, brax_state.obs) + rollout_state = (state, brax_state) + else: + (brax_state,) = rollout_state + action = self.policy(weights, brax_state.obs) + rollout_state = (brax_state,) + + trajectory.append(brax_state.pipeline_state) brax_state = jit_env_step(brax_state, action) trajectory.append(brax_state.pipeline_state) episode_length += 1 - brax_state.done