Skip to content

Commit

Permalink
WIP work
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 25, 2024
1 parent 46c6b03 commit 9a6710f
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 68 deletions.
1 change: 1 addition & 0 deletions examples/baselines/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def clip_action(action: torch.Tensor):
info = infos["final_info"]
episodic_return = info['episode']['r'].mean().cpu().numpy()
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/is_cubeA_on_cubeB", info["is_cubeA_on_cubeB"].float().mean().cpu().numpy(), global_step)
writer.add_scalar("charts/success_rate", info["success"].float().mean().cpu().numpy(), global_step)
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", info["elapsed_steps"], global_step)
Expand Down
16 changes: 8 additions & 8 deletions mani_skill2/agents/robots/mobile_panda/base_mobile_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ def _after_init(self):
self.base_link = self.robot.get_links()[3]

# Ignore collision between the adjustable body and ground
body = get_obj_by_name(self.robot.get_links(), "adjustable_body")

s = body.get_collision_shapes()[0]
gs = s.get_collision_groups()
gs[2] = gs[2] | 1 << 30
s.set_collision_groups(gs)
bodies = get_obj_by_name(self.robot.get_links(), "adjustable_body")
for body in bodies._objs:
s = body.get_collision_shapes()[0]
gs = s.get_collision_groups()
gs[2] = gs[2] | 1 << 30
s.set_collision_groups(gs)

def get_proprioception(self):
state_dict = super().get_proprioception()
qpos, qvel = state_dict["qpos"], state_dict["qvel"]
base_pos, base_orientation, arm_qpos = qpos[:2], qpos[2], qpos[3:]
base_vel, base_ang_vel, arm_qvel = qvel[:2], qvel[2], qvel[3:]
base_pos, base_orientation, arm_qpos = qpos[:, :2], qpos[:, 2], qpos[:, 3:]
base_vel, base_ang_vel, arm_qvel = qvel[:, :2], qvel[:, 2], qvel[:, 3:]

state_dict["qpos"] = arm_qpos
state_dict["qvel"] = arm_qvel
Expand Down
48 changes: 25 additions & 23 deletions mani_skill2/agents/robots/mobile_panda/mobile_panda_dual_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,29 +290,31 @@ def controller_configs(self):
# Make a deepcopy in case users modify any config
return deepcopy_dict(controller_configs)

sensor_configs = []

# TODO (stao): Remove this @property and make sensor configs completely statically defined.
# The expectation is that a robot should not physically ever be changed usually. If you do
# want to adapt the robot, you should inherit a robot class and make appropriate changes
@property
def sensor_configs(self):
sensors = []
qs = [
[0.9238795, 0, 0.3826834, 0],
[0.46193977, 0.33141357, 0.19134172, -0.80010315],
[-0.46193977, 0.33141357, -0.19134172, -0.80010315],
]
for i in range(3):
q = qs[i]
camera = CameraConfig(
f"overhead_camera_{i}",
p=[0, 0, self.camera_h],
q=q,
width=400,
height=160,
near=0.1,
far=10,
fov=np.pi / 3,
entity_uid="mobile_base",
)
sensors.append(camera)
return sensors
# @property
# def sensor_configs(self):
# sensors = []
# qs = [
# [0.9238795, 0, 0.3826834, 0],
# [0.46193977, 0.33141357, 0.19134172, -0.80010315],
# [-0.46193977, 0.33141357, -0.19134172, -0.80010315],
# ]
# for i in range(3):
# q = qs[i]
# camera = CameraConfig(
# f"overhead_camera_{i}",
# p=[0, 0, self.camera_h],
# q=q,
# width=400,
# height=160,
# near=0.1,
# far=10,
# fov=np.pi / 3,
# entity_uid="mobile_base",
# )
# sensors.append(camera)
# return sensors
1 change: 1 addition & 0 deletions mani_skill2/envs/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .fmb.fmb import FMBEnv
from .open_cabinet_drawer import OpenCabinetEnv
from .pick_cube import PickCubeEnv
from .pick_single_ycb import PickSingleYCBEnv
from .push_cube import PushCubeEnv
Expand Down
121 changes: 121 additions & 0 deletions mani_skill2/envs/tasks/open_cabinet_drawer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from collections import OrderedDict
from typing import Any, Dict

import numpy as np
import torch

from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.building.articulations import (
MODEL_DBS,
_load_partnet_mobility_dataset,
build_preprocessed_partnet_mobility_articulation,
)
from mani_skill2.utils.building.ground import build_tesselated_square_floor
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at
from mani_skill2.utils.structs.articulation import Articulation
from mani_skill2.utils.structs.pose import Pose


@register_env("OpenCabinet-v1", max_episode_steps=200)
class OpenCabinetEnv(BaseEnv):
"""
Task Description
----------------
Add a task description here
Randomizations
--------------
Success Conditions
------------------
Visualization: link to a video/gif of the task being solved
"""

