Skip to content

Commit

Permalink
Merge pull request #34 from MichalBortkiewicz/refactor-environments
Browse files Browse the repository at this point in the history
Environment refactoring
  • Loading branch information
vivekmyers authored Jan 4, 2025
2 parents 8efebda + b0f7969 commit c3ed381
Show file tree
Hide file tree
Showing 19 changed files with 268 additions and 296 deletions.
38 changes: 14 additions & 24 deletions envs/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def __init__(
reset_noise_scale=0.1,
exclude_current_positions_from_observation=False,
backend="generalized",
dense_reward:bool=False,
dense_reward: bool = False,
**kwargs,
):
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets', "ant.xml")
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "ant.xml")
sys = mjcf.load(path)

n_frames = 5
Expand All @@ -50,11 +50,7 @@ def __init__(

if backend == "positional":
# TODO: does the same actuator strength work as in spring
sys = sys.replace(
actuator=sys.actuator.replace(
gear=200 * jnp.ones_like(sys.actuator.gear)
)
)
sys = sys.replace(actuator=sys.actuator.replace(gear=200 * jnp.ones_like(sys.actuator.gear)))

kwargs["n_frames"] = kwargs.get("n_frames", n_frames)

Expand All @@ -68,13 +64,11 @@ def __init__(
self._healthy_z_range = healthy_z_range
self._contact_force_range = contact_force_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
self.dense_reward = dense_reward
self.state_dim = 29
self.goal_indices = jnp.array([0, 1])
self.goal_dist = 0.5
self.goal_reach_thresh = 0.5

if self._use_contact_forces:
raise NotImplementedError("use_contact_forces not implemented.")
Expand All @@ -85,9 +79,7 @@ def reset(self, rng: jax.Array) -> State:
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)
q = self.sys.init_q + jax.random.uniform(rng1, (self.sys.q_size(),), minval=low, maxval=hi)
qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

# set the target q, qd
Expand All @@ -112,7 +104,7 @@ def reset(self, rng: jax.Array) -> State:
"forward_reward": zero,
"dist": zero,
"success": zero,
"success_easy": zero
"success_easy": zero,
}
state = State(pipeline_state, obs, reward, done, metrics)
return state
Expand All @@ -139,12 +131,12 @@ def step(self, state: State, action: jax.Array) -> State:
old_dist = jnp.linalg.norm(old_obs[:2] - old_obs[-2:])
obs = self._get_obs(pipeline_state)
dist = jnp.linalg.norm(obs[:2] - obs[-2:])
vel_to_target = (old_dist-dist) / self.dt
success = jnp.array(dist < self.goal_dist, dtype=float)
success_easy = jnp.array(dist < 2., dtype=float)
vel_to_target = (old_dist - dist) / self.dt
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 2.0, dtype=float)

if self.dense_reward:
reward = 10*vel_to_target + healthy_reward - ctrl_cost - contact_cost
reward = 10 * vel_to_target + healthy_reward - ctrl_cost - contact_cost
else:
reward = success
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
Expand All @@ -162,11 +154,9 @@ def step(self, state: State, action: jax.Array) -> State:
forward_reward=forward_reward,
dist=dist,
success=success,
success_easy=success_easy
)
return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
success_easy=success_easy,
)
return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward, done=done)

def _get_obs(self, pipeline_state: base.State) -> jax.Array:
"""Observe ant body position and velocities."""
Expand All @@ -188,4 +178,4 @@ def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]:
ang = jnp.pi * 2.0 * jax.random.uniform(rng2)
target_x = dist * jnp.cos(ang)
target_y = dist * jnp.sin(ang)
return rng, jnp.array([target_x, target_y])
return rng, jnp.array([target_x, target_y])
6 changes: 3 additions & 3 deletions envs/ant_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(

self.state_dim = 31
self.goal_indices = jnp.array([29, 30])
self.goal_dist = 0.5
self.goal_reach_thresh = 0.5

if self._use_contact_forces:
raise NotImplementedError("use_contact_forces not implemented.")
Expand Down Expand Up @@ -144,7 +144,7 @@ def step(self, state: State, action: jax.Array) -> State:
obs = self._get_obs(pipeline_state)
dist = jnp.linalg.norm(obs[-2:] - obs[-4:-2])
vel_to_target = (old_dist - dist) / self.dt
success = jnp.array(dist < self.goal_dist, dtype=float)
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 2., dtype=float)

if self.dense_reward:
Expand Down Expand Up @@ -202,4 +202,4 @@ def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]:

target_pos = jnp.array([target_x, target_y])
obj_pos = target_pos * 0.2 + jnp.array([obj_x_offset, obj_y_offset])
return rng, target_pos, obj_pos
return rng, target_pos, obj_pos
6 changes: 3 additions & 3 deletions envs/ant_ball_maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(

self.state_dim = 31
self.goal_indices = jnp.array([28, 29])
self.goal_dist = 0.5
self.goal_reach_thresh = 0.5

if self._use_contact_forces:
raise NotImplementedError("use_contact_forces not implemented.")
Expand Down Expand Up @@ -235,7 +235,7 @@ def step(self, state: State, action: jax.Array) -> State:
obs = self._get_obs(pipeline_state)
dist = jnp.linalg.norm(obs[-2:] - obs[-4:-2])
vel_to_target = (old_dist - dist) / self.dt
success = jnp.array(dist < self.goal_dist, dtype=float)
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 2., dtype=float)

if self.dense_reward:
Expand Down Expand Up @@ -290,4 +290,4 @@ def _random_start(self, rng: jax.Array) -> jax.Array:

def _random_ball(self, rng: jax.Array) -> jax.Array:
idx = jax.random.randint(rng, (1,), 0, len(self.possible_balls))
return jnp.array(self.possible_balls[idx])[0]
return jnp.array(self.possible_balls[idx])[0]
Loading

0 comments on commit c3ed381

Please sign in to comment.