Skip to content

Commit

Permalink
Update pick_single_ycb.py
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 25, 2024
1 parent 69fc586 commit 3585f98
Showing 1 changed file with 97 additions and 8 deletions.
105 changes: 97 additions & 8 deletions mani_skill2/envs/tasks/pick_single_ycb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List

import numpy as np
import sapien
import torch

from mani_skill2.envs.sapien_env import BaseEnv
Expand All @@ -11,6 +12,7 @@
MODEL_DBS,
_load_ycb_dataset,
build_actor_ycb,
build_sphere,
)
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at
Expand All @@ -28,13 +30,19 @@ class PickSingleYCBEnv(BaseEnv):
Randomizations
--------------
- the object's xy position is randomized on top of a table in the region [0.1, 0.1] x [-0.1, -0.1]. It is placed flat on the table
- the object's z-axis rotation is randomized
- the object geometry is randomized by randomly sampling any YCB object
Success Conditions
------------------
Visualization: link to a video/gif of the task being solved
"""

goal_thresh = 0.025

def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
self.model_id = None
Expand Down Expand Up @@ -69,6 +77,8 @@ def _load_actors(self):
print(
"There are less parallel environments than total available models to sample. The environment will run considerably slower"
)
# TODO (stao): with less envs than models, we should be reconfiguring more often, which is unfortunately also very slow on gpu sim
# alternatively provide option for user to specify reconfiguration frequency in terms of # of resets?

actors: List[Actor] = []
self.obj_heights = []
Expand All @@ -83,28 +93,107 @@ def _load_actors(self):
self.obj_heights.append(obj_height)
self.obj = Actor.merge_actors(actors, name="ycb_object")

self.goal_site = build_sphere(
self._scene,
radius=self.goal_thresh,
color=[0, 1, 0, 1],
name="goal_site",
body_type="kinematic",
add_collision=False,
)

def _initialize_actors(self):
with torch.device(self.device):
self.table_scene.initialize()
ps = torch.zeros((self.num_envs, 3))
xyz = torch.zeros((self.num_envs, 3))
xyz[:, :2] = torch.rand((self.num_envs, 2)) * 0.2 - 0.1
for i in range(self.num_envs):
# use ycb object bounding box heights to set it properly on the table
ps[i, 2] = self.obj_heights[i] / 2
xyz[i, 2] = self.obj_heights[i] / 2

qs = random_quaternions(self.num_envs, lock_x=True, lock_y=True)
self.obj.set_pose(Pose.create_from_pq(p=ps, q=qs))
self.obj.set_pose(Pose.create_from_pq(p=xyz, q=qs))

goal_xyz = torch.zeros((self.num_envs, 3))
goal_xyz[:, :2] = torch.rand((self.num_envs, 2)) * 0.2 - 0.1
goal_xyz[:, 2] = torch.rand((self.num_envs, 2)) * 0.3 + xyz[:, 2]
self.goal_site.set_pose(Pose.create_from_pq(goal_xyz))

# Initialize robot arm to a higher position above the table than the default typically used for other table top tasks
if self.robot_uid == "panda":
# fmt: off
qpos = np.array(
[0.0, 0, 0, -np.pi * 2 / 3, 0, np.pi * 2 / 3, np.pi / 4, 0.04, 0.04]
)
# fmt: on
qpos[:-2] += self._episode_rng.normal(
0, self.robot_init_qpos_noise, len(qpos) - 2
)
self.agent.reset(qpos)
self.agent.robot.set_root_pose(sapien.Pose([-0.615, 0, 0]))
elif self.robot_uid == "xmate3_robotiq":
qpos = np.array([0, 0.6, 0, 1.3, 0, 1.3, -1.57, 0, 0])
qpos[:-2] += self._episode_rng.normal(
0, self.robot_init_qpos_noise, len(qpos) - 2
)
self.agent.reset(qpos)
self.agent.robot.set_root_pose(sapien.Pose([-0.562, 0, 0]))
else:
raise NotImplementedError(self.robot_uid)

def _get_obs_extra(self):
return OrderedDict()
obs = OrderedDict(
tcp_pose=self.agent.tcp.pose.raw_pose,
goal_pos=self.goal_site.pose.p,
)
if "state" in self.obs_mode:
# TODO (stao): previously we used some cmass pose. Why was that?
obs.update(
tcp_to_goal_pos=self.goal_site.p - self.agent.tcp.pose.p,
obj_pose=self.obj.pose.raw_pose,
tcp_to_obj_pos=self.obj.pose.p - self.agent.tcp.pose.p,
obj_to_goal_pos=self.goal_site.pose.p - self.obj.pose.p,
)
return obs

def evaluate(self, obs: Any):
return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)}
obj_to_goal_pos = self.goal_site.pose.p - self.obj.pose.p
is_obj_placed = torch.linalg.norm(obj_to_goal_pos, axis=1) <= self.goal_thresh
is_robot_static = self.agent.is_static()
return dict(
obj_to_goal_pos=obj_to_goal_pos,
is_obj_placed=is_obj_placed,
is_robot_static=is_robot_static,
success=torch.logical_and(is_obj_placed, is_robot_static),
)

def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
return torch.zeros(self.num_envs, device=self.device)
tcp_to_obj_dist = torch.linalg.norm(
self.obj.pose.p - self.agent.tcp.pose.p, axis=1
)
reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
reward = reaching_reward

is_grasped = self.agent.is_grasping(self.cube)
reward += is_grasped

obj_to_goal_dist = torch.linalg.norm(
self.goal_site.pose.p - self.obj.pose.p, axis=1
)
place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
reward += place_reward * is_grasped

reward += info["is_obj_placed"] * is_grasped

static_reward = 1 - torch.tanh(
5 * torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1)
)
reward += static_reward * info["is_obj_placed"] * is_grasped

reward[info["success"]] = 6
return reward

def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
return self.compute_dense_reward(obs=obs, action=action, info=info) / 6

0 comments on commit 3585f98

Please sign in to comment.