Skip to content

Commit

Permalink
batched uniform sampler, stack cube v1 task
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 25, 2024
1 parent 6d3a69c commit 4233fe6
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 48 deletions.
19 changes: 16 additions & 3 deletions mani_skill2/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
check_urdf_config,
parse_urdf_config,
)
from mani_skill2.utils.structs.actor import Actor
from mani_skill2.utils.structs.articulation import Articulation

from .controllers.base_controller import (
Expand Down Expand Up @@ -236,16 +237,28 @@ def reset(self, init_qpos=None):
# -------------------------------------------------------------------------- #
# Optional per-agent APIs, implemented depending on agent affordances
# -------------------------------------------------------------------------- #
def is_grasping(self, object: Union[sapien.Entity, None] = None):
def is_grasping(self, object: Union[Actor, None] = None):
"""
Check if this agent is grasping an object or grasping anything at all
Args:
object (sapien.Entity | None):
If object is a sapien.Entity, this function checks grasping against that. If it is none, the function checks if the
object (Actor | None):
If object is a Actor, this function checks grasping against that. If it is none, the function checks if the
agent is grasping anything at all.
Returns:
True if agent is grasping object. False otherwise. If object is None, returns True if agent is grasping something, False if agent is grasping nothing.
"""
raise NotImplementedError()

def is_static(self, threshold: float):
"""
Check if this robot is static (within the given threshold) in terms of the q velocity
Args:
threshold (float): The threshold before this agent is considered static
Returns:
True if agent is static within the threshold. False otherwise
"""
raise NotImplementedError()
4 changes: 4 additions & 0 deletions mani_skill2/agents/robots/panda/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ def is_grasping(self, object: Actor = None, min_impulse=1e-6, max_angle=85):

return all([lflag, rflag])

def is_static(self, threshold: float = 0.2):
qvel = self.robot.get_qvel()[..., :-2]
return torch.max(torch.abs(qvel), 1)[0] <= threshold

@staticmethod
def build_grasp_pose(approaching, closing, center):
"""Build a grasp pose (panda_hand_tcp)."""
Expand Down
8 changes: 5 additions & 3 deletions mani_skill2/envs/minimal_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mani_skill2.utils.sapien_utils import look_at


@register_env(name="CustomEnv-v0", max_episode_steps=200)
@register_env("CustomEnv-v0", max_episode_steps=200)
class CustomEnv(BaseEnv):
"""
Task Description
Expand Down Expand Up @@ -52,9 +52,11 @@ def _get_obs_extra(self):
def evaluate(self, obs: Any):
return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
return torch.zeros(self.num_envs, device=self.device)

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
13 changes: 8 additions & 5 deletions mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def robot_link_ids(self):
def reward_mode(self):
return self._reward_mode

def get_reward(self, obs: Any, action: Array, info: Dict):
def get_reward(self, obs: Any, action: torch.Tensor, info: Dict):
if self._reward_mode == "sparse":
reward = info["success"]
elif self._reward_mode == "dense":
Expand All @@ -428,10 +428,12 @@ def get_reward(self, obs: Any, action: Array, info: Dict):
raise NotImplementedError(self._reward_mode)
return reward

def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
raise NotImplementedError

def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
raise NotImplementedError

# -------------------------------------------------------------------------- #
Expand Down Expand Up @@ -634,7 +636,7 @@ def _clear_sim_state(self):
# -------------------------------------------------------------------------- #

def step(self, action: Union[None, np.ndarray, Dict]):
self.step_action(action)
action = self.step_action(action)
self._elapsed_steps += 1
obs = self.get_obs()
info = self.get_info(obs=obs)
Expand All @@ -652,7 +654,7 @@ def step(self, action: Union[None, np.ndarray, Dict]):
to_numpy(info),
)

def step_action(self, action):
def step_action(self, action) -> Union[None, torch.Tensor]:
set_action = False
if action is None: # simulation without action
pass
Expand Down Expand Up @@ -681,6 +683,7 @@ def step_action(self, action):
self._after_simulation_step()
if physx.is_gpu_enabled():
self._scene._gpu_fetch_all()
return action

def evaluate(self, **kwargs) -> dict:
"""Evaluate whether the environment is currently in a success state."""
Expand Down
1 change: 1 addition & 0 deletions mani_skill2/envs/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .pick_cube import PickCubeEnv
from .push_cube import PushCubeEnv
from .push_object import PushObjectEnv
from .stack_cube import StackCubeEnv
6 changes: 4 additions & 2 deletions mani_skill2/envs/tasks/fmb/fmb.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,14 @@ def evaluate(self, obs: Any):
# for the task. You may also include additional keys which will populate the info object returned by self.step
return {"success": [False]}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
# you can optionally provide a dense reward function by returning a scalar value here. This is used when reward_mode="dense"
reward = torch.zeros(self.num_envs, device=self.device)
return reward

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
# this should be equal to compute_dense_reward / max possible reward
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
9 changes: 5 additions & 4 deletions mani_skill2/envs/tasks/pick_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,14 @@ def evaluate(self, obs: Any):
torch.linalg.norm(self.goal_site.pose.p - self.cube.pose.p, axis=1)
<= self.goal_thresh
)
qvel = self.agent.robot.get_qvel()[..., :-2]
is_robot_static = torch.max(torch.abs(qvel), 1)[0] <= 0.2
is_robot_static = self.agent.is_static(0.2)
return {
"success": torch.logical_and(is_obj_placed, is_robot_static),
"is_obj_placed": is_obj_placed,
"is_robot_static": is_robot_static,
}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
tcp_to_obj_dist = torch.linalg.norm(
self.cube.pose.p - self.agent.tcp.pose.p, axis=1
)
Expand All @@ -137,5 +136,7 @@ def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
reward[info["success"]] = 5
return reward

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
return self.compute_dense_reward(obs=obs, action=action, info=info) / 5
122 changes: 106 additions & 16 deletions mani_skill2/envs/tasks/stack_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.envs.utils.randomization.pose import random_quaternions
from mani_skill2.envs.utils.randomization.samplers import UniformSampler
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.building.actors import build_cube
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at
from mani_skill2.utils.sapien_utils import look_at, to_tensor
from mani_skill2.utils.scene_builder.table.table_scene_builder import TableSceneBuilder
from mani_skill2.utils.structs.pose import Pose


@register_env(name="StackCube-v1", max_episode_steps=100)
@register_env("StackCube-v1", max_episode_steps=100)
class StackCubeEnv(BaseEnv):
def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
Expand All @@ -31,11 +32,11 @@ def _register_render_cameras(self):
return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10)

def _load_actors(self):
self.cube_half_size = to_tensor([0.2] * 3)
self.table_scene = TableSceneBuilder(
env=self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.table_scene.build()
self.box_half_size = torch.Tensor([0.02] * 3, device=self.device)
self.cubeA = build_cube(
self._scene, half_size=0.02, color=[1, 0, 0, 1], name="cubeA"
)
Expand All @@ -44,22 +45,111 @@ def _load_actors(self):
)

def _initialize_actors(self):
self.table_scene.initialize()
qs = random_quaternions(
self._episode_rng, lock_x=True, lock_y=True, lock_z=False, n=self.num_envs
)
ps = [0, 0, 0.02]
self.cubeA.set_pose(Pose.create_from_pq(p=ps, q=qs))
with torch.device(self.device):
self.table_scene.initialize()

xyz = torch.zeros((self.num_envs, 3))
xyz[:, 2] = 0.02
xy = torch.rand((self.num_envs, 2)) * 0.2 - 0.1
region = [[-0.1, -0.2], [0.1, 0.2]]
sampler = UniformSampler(bounds=region, batch_size=self.num_envs)
radius = (torch.linalg.norm(torch.Tensor([0.02, 0.02])) + 0.001).to(
self.device
)
cubeA_xy = xy + sampler.sample(radius, 100)
cubeB_xy = xy + sampler.sample(radius, 100, verbose=False)

xyz[:, :2] = cubeA_xy
qs = random_quaternions(
self.num_envs,
lock_x=True,
lock_y=True,
lock_z=False,
)
self.cubeA.set_pose(Pose.create_from_pq(p=xyz.clone(), q=qs))

xyz[:, :2] = cubeB_xy
qs = random_quaternions(
self.num_envs,
lock_x=True,
lock_y=True,
lock_z=False,
)
self.cubeB.set_pose(Pose.create_from_pq(p=xyz, q=qs))

def _get_obs_extra(self):
return OrderedDict()
obs = OrderedDict(tcp_pose=self.agent.tcp.pose.raw_pose)
if "state" in self.obs_mode:
obs.update(
cubeA_pose=self.cubeA.pose.raw_pose,
cubeB_pose=self.cubeB.pose.raw_pose,
tcp_to_cubeA_pos=self.cubeA.pose.p - self.agent.tcp.pose.p,
tcp_to_cubeB_pos=self.cubeB.pose.p - self.agent.tcp.pose.p,
cubeA_to_cubeB_pos=self.cubeB.pose.p - self.cubeA.pose.p,
)
return obs

def evaluate(self, obs: Any):
return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)}
pos_A = self.cubeA.pose.p
pos_B = self.cubeB.pose.p
offset = pos_A - pos_B
xy_flag = (
torch.linalg.norm(offset[..., :2], axis=1)
<= torch.linalg.norm(self.cube_half_size[:2]) + 0.005
)
z_flag = torch.abs(offset[..., 2] - self.cube_half_size[..., 2] * 2) <= 0.005
is_cubeA_on_cubeB = torch.logical_and(xy_flag, z_flag)
is_cubeA_static = self.cubeA.is_static()
is_cubeA_grasped = self.agent.is_grasping(self.cubeA)

success = is_cubeA_on_cubeB * is_cubeA_static * ~is_cubeA_grasped

return {
"is_cubeA_grasped": is_cubeA_grasped,
"is_cubeA_on_cubeB": is_cubeA_on_cubeB,
"is_cubeA_static": is_cubeA_static,
"success": success,
}

def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
# reaching reward
tcp_pose = self.agent.tcp.pose.p
cubeA_pos = self.cubeA.pose.p
cubeA_to_tcp_dist = torch.linalg.norm(tcp_pose - cubeA_pos, axis=1)
reward = 2 * (1 - torch.tanh(5 * cubeA_to_tcp_dist))

# grasp and place reward
cubeA_pos = self.cubeA.pose.p
cubeB_pos = self.cubeB.pose.p
goal_xyz = torch.hstack(
[cubeB_pos[:, 0:2], (cubeB_pos[:, 2] + self.cube_half_size[2] * 2)[:, None]]
)
cubeA_to_goal_dist = torch.linalg.norm(goal_xyz - cubeA_pos, axis=1)
place_reward = 1 - torch.tanh(5.0 * cubeA_to_goal_dist)

reward[info["is_cubeA_grasped"]] = (4 + place_reward)[info["is_cubeA_grasped"]]

# ungrasp and static reward
gripper_width = (self.agent.robot.get_qlimits()[-1, 1] * 2).to(
self.device
) # NOTE: hard-coded with panda
is_cubeA_grasped = info["is_cubeA_grasped"]
ungrasp_reward = (
torch.sum(self.agent.robot.get_qpos()[:, -2:], axis=1) / gripper_width
)
ungrasp_reward[~is_cubeA_grasped] = 1.0
v = torch.linalg.norm(self.cubeA.linear_velocity, axis=1)
av = torch.linalg.norm(self.cubeA.angular_velocity, axis=1)
static_reward = 1 - torch.tanh(v * 10 + av)
reward[info["is_cubeA_on_cubeB"]] = (
6 + (ungrasp_reward + static_reward) / 2.0
)[info["is_cubeA_on_cubeB"]]

reward[info["success"]] = 8

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
return torch.zeros(self.num_envs, device=self.device)
return reward

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
return self.compute_dense_reward(obs=obs, action=action, info=info) / 8
8 changes: 5 additions & 3 deletions mani_skill2/envs/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


# register the environment by a unique ID and specify a max time limit. Now once this file is imported you can do gym.make("CustomEnv-v0")
@register_env(name="CustomEnv-v0", max_episode_steps=200)
@register_env("CustomEnv-v0", max_episode_steps=200)
class CustomEnv(BaseEnv):
"""
Task Description
Expand Down Expand Up @@ -138,12 +138,14 @@ def evaluate(self, obs: Any):
# note that as everything is batched, you must return a batched array of self.num_envs booleans (or 0/1 values) as done in the example below
return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
# you can optionally provide a dense reward function by returning a scalar value here. This is used when reward_mode="dense"
# note that as everything is batched, you must return a batch of of self.num_envs rewards as done in the example below
return torch.zeros(self.num_envs, device=self.device)

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
# this should be equal to compute_dense_reward / max possible reward
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
Expand Down
Loading

0 comments on commit 4233fe6

Please sign in to comment.