def __init__(
self,
*args,
robot_uid="mobile_panda_single_arm",
robot_init_qpos_noise=0.02,
**kwargs,
):
self.robot_init_qpos_noise = robot_init_qpos_noise
_load_partnet_mobility_dataset()
self.all_model_ids = np.array(
list(MODEL_DBS["PartnetMobility"]["model_data"].keys())
)
super().__init__(*args, robot_uid=robot_uid, **kwargs)

def _register_sensors(self):
pose = look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [
CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10)
]

def _register_render_cameras(self):
pose = look_at(eye=[-1.5, -1.5, 1.5], target=[-0.1, 0, 0.1])
return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10)

def _load_actors(self):
model_ids = self.all_model_ids[: self.num_envs]
self.ground = build_tesselated_square_floor(self._scene)

cabinets = []
for i, model_id in enumerate(model_ids):
scene_mask = np.zeros(self.num_envs, dtype=bool)
scene_mask[i] = True
cabinet, metadata = build_preprocessed_partnet_mobility_articulation(
self._scene, model_id, name=f"{model_id}-i", scene_mask=scene_mask
)
cabinets.append(cabinet)
self.cabinet = Articulation.merge_articulations(cabinets, name="cabinet")
self.cabinet_metadata = metadata

def _initialize_actors(self):
height = (
self.cabinet_metadata.bbox.bounds[0, 2]
- self.cabinet_metadata.bbox.bounds[1, 2]
)
self.cabinet.set_pose(Pose.create_from_pq(p=[0, 0, -height / 2]))
qlimits = self.cabinet.get_qlimits() # [N, 2]
assert not np.isinf(qlimits).any(), qlimits
qpos = np.ascontiguousarray(qlimits[:, 0])
# NOTE(jigu): must use a contiguous array for `set_qpos`
self.cabinet.set_qpos(qpos)

# initialize robot
if self.robot_uid == "panda":
self.agent.robot.set_pose(Pose.create_from_pq(p=[-1, 0, 0]))
elif self.robot_uid == "mobile_panda_single_arm":
center = np.array([0, 0.8])
dist = self._episode_rng.uniform(1.6, 1.8)
theta = self._episode_rng.uniform(0.9 * np.pi, 1.1 * np.pi)
direction = np.array([np.cos(theta), np.sin(theta)])
xy = center + direction * dist

# Base orientation
noise_ori = self._episode_rng.uniform(-0.05 * np.pi, 0.05 * np.pi)
ori = (theta - np.pi) + noise_ori

h = 1e-4
arm_qpos = np.array([0, 0, 0, -1.5, 0, 3, 0.78, 0.02, 0.02])

qpos = np.hstack([xy, ori, h, arm_qpos])
self.agent.reset(qpos)

def _get_obs_extra(self):
return OrderedDict()

def evaluate(self, obs: Any):
return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)}

def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
return torch.zeros(self.num_envs, device=self.device)

def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
55 changes: 28 additions & 27 deletions mani_skill2/utils/building/articulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import sapien
import sapien.physx as physx
import sapien.render
import trimesh
from sapien import Pose

from mani_skill2 import ASSET_DIR, PACKAGE_ASSET_DIR
from mani_skill2.envs.scene import ManiSkillScene
from mani_skill2.utils.geometry.trimesh_utils import (
get_articulation_meshes,
merge_meshes,
Expand Down Expand Up @@ -41,6 +43,7 @@ class ArticulationMetadata:
links: Dict[str, LinkMetadata]
# a list of all movable links
movable_links: List[str]
bbox: trimesh.primitives.Box


def build_articulation_from_file(
Expand All @@ -65,16 +68,17 @@ def build_articulation_from_file(


# cache model metadata here if needed
model_dbs: Dict[str, Dict[str, Dict]] = {}
MODEL_DBS: Dict[str, Dict[str, Dict]] = {}


### Build articulations ###
def build_preprocessed_partnet_mobility_articulation(
scene: sapien.Scene,
scene: ManiSkillScene,
model_id: str,
name: str,
fix_root_link=True,
urdf_config: dict = None,
set_object_on_ground=True,
scene_mask=None,
):
"""
Builds a physx.PhysxArticulation object into the scene and returns metadata containing annotations of the object's links and joints.
Expand All @@ -88,34 +92,38 @@ def build_preprocessed_partnet_mobility_articulation(
set_object_on_ground: whether to change the pose of the built articulation such that the object is settled on the ground (at z = 0)
"""
if "PartnetMobility" not in model_dbs:
if "PartnetMobility" not in MODEL_DBS:
_load_partnet_mobility_dataset()

metadata = model_dbs["PartnetMobility"]["model_data"][model_id]
metadata = MODEL_DBS["PartnetMobility"]["model_data"][model_id]

loader = scene.create_urdf_loader()
loader.fix_root_link = fix_root_link
loader.scale = metadata["scale"]
loader.load_multiple_collisions_from_file = True

urdf_path = model_dbs["PartnetMobility"]["model_urdf_paths"][model_id]
urdf_path = MODEL_DBS["PartnetMobility"]["model_urdf_paths"][model_id]
urdf_config = parse_urdf_config(urdf_config or {}, scene)
apply_urdf_config(loader, urdf_config)
articulation: physx.PhysxArticulation = loader.load(str(urdf_path))
articulation = loader.load(str(urdf_path), name=name, scene_mask=scene_mask)

metadata = ArticulationMetadata(joints=dict(), links=dict(), movable_links=[])
metadata = ArticulationMetadata(
joints=dict(), links=dict(), movable_links=[], bbox=None
)

# NOTE(jigu): links and their parent joints.
for link, joint in zip(articulation.get_links(), articulation.get_joints()):
metadata.joints[joint.name] = JointMetadata(type=joint.type, name="")
render_body = link.entity.find_component_by_type(
sapien.render.RenderBodyComponent
)
render_shapes = []
if render_body is not None:
render_shapes = render_body.render_shapes
# render_body = link.entity.find_component_by_type(
# sapien.render.RenderBodyComponent
# )
# render_shapes = []
# if render_body is not None:
# render_shapes = render_body.render_shapes
metadata.links[link.name] = LinkMetadata(
name=None, link=link, render_shapes=render_shapes
name=None,
link=link,
render_shapes=[], # render_shapes=render_shapes
)
if joint.type != "fixed":
metadata.movable_links.append(link.name)
Expand All @@ -129,21 +137,14 @@ def build_preprocessed_partnet_mobility_articulation(
metadata.links[link_id].name = link_name
metadata.joints[f"joint_{link_id.split('_')[1]}"].name = joint_type

if set_object_on_ground:
qlimits = articulation.get_qlimits() # [N, 2]
assert not np.isinf(qlimits).any(), qlimits
qpos = np.ascontiguousarray(qlimits[:, 0])
articulation.set_qpos(qpos)
articulation.set_pose(Pose())
bounds = merge_meshes(get_articulation_meshes(articulation)).bounds
articulation.set_pose(Pose([0, 0, -bounds[0, 2]]))

bbox = merge_meshes(get_articulation_meshes(articulation._objs[0])).bounding_box
metadata.bbox = bbox
return articulation, metadata


def _load_partnet_mobility_dataset():
"""loads preprocssed partnet mobility metadata"""
model_dbs["PartnetMobility"] = {
MODEL_DBS["PartnetMobility"] = {
"model_data": load_json(
PACKAGE_ASSET_DIR / "partnet_mobility/meta/info_cabinet_drawer_train.json"
),
Expand All @@ -157,6 +158,6 @@ def find_urdf_path(model_id):
if urdf_path.exists():
return urdf_path

model_dbs["PartnetMobility"]["model_urdf_paths"] = {
k: find_urdf_path(k) for k in model_dbs["PartnetMobility"]["model_data"].keys()
MODEL_DBS["PartnetMobility"]["model_urdf_paths"] = {
k: find_urdf_path(k) for k in MODEL_DBS["PartnetMobility"]["model_data"].keys()
}
14 changes: 12 additions & 2 deletions mani_skill2/utils/building/urdf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,25 @@ def load_file_as_articulation_builder(
urdf_file, srdf_file, package_dir
)

def load(self, urdf_file: str, srdf_file=None, package_dir=None) -> Articulation:
def load(
self,
urdf_file: str,
srdf_file=None,
package_dir=None,
name=None,
scene_mask=None,
) -> Articulation:
"""
Args:
urdf_file: filename for URDL file
srdf_file: SRDF for urdf_file. If srdf_file is None, it defaults to the ".srdf" file with the same as the urdf file
package_dir: base directory used to resolve asset files in the URDF file. If an asset path starts with "package://", "package://" is simply removed from the file name
name (str): name of the created articulation
Returns:
returns a single Articulation loaded from the URDF file. It throws an error if multiple objects exists
"""

if name is not None:
self.name = name
articulation_builders, actor_builders, cameras = self.parse(
urdf_file, srdf_file, package_dir
)
Expand All @@ -58,6 +67,7 @@ def load(self, urdf_file: str, srdf_file=None, package_dir=None) -> Articulation

articulations: List[Articulation] = []
for b in articulation_builders:
b.set_scene_mask(scene_mask)
articulations.append(b.build())

actors: List[Actor] = []
Expand Down
Loading

0 comments on commit 9a6710f

Please sign in to comment.