Skip to content

Commit

Permalink
fix brax's wrong impl on episode return evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZaberKo committed May 14, 2024
1 parent 95142ba commit be13968
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions src/evox/problems/neuroevolution/reinforcement_learning/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from evox import Problem, State, jit_method


def vmap_rng_split(key: jax.Array, num: int = 2) -> jax.Array:
# batched_key [B, 2] -> batched_keys [num, B, 2]
return jax.vmap(jax.random.split, in_axes=(0, None), out_axes=1)(key, num)


class Brax(Problem):
def __init__(
self,
Expand Down Expand Up @@ -50,27 +55,32 @@ def setup(self, key):

@jit_method
def evaluate(self, state, weights):
batch_size = tree_leaves(weights)[0].shape[0]
brax_state = self.jit_reset(jnp.tile(state.key, (batch_size, 1)))
pop_size = tree_leaves(weights)[0].shape[0]
num_evals = 1
key, eval_key = jax.random.split(state.key)
brax_state = self.jit_reset(
vmap_rng_split(jax.random.split(eval_key, num_evals), pop_size)
)

def cond_func(val):
counter, state, _total_reward = val
return (counter < self.cap_episode) & (~state.done.all())
counter, state, done, _total_reward = val
return (counter < self.cap_episode) & (~done.all())

def body_func(val):
counter, brax_state, total_reward = val
counter, brax_state, done, total_reward = val
action = self.batched_policy(weights, brax_state.obs)
brax_state = self.jit_env_step(brax_state, action)
total_reward += (1 - brax_state.done) * brax_state.reward
return counter + 1, brax_state, total_reward
done = brax_state.done * (1 - done)
total_reward += (1 - done) * brax_state.reward
return counter + 1, brax_state, done, total_reward

init_val = (0, brax_state, jnp.zeros((batch_size,)))
init_val = (0, brax_state, jnp.zeros((pop_size,)), jnp.zeros((pop_size,)))

_counter, _brax_state, total_reward = jax.lax.while_loop(
_counter, _brax_state, done, total_reward = jax.lax.while_loop(
cond_func, body_func, init_val
)

return total_reward, state
return total_reward, state.replace(key=key)

def visualize(
self,
Expand Down

0 comments on commit be13968

Please sign in to comment.