From 9a6710f45b2eb03df634e1bc979eb0cd74378ccc Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Thu, 25 Jan 2024 01:02:59 -0800 Subject: [PATCH] WIP work --- examples/baselines/ppo/ppo.py | 1 + .../robots/mobile_panda/base_mobile_agent.py | 16 +-- .../mobile_panda/mobile_panda_dual_arm.py | 48 +++---- mani_skill2/envs/tasks/__init__.py | 1 + mani_skill2/envs/tasks/open_cabinet_drawer.py | 121 ++++++++++++++++++ mani_skill2/utils/building/articulations.py | 55 ++++---- mani_skill2/utils/building/urdf_loader.py | 14 +- mani_skill2/utils/structs/articulation.py | 25 +++- manualtest/visual_all_envs_cpu.py | 15 ++- 9 files changed, 228 insertions(+), 68 deletions(-) create mode 100644 mani_skill2/envs/tasks/open_cabinet_drawer.py diff --git a/examples/baselines/ppo/ppo.py b/examples/baselines/ppo/ppo.py index e0c7d06bc..69530f46d 100644 --- a/examples/baselines/ppo/ppo.py +++ b/examples/baselines/ppo/ppo.py @@ -264,6 +264,7 @@ def clip_action(action: torch.Tensor): info = infos["final_info"] episodic_return = info['episode']['r'].mean().cpu().numpy() print(f"global_step={global_step}, episodic_return={episodic_return}") + writer.add_scalar("charts/is_cubeA_on_cubeB", info["is_cubeA_on_cubeB"].float().mean().cpu().numpy(), global_step) writer.add_scalar("charts/success_rate", info["success"].float().mean().cpu().numpy(), global_step) writer.add_scalar("charts/episodic_return", episodic_return, global_step) writer.add_scalar("charts/episodic_length", info["elapsed_steps"], global_step) diff --git a/mani_skill2/agents/robots/mobile_panda/base_mobile_agent.py b/mani_skill2/agents/robots/mobile_panda/base_mobile_agent.py index 485a57619..5c41cd610 100644 --- a/mani_skill2/agents/robots/mobile_panda/base_mobile_agent.py +++ b/mani_skill2/agents/robots/mobile_panda/base_mobile_agent.py @@ -36,18 +36,18 @@ def _after_init(self): self.base_link = self.robot.get_links()[3] # Ignore collision between the adjustable body and ground - body = get_obj_by_name(self.robot.get_links(), "adjustable_body") - - s = body.get_collision_shapes()[0] - gs = s.get_collision_groups() - gs[2] = gs[2] | 1 << 30 - s.set_collision_groups(gs) + bodies = get_obj_by_name(self.robot.get_links(), "adjustable_body") + for body in bodies._objs: + s = body.get_collision_shapes()[0] + gs = s.get_collision_groups() + gs[2] = gs[2] | 1 << 30 + s.set_collision_groups(gs) def get_proprioception(self): state_dict = super().get_proprioception() qpos, qvel = state_dict["qpos"], state_dict["qvel"] - base_pos, base_orientation, arm_qpos = qpos[:2], qpos[2], qpos[3:] - base_vel, base_ang_vel, arm_qvel = qvel[:2], qvel[2], qvel[3:] + base_pos, base_orientation, arm_qpos = qpos[:, :2], qpos[:, 2], qpos[:, 3:] + base_vel, base_ang_vel, arm_qvel = qvel[:, :2], qvel[:, 2], qvel[:, 3:] state_dict["qpos"] = arm_qpos state_dict["qvel"] = arm_qvel diff --git a/mani_skill2/agents/robots/mobile_panda/mobile_panda_dual_arm.py b/mani_skill2/agents/robots/mobile_panda/mobile_panda_dual_arm.py index ee31e0d7e..c63120b25 100644 --- a/mani_skill2/agents/robots/mobile_panda/mobile_panda_dual_arm.py +++ b/mani_skill2/agents/robots/mobile_panda/mobile_panda_dual_arm.py @@ -290,29 +290,31 @@ def controller_configs(self): # Make a deepcopy in case users modify any config return deepcopy_dict(controller_configs) + sensor_configs = [] + # TODO (stao): Remove this @property and make sensor configs completely statically defined. # The expectation is that a robot should not physically ever be changed usually. If you do # want to adapt the robot, you should inherit a robot class and make appropriate changes - @property - def sensor_configs(self): - sensors = [] - qs = [ - [0.9238795, 0, 0.3826834, 0], - [0.46193977, 0.33141357, 0.19134172, -0.80010315], - [-0.46193977, 0.33141357, -0.19134172, -0.80010315], - ] - for i in range(3): - q = qs[i] - camera = CameraConfig( - f"overhead_camera_{i}", - p=[0, 0, self.camera_h], - q=q, - width=400, - height=160, - near=0.1, - far=10, - fov=np.pi / 3, - entity_uid="mobile_base", - ) - sensors.append(camera) - return sensors + # @property + # def sensor_configs(self): + # sensors = [] + # qs = [ + # [0.9238795, 0, 0.3826834, 0], + # [0.46193977, 0.33141357, 0.19134172, -0.80010315], + # [-0.46193977, 0.33141357, -0.19134172, -0.80010315], + # ] + # for i in range(3): + # q = qs[i] + # camera = CameraConfig( + # f"overhead_camera_{i}", + # p=[0, 0, self.camera_h], + # q=q, + # width=400, + # height=160, + # near=0.1, + # far=10, + # fov=np.pi / 3, + # entity_uid="mobile_base", + # ) + # sensors.append(camera) + # return sensors diff --git a/mani_skill2/envs/tasks/__init__.py b/mani_skill2/envs/tasks/__init__.py index 7fb0c0b1e..2be2ec6c3 100644 --- a/mani_skill2/envs/tasks/__init__.py +++ b/mani_skill2/envs/tasks/__init__.py @@ -1,4 +1,5 @@ from .fmb.fmb import FMBEnv +from .open_cabinet_drawer import OpenCabinetEnv from .pick_cube import PickCubeEnv from .pick_single_ycb import PickSingleYCBEnv from .push_cube import PushCubeEnv diff --git a/mani_skill2/envs/tasks/open_cabinet_drawer.py b/mani_skill2/envs/tasks/open_cabinet_drawer.py new file mode 100644 index 000000000..19782379d --- /dev/null +++ b/mani_skill2/envs/tasks/open_cabinet_drawer.py @@ -0,0 +1,121 @@ +from collections import OrderedDict +from typing import Any, Dict + +import numpy as np +import torch + +from mani_skill2.envs.sapien_env import BaseEnv +from mani_skill2.sensors.camera import CameraConfig +from mani_skill2.utils.building.articulations import ( + MODEL_DBS, + _load_partnet_mobility_dataset, + build_preprocessed_partnet_mobility_articulation, +) +from mani_skill2.utils.building.ground import build_tesselated_square_floor +from mani_skill2.utils.registration import register_env +from mani_skill2.utils.sapien_utils import look_at +from mani_skill2.utils.structs.articulation import Articulation +from mani_skill2.utils.structs.pose import Pose + + +@register_env("OpenCabinet-v1", max_episode_steps=200) +class OpenCabinetEnv(BaseEnv): + """ + Task Description + ---------------- + Add a task description here + + Randomizations + -------------- + + Success Conditions + ------------------ + + Visualization: link to a video/gif of the task being solved + """ + + def __init__( + self, + *args, + robot_uid="mobile_panda_single_arm", + robot_init_qpos_noise=0.02, + **kwargs, + ): + self.robot_init_qpos_noise = robot_init_qpos_noise + _load_partnet_mobility_dataset() + self.all_model_ids = np.array( + list(MODEL_DBS["PartnetMobility"]["model_data"].keys()) + ) + super().__init__(*args, robot_uid=robot_uid, **kwargs) + + def _register_sensors(self): + pose = look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) + return [ + CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10) + ] + + def _register_render_cameras(self): + pose = look_at(eye=[-1.5, -1.5, 1.5], target=[-0.1, 0, 0.1]) + return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10) + + def _load_actors(self): + model_ids = self.all_model_ids[: self.num_envs] + self.ground = build_tesselated_square_floor(self._scene) + + cabinets = [] + for i, model_id in enumerate(model_ids): + scene_mask = np.zeros(self.num_envs, dtype=bool) + scene_mask[i] = True + cabinet, metadata = build_preprocessed_partnet_mobility_articulation( + self._scene, model_id, name=f"{model_id}-i", scene_mask=scene_mask + ) + cabinets.append(cabinet) + self.cabinet = Articulation.merge_articulations(cabinets, name="cabinet") + self.cabinet_metadata = metadata + + def _initialize_actors(self): + height = ( + self.cabinet_metadata.bbox.bounds[0, 2] + - self.cabinet_metadata.bbox.bounds[1, 2] + ) + self.cabinet.set_pose(Pose.create_from_pq(p=[0, 0, -height / 2])) + qlimits = self.cabinet.get_qlimits() # [N, 2] + assert not np.isinf(qlimits).any(), qlimits + qpos = np.ascontiguousarray(qlimits[:, 0]) + # NOTE(jigu): must use a contiguous array for `set_qpos` + self.cabinet.set_qpos(qpos) + + # initialize robot + if self.robot_uid == "panda": + self.agent.robot.set_pose(Pose.create_from_pq(p=[-1, 0, 0])) + elif self.robot_uid == "mobile_panda_single_arm": + center = np.array([0, 0.8]) + dist = self._episode_rng.uniform(1.6, 1.8) + theta = self._episode_rng.uniform(0.9 * np.pi, 1.1 * np.pi) + direction = np.array([np.cos(theta), np.sin(theta)]) + xy = center + direction * dist + + # Base orientation + noise_ori = self._episode_rng.uniform(-0.05 * np.pi, 0.05 * np.pi) + ori = (theta - np.pi) + noise_ori + + h = 1e-4 + arm_qpos = np.array([0, 0, 0, -1.5, 0, 3, 0.78, 0.02, 0.02]) + + qpos = np.hstack([xy, ori, h, arm_qpos]) + self.agent.reset(qpos) + + def _get_obs_extra(self): + return OrderedDict() + + def evaluate(self, obs: Any): + return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)} + + def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict): + return torch.zeros(self.num_envs, device=self.device) + + def compute_normalized_dense_reward( + self, obs: Any, action: torch.Tensor, info: Dict + ): + max_reward = 1.0 + return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward diff --git a/mani_skill2/utils/building/articulations.py b/mani_skill2/utils/building/articulations.py index b835d5a1a..5fe0f528d 100644 --- a/mani_skill2/utils/building/articulations.py +++ b/mani_skill2/utils/building/articulations.py @@ -9,9 +9,11 @@ import sapien import sapien.physx as physx import sapien.render +import trimesh from sapien import Pose from mani_skill2 import ASSET_DIR, PACKAGE_ASSET_DIR +from mani_skill2.envs.scene import ManiSkillScene from mani_skill2.utils.geometry.trimesh_utils import ( get_articulation_meshes, merge_meshes, @@ -41,6 +43,7 @@ class ArticulationMetadata: links: Dict[str, LinkMetadata] # a list of all movable links movable_links: List[str] + bbox: trimesh.primitives.Box def build_articulation_from_file( @@ -65,16 +68,17 @@ def build_articulation_from_file( # cache model metadata here if needed -model_dbs: Dict[str, Dict[str, Dict]] = {} +MODEL_DBS: Dict[str, Dict[str, Dict]] = {} ### Build articulations ### def build_preprocessed_partnet_mobility_articulation( - scene: sapien.Scene, + scene: ManiSkillScene, model_id: str, + name: str, fix_root_link=True, urdf_config: dict = None, - set_object_on_ground=True, + scene_mask=None, ): """ Builds a physx.PhysxArticulation object into the scene and returns metadata containing annotations of the object's links and joints. @@ -88,34 +92,38 @@ def build_preprocessed_partnet_mobility_articulation( set_object_on_ground: whether to change the pose of the built articulation such that the object is settled on the ground (at z = 0) """ - if "PartnetMobility" not in model_dbs: + if "PartnetMobility" not in MODEL_DBS: _load_partnet_mobility_dataset() - metadata = model_dbs["PartnetMobility"]["model_data"][model_id] + metadata = MODEL_DBS["PartnetMobility"]["model_data"][model_id] loader = scene.create_urdf_loader() loader.fix_root_link = fix_root_link loader.scale = metadata["scale"] loader.load_multiple_collisions_from_file = True - urdf_path = model_dbs["PartnetMobility"]["model_urdf_paths"][model_id] + urdf_path = MODEL_DBS["PartnetMobility"]["model_urdf_paths"][model_id] urdf_config = parse_urdf_config(urdf_config or {}, scene) apply_urdf_config(loader, urdf_config) - articulation: physx.PhysxArticulation = loader.load(str(urdf_path)) + articulation = loader.load(str(urdf_path), name=name, scene_mask=scene_mask) - metadata = ArticulationMetadata(joints=dict(), links=dict(), movable_links=[]) + metadata = ArticulationMetadata( + joints=dict(), links=dict(), movable_links=[], bbox=None + ) # NOTE(jigu): links and their parent joints. for link, joint in zip(articulation.get_links(), articulation.get_joints()): metadata.joints[joint.name] = JointMetadata(type=joint.type, name="") - render_body = link.entity.find_component_by_type( - sapien.render.RenderBodyComponent - ) - render_shapes = [] - if render_body is not None: - render_shapes = render_body.render_shapes + # render_body = link.entity.find_component_by_type( + # sapien.render.RenderBodyComponent + # ) + # render_shapes = [] + # if render_body is not None: + # render_shapes = render_body.render_shapes metadata.links[link.name] = LinkMetadata( - name=None, link=link, render_shapes=render_shapes + name=None, + link=link, + render_shapes=[], # render_shapes=render_shapes ) if joint.type != "fixed": metadata.movable_links.append(link.name) @@ -129,21 +137,14 @@ def build_preprocessed_partnet_mobility_articulation( metadata.links[link_id].name = link_name metadata.joints[f"joint_{link_id.split('_')[1]}"].name = joint_type - if set_object_on_ground: - qlimits = articulation.get_qlimits() # [N, 2] - assert not np.isinf(qlimits).any(), qlimits - qpos = np.ascontiguousarray(qlimits[:, 0]) - articulation.set_qpos(qpos) - articulation.set_pose(Pose()) - bounds = merge_meshes(get_articulation_meshes(articulation)).bounds - articulation.set_pose(Pose([0, 0, -bounds[0, 2]])) - + bbox = merge_meshes(get_articulation_meshes(articulation._objs[0])).bounding_box + metadata.bbox = bbox return articulation, metadata def _load_partnet_mobility_dataset(): """loads preprocssed partnet mobility metadata""" - model_dbs["PartnetMobility"] = { + MODEL_DBS["PartnetMobility"] = { "model_data": load_json( PACKAGE_ASSET_DIR / "partnet_mobility/meta/info_cabinet_drawer_train.json" ), @@ -157,6 +158,6 @@ def find_urdf_path(model_id): if urdf_path.exists(): return urdf_path - model_dbs["PartnetMobility"]["model_urdf_paths"] = { - k: find_urdf_path(k) for k in model_dbs["PartnetMobility"]["model_data"].keys() + MODEL_DBS["PartnetMobility"]["model_urdf_paths"] = { + k: find_urdf_path(k) for k in MODEL_DBS["PartnetMobility"]["model_data"].keys() } diff --git a/mani_skill2/utils/building/urdf_loader.py b/mani_skill2/utils/building/urdf_loader.py index 483e11888..fc8ee467b 100644 --- a/mani_skill2/utils/building/urdf_loader.py +++ b/mani_skill2/utils/building/urdf_loader.py @@ -37,16 +37,25 @@ def load_file_as_articulation_builder( urdf_file, srdf_file, package_dir ) - def load(self, urdf_file: str, srdf_file=None, package_dir=None) -> Articulation: + def load( + self, + urdf_file: str, + srdf_file=None, + package_dir=None, + name=None, + scene_mask=None, + ) -> Articulation: """ Args: urdf_file: filename for URDL file srdf_file: SRDF for urdf_file. If srdf_file is None, it defaults to the ".srdf" file with the same as the urdf file package_dir: base directory used to resolve asset files in the URDF file. If an asset path starts with "package://", "package://" is simply removed from the file name + name (str): name of the created articulation Returns: returns a single Articulation loaded from the URDF file. It throws an error if multiple objects exists """ - + if name is not None: + self.name = name articulation_builders, actor_builders, cameras = self.parse( urdf_file, srdf_file, package_dir ) @@ -58,6 +67,7 @@ def load(self, urdf_file: str, srdf_file=None, package_dir=None) -> Articulation articulations: List[Articulation] = [] for b in articulation_builders: + b.set_scene_mask(scene_mask) articulations.append(b.build()) actors: List[Actor] = [] diff --git a/mani_skill2/utils/structs/articulation.py b/mani_skill2/utils/structs/articulation.py index 51aa8158a..88108e81b 100644 --- a/mani_skill2/utils/structs/articulation.py +++ b/mani_skill2/utils/structs/articulation.py @@ -61,7 +61,8 @@ def _create_from_physx_articulations( name=shared_name, ) # create link and joint structs - num_links = len(physx_articulations[0].links) + # num_links = len(physx_articulations[0].links) + num_links = max([len(x.links) for x in physx_articulations]) all_links_objs: List[List[physx.PhysxArticulationLinkComponent]] = [ [] for _ in range(num_links) ] @@ -71,7 +72,9 @@ def _create_from_physx_articulations( ] link_map = OrderedDict() + import ipdb + ipdb.set_trace() for articulation in physx_articulations: assert num_links == len(articulation.links) and num_joints == len( articulation.joints @@ -114,6 +117,26 @@ def _create_from_physx_articulations( return self + @classmethod + def merge_articulations(cls, articulations: List["Articulation"], name: str = None): + objs = [] + scene = articulations[0]._scene + merged_scene_mask = articulations[0]._scene_mask.clone() + num_objs_per_actor = articulations[0]._num_objs + for articulation in articulations: + objs += articulation._objs + merged_scene_mask[articulation._scene_mask] = True + del scene.articulations[articulation.name] + assert ( + articulation._num_objs == num_objs_per_actor + ), "Each given articulation must have the same number of managed objects" + merged_articulation = Articulation._create_from_physx_articulations( + objs, scene, merged_scene_mask + ) + merged_articulation.name = name + scene.articulations[merged_articulation.name] = merged_articulation + return merged_articulation + # -------------------------------------------------------------------------- # # Additional useful functions not in SAPIEN original API # -------------------------------------------------------------------------- # diff --git a/manualtest/visual_all_envs_cpu.py b/manualtest/visual_all_envs_cpu.py index 78956dc0e..70c4fe8e1 100644 --- a/manualtest/visual_all_envs_cpu.py +++ b/manualtest/visual_all_envs_cpu.py @@ -6,17 +6,18 @@ from mani_skill2.utils.wrappers import RecordEpisode if __name__ == "__main__": - # , "StackCube-v0", "LiftCube-v0" - num_envs = 200 - for env_id in ["PickSingleYCB-v1"]: # , "StackCube-v0", "LiftCube-v0"]: + # , "StackCube-v1", "PickCube-v1", "PushCube-v1", "PickSingleYCB-v1" + num_envs = 4 + for env_id in ["OpenCabinet-v1"]: # , "StackCube-v0", "LiftCube-v0"]: env = gym.make( env_id, num_envs=num_envs, enable_shadow=True, robot_uid="panda", reward_mode="normalized_dense", - render_mode="cameras", - control_mode="pd_ee_delta_pos", + render_mode="rgb_array", + # control_mode="base_pd_joint_vel_arm_pd_joint_delta_pos", + control_mode="pd_joint_delta_pos", sim_freq=500, control_freq=100, ) @@ -38,8 +39,8 @@ env.render_human() while i < 50 or (i < 50000 and num_envs == 1): action = env.action_space.sample() - action[:] * 0 - action[:, 2] = -1 + # action[:] * 0 + # action[:, 2] = -1 obs, rew, terminated, truncated, info = env.step(action) done = np.logical_or(to_numpy(terminated), to_numpy(truncated)) if num_envs == 1: