-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy patharm_binpick_easy.py
99 lines (81 loc) · 4.65 KB
/
arm_binpick_easy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from brax import base
from brax.envs.base import State
from brax.io import mjcf
import mujoco
import jax
from jax import numpy as jnp
from envs.manipulation.arm_envs import ArmEnvs
"""
Binpick-Easy: Move a cube from a random location in the blue bin to the center of the red bin.
- Observation space: 18-dim obs + 3-dim goal.
- Action space: 5-dim, each element in [-1, 1], corresponding to target angles for joints 1, 2, 4, 6, and finger closedness.
See _get_obs() and ArmEnvs._convert_action() for details.
"""
class ArmBinpickEasy(ArmEnvs):
def _get_xml_path(self):
return "envs/assets/panda_binpick_easy.xml"
@property
def action_size(self) -> int:
return 5 # Override default (actuator count)
# See ArmEnvs._set_environment_attributes for descriptions of attributes
def _set_environment_attributes(self):
self.env_name = "arm_binpick_easy"
self.episode_length = 150
self.goal_indices = jnp.array([0, 1, 2]) # Cube position
self.completion_goal_indices = jnp.array([0, 1, 2]) # Identical
self.state_dim = 18
self.goal_reach_thresh = 0.1
self.arm_noise_scale = 0
self.cube_noise_scale = 0.07
self.goal_noise_scale = 0.005
def _get_initial_state(self, rng):
rng, subkey1, subkey2 = jax.random.split(rng, 3)
cube_q_xy = self.sys.init_q[:2] + self.cube_noise_scale * jax.random.uniform(subkey1, [2], minval=-1)
cube_q_remaining = self.sys.init_q[2:7]
target_q = self.sys.init_q[7:14]
arm_q_default = jnp.array([1.571, 0.742, 0, -1.571, 0, 3.054, 1.449, 0.04, 0.04]) # Start closer to the relevant area
arm_q = arm_q_default + self.arm_noise_scale * jax.random.uniform(subkey2, [self.sys.q_size() - 14], minval=-1)
q = jnp.concatenate([cube_q_xy] + [cube_q_remaining] + [target_q] + [arm_q])
qd = jnp.zeros([self.sys.qd_size()])
return q, qd
def _get_initial_goal(self, pipeline_state: base.State, rng):
rng, subkey = jax.random.split(rng)
cube_goal_pos = jnp.array([0.17, 0.6, 0.03]) + jnp.array([self.goal_noise_scale, self.goal_noise_scale, 0]) * jax.random.uniform(subkey, [3], minval=-1)
return cube_goal_pos
def _compute_goal_completion(self, obs, goal):
# Goal occupancy: is the cube close enough to the goal?
current_cube_pos = obs[self.completion_goal_indices]
goal_pos = goal[:3]
dist = jnp.linalg.norm(current_cube_pos - goal_pos)
success = jnp.array(dist < self.goal_reach_thresh, dtype=float)
success_easy = jnp.array(dist < 0.3, dtype=float)
success_hard = jnp.array(dist < 0.03, dtype=float)
return success, success_easy, success_hard
def _update_goal_visualization(self, pipeline_state: base.State, goal: jax.Array) -> base.State:
updated_q = pipeline_state.q.at[7:10].set(goal[:3]) # Only set the position, not orientation
updated_pipeline_state = pipeline_state.replace(qpos=updated_q)
return updated_pipeline_state
def _get_obs(self, pipeline_state: base.State, goal: jax.Array, timestep) -> jax.Array:
"""
Observation space (18-dim)
- q_subset (10-dim): 3-dim cube position, 7-dim joint angles
- End-effector (6-dim): position and velocity
- Fingers (2-dim): finger distance, gripper force
Note q is 23-dim: 7-dim cube position/angle, 7-dim goal marker position/angle, 7-dim joint angles, 2-dim finger offset
Goal space (3-dim): position of cube
"""
q_indices = jnp.array([0, 1, 2, 14, 15, 16, 17, 18, 19, 20])
q_subset = pipeline_state.q[q_indices]
eef_index = 8 # Cube is 0, goal marker is 1, then links 1-7 are indices 2-8. The end-effector (eef) base is merged with link 7, so we say link 7 index = eef index.
eef_x_pos = pipeline_state.x.pos[eef_index]
eef_xd_vel = pipeline_state.xd.vel[eef_index]
left_finger_index = 9
left_finger_x_pos = pipeline_state.x.pos[left_finger_index]
right_finger_index = 10
right_finger_x_pos = pipeline_state.x.pos[right_finger_index]
finger_distance = jnp.linalg.norm(right_finger_x_pos - left_finger_x_pos)[None] # [None] expands dims from 0 to 1
gripper_force = (pipeline_state.qfrc_actuator[:-2]).mean(keepdims=True) * 0.1 # Normalize it from range [-20, 20] to [-2, 2]
return jnp.concatenate([q_subset] + [eef_x_pos] + [eef_xd_vel] + [finger_distance] + [gripper_force] + [goal])
def _get_arm_angles(self, pipeline_state: base.State) -> jax.Array:
q_indices = jnp.arange(14, 21)
return pipeline_state.q[q_indices]