Skip to content

Commit

Permalink
fix: brax's visualization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Oct 27, 2024
1 parent abe6799 commit f67172a
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions src/evox/problems/neuroevolution/reinforcement_learning/brax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Any
from typing import Callable, Any, Optional
from brax import envs
from brax.io import html, image
import jax
Expand Down Expand Up @@ -129,10 +129,27 @@ def visualize(
key,
weights,
output_type: str = "HTML",
respect_done=False,
respect_done: bool = False,
num_episodes: Optional[int] = None,
*args,
**kwargs,
):
"""Visualize the brax environment with the given policy and weights.
Parameters
----------
key
The random key.
weights
The weights of the policy.
output_type
The output type, either "HTML" or "rgb_array".
respect_done
Whether to respect the done signal.
num_episodes
The number of episodes to visualize, used to override the num_episodes in the constructor.
If None, use the num_episodes in the constructor.
"""
assert output_type in [
"HTML",
"rgb_array",
Expand All @@ -143,8 +160,23 @@ def visualize(
jit_env_step = jit(env.step)
trajectory = [brax_state.pipeline_state]
episode_length = 1
for _ in range(self.cap_episode):
action = self.policy(weights, brax_state.obs)

if self.stateful_policy:
rollout_state = (self.initial_state, brax_state)
else:
rollout_state = (brax_state,)

for _ in range(self.num_episodes):
if self.stateful_policy:
state, brax_state = rollout_state
action, state = self.policy(state, weights, brax_state.obs)
rollout_state = (state, brax_state)
else:
(brax_state,) = rollout_state
action = self.policy(weights, brax_state.obs)
rollout_state = (brax_state,)

trajectory.append(brax_state.pipeline_state)
brax_state = jit_env_step(brax_state, action)
trajectory.append(brax_state.pipeline_state)
episode_length += 1 - brax_state.done
Expand Down

0 comments on commit f67172a

Please sign in to comment.