-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy patharm_grasp.py
131 lines (109 loc) · 6.38 KB
/
arm_grasp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from brax import base
from brax.envs.base import State
from brax.io import mjcf
import mujoco
import jax
from jax import numpy as jnp
from envs.manipulation.arm_envs import ArmEnvs
"""
Grasp: Close fingers on opposite sides of a cube.
- Observation space: 24-dim obs + 7-dim goal.
- Action space: 5-dim, each element in [-1, 1], corresponding to target angles for joints 1, 2, 4, 6, and finger closedness.
See _get_obs() and ArmEnvs._convert_action() for details.
"""
class ArmGrasp(ArmEnvs):
def _get_xml_path(self):
return "envs/assets/panda_grasp.xml"
@property
def action_size(self) -> int:
return 5 # Override default (actuator count)
# See ArmEnvs._set_environment_attributes for descriptions of attributes
def _set_environment_attributes(self):
self.env_name = "arm_grasp"
self.episode_length = 100
self.goal_indices = jnp.array([16, 17, 18, 19, 20, 21, 22]) # Left and right fingertip positions, and fingertip distance
self.completion_goal_indices = jnp.array([16, 17, 18, 19, 20, 21, 22]) # Identical
self.state_dim = 24
self.arm_noise_scale = 0
self.cube_noise_scale = 0.3
def _get_initial_state(self, rng):
rng, subkey1, subkey2 = jax.random.split(rng, 3)
cube_q_xy = self.sys.init_q[:2] + self.cube_noise_scale * jax.random.uniform(subkey1, [2], minval=-1)
cube_q_remaining = self.sys.init_q[2:7]
target_q = self.sys.init_q[7:14]
arm_q_default = jnp.array([1.571, 0.742, 0, -1.571, 0, 3.054, 1.449, 0.04, 0.04, 0, 0]) # Start closer to the relevant area
arm_q = arm_q_default + self.arm_noise_scale * jax.random.uniform(subkey2, [self.sys.q_size() - 14], minval=-1)
q = jnp.concatenate([cube_q_xy] + [cube_q_remaining] + [target_q] + [arm_q])
qd = jnp.zeros([self.sys.qd_size()])
return q, qd
def _get_initial_goal(self, pipeline_state: base.State, rng):
# The fingertip goal positions will constantly adjust to be next to the cube (logic handled in arm_envs:step)
cube_pos = pipeline_state.q[:3]
left_fingertip_goal_pos = cube_pos + jnp.array([0.0375, 0, 0])
right_fingertip_goal_pos = cube_pos + jnp.array([-0.0375, 0, 0])
gripper_openness_goal = jnp.array([0.075]) # The cube itself is 0.06 wide, but we want the centers of the fingertips to be slightly farther apart
goal = jnp.concatenate([left_fingertip_goal_pos] + [right_fingertip_goal_pos] + [gripper_openness_goal])
return goal
def _compute_goal_completion(self, obs, goal):
# Goal occupancy: is the midpoint of the fingertips close enough to the cube, and is the gripper closed enough?
# Technically, only success_hard is properly gripping the cube, but success/success_easy are for signs of life.
cube_pos = obs[:3]
left_fingertip_pos = obs[16:19]
right_fingertip_pos = obs[19:22]
fingertip_midpoint = (left_fingertip_pos + right_fingertip_pos) / 2
cube_to_fingertip_midpoint_dist = jnp.linalg.norm(cube_pos - fingertip_midpoint)
gripper_openness = obs[22]
goal_gripper_openness = goal[9]
gripper_openness_difference = jnp.linalg.norm(gripper_openness - goal_gripper_openness)
success = jnp.array(
jnp.all(jnp.array([
cube_to_fingertip_midpoint_dist < 0.05,
gripper_openness_difference < 0.02
])),
dtype=float
)
success_easy = jnp.array(
jnp.all(jnp.array([
cube_to_fingertip_midpoint_dist < 0.15,
gripper_openness_difference < 0.05
])),
dtype=float
)
success_hard = jnp.array(
jnp.all(jnp.array([
cube_to_fingertip_midpoint_dist < 0.02,
gripper_openness_difference < 0.005
])),
dtype=float
)
return success, success_easy, success_hard
def _update_goal_visualization(self, pipeline_state: base.State, goal: jax.Array) -> base.State:
updated_q = pipeline_state.q.at[7:10].set(goal[:3]) # Only set the position, not orientation (we set the left marker and the right marker is automatically offset)
updated_pipeline_state = pipeline_state.replace(qpos=updated_q)
return updated_pipeline_state
def _get_obs(self, pipeline_state: base.State, goal: jax.Array, timestep) -> jax.Array:
"""
Observation space (24-dim)
- q_subset (10-dim): 3-dim cube position, 7-dim joint angles
- End-effector (6-dim): position and velocity
- Fingertips (7-dim): 6-dim positions of fingertips, 1-dim distance between fingertips
- Gripper (finger) force: 1-dim
Note q is 25-dim: 7-dim cube position/angle, 7-dim goal marker position/angle, 7-dim joint angles, 4-dim finger offsets and dummy fingertip angles
Goal space (7-dim): position of left fingertip, position of right fingertip, distance between fingertips
"""
q_indices = jnp.array([0, 1, 2, 14, 15, 16, 17, 18, 19, 20])
q_subset = pipeline_state.q[q_indices]
eef_index = 8 # Cube is 0, goal marker is 1, then links 1-7 are indices 2-8. The end-effector (eef) base is merged with link 7, so we say link 7 index = eef index.
eef_x_pos = pipeline_state.x.pos[eef_index]
eef_xd_vel = pipeline_state.xd.vel[eef_index]
left_fingertip_index = 10 # Left finger is 9, fingertip is 10
left_fingertip_x_pos = pipeline_state.x.pos[left_fingertip_index]
right_fingertip_index = 12 # Right finger is 11, fingertip is 12
right_fingertip_x_pos = pipeline_state.x.pos[right_fingertip_index]
fingertip_distance = jnp.linalg.norm(right_fingertip_x_pos - left_fingertip_x_pos)[None] # [None] expands dims from 0 to 1
# Index -4 and -2 to get the fingers, rather than fingertips
gripper_force = (pipeline_state.qfrc_actuator[jnp.array([-4, -2])]).mean(keepdims=True) * 0.1 # Normalize it from range [-20, 20] to [-2, 2]
return jnp.concatenate([q_subset] + [eef_x_pos] + [eef_xd_vel] + [left_fingertip_x_pos] + [right_fingertip_x_pos] + [fingertip_distance] + [gripper_force] + [goal])
def _get_arm_angles(self, pipeline_state: base.State) -> jax.Array:
q_indices = jnp.arange(14, 21)
return pipeline_state.q[q_indices]