From 4233fe651d7efa69058f6eb8781dfd6098a6192c Mon Sep 17 00:00:00 2001 From: Stone Tao Date: Wed, 24 Jan 2024 16:30:21 -0800 Subject: [PATCH] batched uniform sampler, stack cube v1 task --- mani_skill2/agents/base_agent.py | 19 ++- mani_skill2/agents/robots/panda/panda.py | 4 + mani_skill2/envs/minimal_template.py | 8 +- mani_skill2/envs/sapien_env.py | 13 +- mani_skill2/envs/tasks/__init__.py | 1 + mani_skill2/envs/tasks/fmb/fmb.py | 6 +- mani_skill2/envs/tasks/pick_cube.py | 9 +- mani_skill2/envs/tasks/stack_cube.py | 122 +++++++++++++++--- mani_skill2/envs/template.py | 8 +- .../envs/utils/randomization/samplers.py | 87 +++++++++++++ mani_skill2/utils/common.py | 5 +- mani_skill2/utils/structs/actor.py | 9 ++ manualtest/visual_all_envs_cpu.py | 22 ++-- 13 files changed, 265 insertions(+), 48 deletions(-) diff --git a/mani_skill2/agents/base_agent.py b/mani_skill2/agents/base_agent.py index 2b14cc0b4..c090b3458 100644 --- a/mani_skill2/agents/base_agent.py +++ b/mani_skill2/agents/base_agent.py @@ -14,6 +14,7 @@ check_urdf_config, parse_urdf_config, ) +from mani_skill2.utils.structs.actor import Actor from mani_skill2.utils.structs.articulation import Articulation from .controllers.base_controller import ( @@ -236,16 +237,28 @@ def reset(self, init_qpos=None): # -------------------------------------------------------------------------- # # Optional per-agent APIs, implemented depending on agent affordances # -------------------------------------------------------------------------- # - def is_grasping(self, object: Union[sapien.Entity, None] = None): + def is_grasping(self, object: Union[Actor, None] = None): """ Check if this agent is grasping an object or grasping anything at all Args: - object (sapien.Entity | None): - If object is a sapien.Entity, this function checks grasping against that. If it is none, the function checks if the + object (Actor | None): + If object is a Actor, this function checks grasping against that. If it is none, the function checks if the agent is grasping anything at all. Returns: True if agent is grasping object. False otherwise. If object is None, returns True if agent is grasping something, False if agent is grasping nothing. """ raise NotImplementedError() + + def is_static(self, threshold: float): + """ + Check if this robot is static (within the given threshold) in terms of the q velocity + + Args: + threshold (float): The threshold before this agent is considered static + + Returns: + True if agent is static within the threshold. False otherwise + """ + raise NotImplementedError() diff --git a/mani_skill2/agents/robots/panda/panda.py b/mani_skill2/agents/robots/panda/panda.py index 3ac037de4..15291b173 100644 --- a/mani_skill2/agents/robots/panda/panda.py +++ b/mani_skill2/agents/robots/panda/panda.py @@ -292,6 +292,10 @@ def is_grasping(self, object: Actor = None, min_impulse=1e-6, max_angle=85): return all([lflag, rflag]) + def is_static(self, threshold: float = 0.2): + qvel = self.robot.get_qvel()[..., :-2] + return torch.max(torch.abs(qvel), 1)[0] <= threshold + @staticmethod def build_grasp_pose(approaching, closing, center): """Build a grasp pose (panda_hand_tcp).""" diff --git a/mani_skill2/envs/minimal_template.py b/mani_skill2/envs/minimal_template.py index b59f6ecb2..d3f4a30f3 100644 --- a/mani_skill2/envs/minimal_template.py +++ b/mani_skill2/envs/minimal_template.py @@ -10,7 +10,7 @@ from mani_skill2.utils.sapien_utils import look_at -@register_env(name="CustomEnv-v0", max_episode_steps=200) +@register_env("CustomEnv-v0", max_episode_steps=200) class CustomEnv(BaseEnv): """ Task Description @@ -52,9 +52,11 @@ def _get_obs_extra(self): def evaluate(self, obs: Any): return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)} - def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): return torch.zeros(self.num_envs, device=self.device) - def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): max_reward = 1.0 return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward diff --git a/mani_skill2/envs/sapien_env.py b/mani_skill2/envs/sapien_env.py index 4c525cd52..90591926b 100644 --- a/mani_skill2/envs/sapien_env.py +++ b/mani_skill2/envs/sapien_env.py @@ -415,7 +415,7 @@ def robot_link_ids(self): def reward_mode(self): return self._reward_mode - def get_reward(self, obs: Any, action: Array, info: Dict): + def get_reward(self, obs: Any, action: torch.Tensor, info: Dict): if self._reward_mode == "sparse": reward = info["success"] elif self._reward_mode == "dense": @@ -428,10 +428,12 @@ def get_reward(self, obs: Any, action: Array, info: Dict): raise NotImplementedError(self._reward_mode) return reward - def compute_dense_reward(self, obs: Any, action: Array, info: Dict): + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): raise NotImplementedError - def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): raise NotImplementedError # -------------------------------------------------------------------------- # @@ -634,7 +636,7 @@ def _clear_sim_state(self): # -------------------------------------------------------------------------- # def step(self, action: Union[None, np.ndarray, Dict]): - self.step_action(action) + action = self.step_action(action) self._elapsed_steps += 1 obs = self.get_obs() info = self.get_info(obs=obs) @@ -652,7 +654,7 @@ def step(self, action: Union[None, np.ndarray, Dict]): to_numpy(info), ) - def step_action(self, action): + def step_action(self, action) -> Union[None, torch.Tensor]: set_action = False if action is None: # simulation without action pass @@ -681,6 +683,7 @@ def step_action(self, action): self._after_simulation_step() if physx.is_gpu_enabled(): self._scene._gpu_fetch_all() + return action def evaluate(self, **kwargs) -> dict: """Evaluate whether the environment is currently in a success state.""" diff --git a/mani_skill2/envs/tasks/__init__.py b/mani_skill2/envs/tasks/__init__.py index ccee7261a..4fa3b6f3f 100644 --- a/mani_skill2/envs/tasks/__init__.py +++ b/mani_skill2/envs/tasks/__init__.py @@ -2,3 +2,4 @@ from .pick_cube import PickCubeEnv from .push_cube import PushCubeEnv from .push_object import PushObjectEnv +from .stack_cube import StackCubeEnv diff --git a/mani_skill2/envs/tasks/fmb/fmb.py b/mani_skill2/envs/tasks/fmb/fmb.py index 2c8a926e1..30b244f99 100644 --- a/mani_skill2/envs/tasks/fmb/fmb.py +++ b/mani_skill2/envs/tasks/fmb/fmb.py @@ -203,12 +203,14 @@ def evaluate(self, obs: Any): # for the task. You may also include additional keys which will populate the info object returned by self.step return {"success": [False]} - def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): # you can optionally provide a dense reward function by returning a scalar value here. This is used when reward_mode="dense" reward = torch.zeros(self.num_envs, device=self.device) return reward - def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): # this should be equal to compute_dense_reward / max possible reward max_reward = 1.0 return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward diff --git a/mani_skill2/envs/tasks/pick_cube.py b/mani_skill2/envs/tasks/pick_cube.py index 46ce4e854..824c45b8c 100644 --- a/mani_skill2/envs/tasks/pick_cube.py +++ b/mani_skill2/envs/tasks/pick_cube.py @@ -105,15 +105,14 @@ def evaluate(self, obs: Any): torch.linalg.norm(self.goal_site.pose.p - self.cube.pose.p, axis=1) <= self.goal_thresh ) - qvel = self.agent.robot.get_qvel()[..., :-2] - is_robot_static = torch.max(torch.abs(qvel), 1)[0] <= 0.2 + is_robot_static = self.agent.is_static(0.2) return { "success": torch.logical_and(is_obj_placed, is_robot_static), "is_obj_placed": is_obj_placed, "is_robot_static": is_robot_static, } - def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): tcp_to_obj_dist = torch.linalg.norm( self.cube.pose.p - self.agent.tcp.pose.p, axis=1 ) @@ -137,5 +136,7 @@ def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): reward[info["success"]] = 5 return reward - def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): return self.compute_dense_reward(obs=obs, action=action, info=info) / 5 diff --git a/mani_skill2/envs/tasks/stack_cube.py b/mani_skill2/envs/tasks/stack_cube.py index bc937c35b..01e8ec2eb 100644 --- a/mani_skill2/envs/tasks/stack_cube.py +++ b/mani_skill2/envs/tasks/stack_cube.py @@ -6,15 +6,16 @@ from mani_skill2.envs.sapien_env import BaseEnv from mani_skill2.envs.utils.randomization.pose import random_quaternions +from mani_skill2.envs.utils.randomization.samplers import UniformSampler from mani_skill2.sensors.camera import CameraConfig from mani_skill2.utils.building.actors import build_cube from mani_skill2.utils.registration import register_env -from mani_skill2.utils.sapien_utils import look_at +from mani_skill2.utils.sapien_utils import look_at, to_tensor from mani_skill2.utils.scene_builder.table.table_scene_builder import TableSceneBuilder from mani_skill2.utils.structs.pose import Pose -@register_env(name="StackCube-v1", max_episode_steps=100) +@register_env("StackCube-v1", max_episode_steps=100) class StackCubeEnv(BaseEnv): def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs): self.robot_init_qpos_noise = robot_init_qpos_noise @@ -31,11 +32,11 @@ def _register_render_cameras(self): return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10) def _load_actors(self): + self.cube_half_size = to_tensor([0.2] * 3) self.table_scene = TableSceneBuilder( env=self, robot_init_qpos_noise=self.robot_init_qpos_noise ) self.table_scene.build() - self.box_half_size = torch.Tensor([0.02] * 3, device=self.device) self.cubeA = build_cube( self._scene, half_size=0.02, color=[1, 0, 0, 1], name="cubeA" ) @@ -44,22 +45,111 @@ def _load_actors(self): ) def _initialize_actors(self): - self.table_scene.initialize() - qs = random_quaternions( - self._episode_rng, lock_x=True, lock_y=True, lock_z=False, n=self.num_envs - ) - ps = [0, 0, 0.02] - self.cubeA.set_pose(Pose.create_from_pq(p=ps, q=qs)) + with torch.device(self.device): + self.table_scene.initialize() + + xyz = torch.zeros((self.num_envs, 3)) + xyz[:, 2] = 0.02 + xy = torch.rand((self.num_envs, 2)) * 0.2 - 0.1 + region = [[-0.1, -0.2], [0.1, 0.2]] + sampler = UniformSampler(bounds=region, batch_size=self.num_envs) + radius = (torch.linalg.norm(torch.Tensor([0.02, 0.02])) + 0.001).to( + self.device + ) + cubeA_xy = xy + sampler.sample(radius, 100) + cubeB_xy = xy + sampler.sample(radius, 100, verbose=False) + + xyz[:, :2] = cubeA_xy + qs = random_quaternions( + self.num_envs, + lock_x=True, + lock_y=True, + lock_z=False, + ) + self.cubeA.set_pose(Pose.create_from_pq(p=xyz.clone(), q=qs)) + + xyz[:, :2] = cubeB_xy + qs = random_quaternions( + self.num_envs, + lock_x=True, + lock_y=True, + lock_z=False, + ) + self.cubeB.set_pose(Pose.create_from_pq(p=xyz, q=qs)) def _get_obs_extra(self): - return OrderedDict() + obs = OrderedDict(tcp_pose=self.agent.tcp.pose.raw_pose) + if "state" in self.obs_mode: + obs.update( + cubeA_pose=self.cubeA.pose.raw_pose, + cubeB_pose=self.cubeB.pose.raw_pose, + tcp_to_cubeA_pos=self.cubeA.pose.p - self.agent.tcp.pose.p, + tcp_to_cubeB_pos=self.cubeB.pose.p - self.agent.tcp.pose.p, + cubeA_to_cubeB_pos=self.cubeB.pose.p - self.cubeA.pose.p, + ) + return obs def evaluate(self, obs: Any): - return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)} + pos_A = self.cubeA.pose.p + pos_B = self.cubeB.pose.p + offset = pos_A - pos_B + xy_flag = ( + torch.linalg.norm(offset[..., :2], axis=1) + <= torch.linalg.norm(self.cube_half_size[:2]) + 0.005 + ) + z_flag = torch.abs(offset[..., 2] - self.cube_half_size[..., 2] * 2) <= 0.005 + is_cubeA_on_cubeB = torch.logical_and(xy_flag, z_flag) + is_cubeA_static = self.cubeA.is_static() + is_cubeA_grasped = self.agent.is_grasping(self.cubeA) + + success = is_cubeA_on_cubeB * is_cubeA_static * ~is_cubeA_grasped + + return { + "is_cubeA_grasped": is_cubeA_grasped, + "is_cubeA_on_cubeB": is_cubeA_on_cubeB, + "is_cubeA_static": is_cubeA_static, + "success": success, + } + + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): + # reaching reward + tcp_pose = self.agent.tcp.pose.p + cubeA_pos = self.cubeA.pose.p + cubeA_to_tcp_dist = torch.linalg.norm(tcp_pose - cubeA_pos, axis=1) + reward = 2 * (1 - torch.tanh(5 * cubeA_to_tcp_dist)) + + # grasp and place reward + cubeA_pos = self.cubeA.pose.p + cubeB_pos = self.cubeB.pose.p + goal_xyz = torch.hstack( + [cubeB_pos[:, 0:2], (cubeB_pos[:, 2] + self.cube_half_size[2] * 2)[:, None]] + ) + cubeA_to_goal_dist = torch.linalg.norm(goal_xyz - cubeA_pos, axis=1) + place_reward = 1 - torch.tanh(5.0 * cubeA_to_goal_dist) + + reward[info["is_cubeA_grasped"]] = (4 + place_reward)[info["is_cubeA_grasped"]] + + # ungrasp and static reward + gripper_width = (self.agent.robot.get_qlimits()[-1, 1] * 2).to( + self.device + ) # NOTE: hard-coded with panda + is_cubeA_grasped = info["is_cubeA_grasped"] + ungrasp_reward = ( + torch.sum(self.agent.robot.get_qpos()[:, -2:], axis=1) / gripper_width + ) + ungrasp_reward[~is_cubeA_grasped] = 1.0 + v = torch.linalg.norm(self.cubeA.linear_velocity, axis=1) + av = torch.linalg.norm(self.cubeA.angular_velocity, axis=1) + static_reward = 1 - torch.tanh(v * 10 + av) + reward[info["is_cubeA_on_cubeB"]] = ( + 6 + (ungrasp_reward + static_reward) / 2.0 + )[info["is_cubeA_on_cubeB"]] + + reward[info["success"]] = 8 - def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): - return torch.zeros(self.num_envs, device=self.device) + return reward - def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): - max_reward = 1.0 - return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): + return self.compute_dense_reward(obs=obs, action=action, info=info) / 8 diff --git a/mani_skill2/envs/template.py b/mani_skill2/envs/template.py index 004af09d1..38670d587 100644 --- a/mani_skill2/envs/template.py +++ b/mani_skill2/envs/template.py @@ -33,7 +33,7 @@ # register the environment by a unique ID and specify a max time limit. Now once this file is imported you can do gym.make("CustomEnv-v0") -@register_env(name="CustomEnv-v0", max_episode_steps=200) +@register_env("CustomEnv-v0", max_episode_steps=200) class CustomEnv(BaseEnv): """ Task Description @@ -138,12 +138,14 @@ def evaluate(self, obs: Any): # note that as everything is batched, you must return a batched array of self.num_envs booleans (or 0/1 values) as done in the example below return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)} - def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): # you can optionally provide a dense reward function by returning a scalar value here. This is used when reward_mode="dense" # note that as everything is batched, you must return a batch of of self.num_envs rewards as done in the example below return torch.zeros(self.num_envs, device=self.device) - def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict): + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): # this should be equal to compute_dense_reward / max possible reward max_reward = 1.0 return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward diff --git a/mani_skill2/envs/utils/randomization/samplers.py b/mani_skill2/envs/utils/randomization/samplers.py index e69de29bb..7eda9bbd5 100644 --- a/mani_skill2/envs/utils/randomization/samplers.py +++ b/mani_skill2/envs/utils/randomization/samplers.py @@ -0,0 +1,87 @@ +from typing import List, Tuple + +import torch + +from mani_skill2.utils.sapien_utils import to_tensor + + +class UniformSampler: + """Uniform placement sampler that supports sampling in batch + + Args: + bounds: ((low1, low2, ...), (high1, high2, ...)) + batch_size (int): The number of points to sample with each call to sample(...) + """ + + def __init__( + self, bounds: Tuple[List[float], List[float]], batch_size: int + ) -> None: + assert len(bounds) == 2 and len(bounds[0]) == len(bounds[1]) + self._bounds = to_tensor(bounds) + self._ranges = self._bounds[1] - self._bounds[0] + self.fixtures_radii = None + self.fixture_positions = None + self.batch_size = batch_size + + def sample(self, radius, max_trials, append=True, verbose=False): + """Sample a position. + + Args: + radius (float): collision radius. + max_trials (int): maximal trials to sample. + append (bool, optional): whether to append the new sample to fixtures. Defaults to True. + verbose (bool, optional): whether to print verbosely. Defaults to False. + + Returns: + torch.Tensor: a sampled position. + """ + if self.fixture_positions is None: + sampled_pos = ( + torch.rand((self.batch_size, self._bounds.shape[1])) * self._ranges + + self._bounds[0] + ) + else: + pass_mask = torch.zeros((self.batch_size), dtype=bool) + sampled_pos = torch.zeros((self.batch_size, self._bounds.shape[1])) + for i in range(max_trials): + pos = ( + torch.rand((self.batch_size, self._bounds.shape[1])) * self._ranges + + self._bounds[0] + ) # (B, d) + dist = torch.linalg.norm( + pos - self.fixture_positions, axis=-1 + ) # (n, B) + radii = self.fixtures_radii + radius # (n, ) + mask = torch.all(dist > radii[:, None], axis=0) # (B, ) + sampled_pos[mask] = pos[mask] + pass_mask[mask] = True + if torch.all(pass_mask): + if verbose: + print( + f"Found valid set of {self.batch_size=} samples at {i}-th trial" + ) + break + else: + if verbose: + print("Fail to find a valid sample!") + if append: + if self.fixture_positions is None: + self.fixture_positions = sampled_pos[None, ...] + else: + self.fixture_positions = torch.concat( + [self.fixture_positions, sampled_pos[None, ...]] + ) + if self.fixtures_radii is None: + self.fixtures_radii = to_tensor(radius).reshape( + 1, + ) + else: + self.fixtures_radii = torch.concat( + [ + self.fixtures_radii, + to_tensor(radius).reshape( + 1, + ), + ] + ) + return sampled_pos diff --git a/mani_skill2/utils/common.py b/mani_skill2/utils/common.py index 7be50d984..63156f7d7 100644 --- a/mani_skill2/utils/common.py +++ b/mani_skill2/utils/common.py @@ -8,6 +8,7 @@ from gymnasium import spaces from mani_skill2.utils.sapien_utils import to_tensor +from mani_skill2.utils.structs.types import Array from .logging_utils import logger @@ -171,7 +172,7 @@ def inv_scale_action(action, low, high): # TODO (stao): Clean up this code -def flatten_state_dict(state_dict: dict, squeeze_dims: bool = False) -> np.ndarray: +def flatten_state_dict(state_dict: dict, squeeze_dims: bool = False) -> Array: """Flatten a dictionary containing states recursively. Args: @@ -228,7 +229,7 @@ def flatten_state_dict(state_dict: dict, squeeze_dims: bool = False) -> np.ndarr if physx.is_gpu_enabled(): if len(states) == 0: - return torch.empty(0) + return torch.empty(0, device="cuda") else: return torch.hstack(states) else: diff --git a/mani_skill2/utils/structs/actor.py b/mani_skill2/utils/structs/actor.py index 27eb671e0..ffa079bde 100644 --- a/mani_skill2/utils/structs/actor.py +++ b/mani_skill2/utils/structs/actor.py @@ -135,6 +135,15 @@ def show_visual(self): ).visibility = 1 self.hidden = False + def is_static(self, lin_thresh=1e-3, ang_thresh=1e-2): + """ + Checks if this actor is static within the given linear velocity threshold `lin_thresh` and angular velocity threshold `ang_thresh` + """ + return torch.logical_and( + torch.linalg.norm(self.linear_velocity, axis=1) <= lin_thresh, + torch.linalg.norm(self.angular_velocity, axis=1) <= ang_thresh, + ) + # -------------------------------------------------------------------------- # # Exposed actor properties, getters/setters that automatically handle # CPU and GPU based actors diff --git a/manualtest/visual_all_envs_cpu.py b/manualtest/visual_all_envs_cpu.py index 14e567466..382012f49 100644 --- a/manualtest/visual_all_envs_cpu.py +++ b/manualtest/visual_all_envs_cpu.py @@ -7,8 +7,8 @@ if __name__ == "__main__": # , "StackCube-v0", "LiftCube-v0" - num_envs = 2 - for env_id in ["PushCube-v0"]: # , "StackCube-v0", "LiftCube-v0"]: + num_envs = 4 + for env_id in ["StackCube-v1"]: # , "StackCube-v0", "LiftCube-v0"]: env = gym.make( env_id, num_envs=num_envs, @@ -20,15 +20,17 @@ sim_freq=500, control_freq=100, ) - env = RecordEpisode( - env, - output_dir="videos/manual_test", - trajectory_name=f"{env_id}", - info_on_video=False, - video_fps=30, - save_trajectory=False, - ) + # env = RecordEpisode( + # env, + # output_dir="videos/manual_test", + # trajectory_name=f"{env_id}", + # info_on_video=False, + # video_fps=30, + # save_trajectory=False, + # ) env.reset(seed=0) + env.reset(seed=1) + env.reset(seed=2) done = False i = 0 if num_envs == 1: