Skip to content

Commit

Permalink
refactor some code, cache some properties, batch some properties
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 25, 2024
1 parent 14eac02 commit e99464b
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 40 deletions.
4 changes: 2 additions & 2 deletions mani_skill2/agents/controllers/pd_joint_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class PDJointPosController(BaseController):
config: "PDJointPosControllerConfig"

def _get_joint_limits(self):
qlimits = self.articulation.get_qlimit()[self.joint_indices]
qlimits = self.articulation.get_qlimits()[0, self.joint_indices].cpu().numpy()
# Override if specified
if self.config.lower is not None:
qlimits[:, 0] = self.config.lower
Expand All @@ -24,7 +24,7 @@ def _get_joint_limits(self):
return qlimits

def _initialize_action_space(self):
joint_limits = self._get_joint_limits().cpu().numpy()
joint_limits = self._get_joint_limits()
low, high = joint_limits[:, 0], joint_limits[:, 1]
self.action_space = spaces.Box(low, high, dtype=np.float32)

Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/envs/ms2/assembly/assembling_kits.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def compute_dense_reward(self, info, **kwargs):

reward = 0.0
gripper_width = (
self.agent.robot.get_qlimit()[-1, 1] * 2
self.agent.robot.get_qlimits()[-1, 1] * 2
) # NOTE: hard-coded with panda

# reaching reward
Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/envs/ms2/pick_and_place/stack_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def evaluate(self, **kwargs):

def compute_dense_reward(self, info, **kwargs):
gripper_width = (
self.agent.robot.get_qlimit()[-1, 1] * 2
self.agent.robot.get_qlimits()[-1, 1] * 2
) # NOTE: hard-coded with panda
# TODO (stao): rewrite dense reward for this task. TBH it should just be nearly the same as pick cube.
# # grasp pose rotation reward
Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def _clear_sim_state(self):
actor.set_linear_velocity([0, 0, 0])
actor.set_angular_velocity([0, 0, 0])
for articulation in self._scene.articulations.values():
articulation.set_qvel(np.zeros(articulation.dof))
articulation.set_qvel(np.zeros(articulation.max_dof))
# articulation.set_root_velocity([0, 0, 0])
# articulation.set_root_angular_velocity([0, 0, 0])
if physx.is_gpu_enabled():
Expand Down
4 changes: 3 additions & 1 deletion mani_skill2/envs/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ def set_sim_state(self, state: Array):
actor.set_state(state[:, start : start + KINEMATIC_DIM])
start += KINEMATIC_DIM
for articulation in self.articulations.values():
ndim = KINEMATIC_DIM + 2 * articulation.dof
# TODO (stao): when multiple articulations are managed by the same object we have to take the max DOF
# but then restoring state is rather non trivial, need to store dof as part of state somewhere?
ndim = KINEMATIC_DIM + 2 * articulation.max_dof
articulation.set_state(state[:, start : start + ndim])
start += ndim

Expand Down
13 changes: 8 additions & 5 deletions mani_skill2/envs/tasks/open_cabinet_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def _load_actors(self):
cabinet, metadata = build_preprocessed_partnet_mobility_articulation(
self._scene, model_id, name=f"{model_id}-i", scene_mask=scene_mask
)
# self.cabinet = cabinet
# self.cabinet_metadata = metadata
cabinets.append(cabinet)
self.cabinet = Articulation.merge_articulations(cabinets, name="cabinet")
self.cabinet_metadata = metadata
Expand All @@ -79,14 +81,15 @@ def _initialize_actors(self):
- self.cabinet_metadata.bbox.bounds[1, 2]
)
self.cabinet.set_pose(Pose.create_from_pq(p=[0, 0, -height / 2]))
qlimits = self.cabinet.get_qlimits() # [N, 2]
assert not np.isinf(qlimits).any(), qlimits
qpos = np.ascontiguousarray(qlimits[:, 0])
# NOTE(jigu): must use a contiguous array for `set_qpos`
self.cabinet.set_qpos(qpos)
qlimits = self.cabinet.get_qlimits() # [N, self.cabinet.max_dof, 2]
qpos = qlimits[:, :, 0]
self.cabinet.set_qpos(
qpos
) # close all the cabinets. We know beforehand that lower qlimit means "closed" for these assets.

# initialize robot
if self.robot_uid == "panda":
self.agent.robot.set_qpos(self.agent.robot.qpos * 0)
self.agent.robot.set_pose(Pose.create_from_pq(p=[-1, 0, 0]))
elif self.robot_uid == "mobile_panda_single_arm":
center = np.array([0, 0.8])
Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/envs/tasks/stack_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
reward[info["is_cubeA_grasped"]] = (4 + place_reward)[info["is_cubeA_grasped"]]

# ungrasp and static reward
gripper_width = (self.agent.robot.get_qlimits()[-1, 1] * 2).to(
gripper_width = (self.agent.robot.get_qlimits()[0, -1, 1] * 2).to(
self.device
) # NOTE: hard-coded with panda
is_cubeA_grasped = info["is_cubeA_grasped"]
Expand Down
10 changes: 8 additions & 2 deletions mani_skill2/utils/building/articulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def build_preprocessed_partnet_mobility_articulation(
loader.fix_root_link = fix_root_link
loader.scale = metadata["scale"]
loader.load_multiple_collisions_from_file = True

# loader.multiple_collisions_decomposition="coacd"
urdf_path = MODEL_DBS["PartnetMobility"]["model_urdf_paths"][model_id]
urdf_config = parse_urdf_config(urdf_config or {}, scene)
apply_urdf_config(loader, urdf_config)
Expand All @@ -110,7 +110,13 @@ def build_preprocessed_partnet_mobility_articulation(
metadata = ArticulationMetadata(
joints=dict(), links=dict(), movable_links=[], bbox=None
)

# for link in articulation._objs[0].links:
# # rb = link.
# print(link.name)
# tc = 0
# for rs in link.collision_shapes:
# count = rs.get_vertices().shape[0]
# print(count)
# NOTE(jigu): links and their parent joints.
for link, joint in zip(articulation.get_links(), articulation.get_joints()):
metadata.joints[joint.name] = JointMetadata(type=joint.type, name="")
Expand Down
70 changes: 44 additions & 26 deletions mani_skill2/utils/structs/articulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import OrderedDict
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, List, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -66,19 +67,17 @@ def _create_from_physx_articulations(
all_links_objs: List[List[physx.PhysxArticulationLinkComponent]] = [
[] for _ in range(num_links)
]
num_joints = len(physx_articulations[0].joints)
# num_joints = len(physx_articulations[0].joints)
num_joints = max([len(x.joints) for x in physx_articulations])
all_joint_objs: List[List[physx.PhysxArticulationJoint]] = [
[] for _ in range(num_joints)
]

link_map = OrderedDict()
import ipdb

ipdb.set_trace()
for articulation in physx_articulations:
assert num_links == len(articulation.links) and num_joints == len(
articulation.joints
), "Gave different physx articulations. Each Articulation object can only manage the same articulations, not different ones"
# assert num_links == len(articulation.links) and num_joints == len(
# articulation.joints
# ), "Gave different physx articulations. Each Articulation object can only manage the same articulations, not different ones"
for i, link in enumerate(articulation.links):
all_links_objs[i].append(link)
for i, joint in enumerate(articulation.joints):
Expand Down Expand Up @@ -162,6 +161,16 @@ def set_state(self, state: Array):
self.set_qpos(qpos)
self.set_qvel(qvel)

@cached_property
def max_dof(self) -> int:
return max([obj.dof for obj in self._objs])

def bbox(self):
import ipdb

ipdb.set_trace()
self._objs[0]

# -------------------------------------------------------------------------- #
# Functions from physx.PhysxArticulation
# -------------------------------------------------------------------------- #
Expand Down Expand Up @@ -203,11 +212,12 @@ def get_pose(self) -> sapien.Pose:
def get_qf(self):
return self.qf

def get_qlimit(self):
"""
same as get_qlimits
"""
return self.qlimits
# def get_qlimit(self):
# removed this function from ManiSkill Articulation wrapper API as it is redundant
# """
# same as get_qlimits
# """
# return self.qlimits

def get_qlimits(self):
return self.qlimits
Expand Down Expand Up @@ -257,12 +267,12 @@ def set_root_pose(self, pose: sapien.Pose) -> None:
# def active_joints(self):
# return self._articulations[0].active_joints

@property
@cached_property
def dof(self) -> int:
"""
:type: int
"""
return self._objs[0].dof
return torch.tensor([obj.dof for obj in self._objs])

# @property
# def gpu_index(self) -> int:
Expand Down Expand Up @@ -308,53 +318,61 @@ def pose(self, arg1: sapien.Pose) -> None:
@property
def qf(self):
if physx.is_gpu_enabled():
return self.px.cuda_articulation_qf[self._data_index, : self.dof]
return self.px.cuda_articulation_qf[self._data_index, : self.max_dof]
else:
return torch.from_numpy(self._objs[0].qf[None, :])

@qf.setter
def qf(self, arg1):
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self.px.cuda_articulation_qf[self._data_index, : self.dof] = arg1
self.px.cuda_articulation_qf[self._data_index, : self.max_dof] = arg1
else:
self._objs[0].qf = arg1

@property
def qlimit(self):
return torch.from_numpy(self._objs[0].qlimit)

@property
@cached_property
def qlimits(self):
return torch.from_numpy(self._objs[0].qlimits)
padded_qlimits = np.array(
[
np.concatenate([obj.qlimits, np.zeros((self.max_dof - obj.dof, 2))])
for obj in self._objs
]
)
padded_qlimits = torch.from_numpy(padded_qlimits).float()
if physx.is_gpu_enabled():
return padded_qlimits.cuda()
else:
return padded_qlimits

@property
def qpos(self):
if physx.is_gpu_enabled():
return self.px.cuda_articulation_qpos[self._data_index, : self.dof]
# NOTE (stao): cuda_articulation_qpos is of shape (M, N) where M is the total number of articulations in the physx scene,
# N is the max dof of all those articulations.
return self.px.cuda_articulation_qpos[self._data_index, : self.max_dof]
else:
return torch.from_numpy(self._objs[0].qpos[None, :])

@qpos.setter
def qpos(self, arg1):
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self.px.cuda_articulation_qpos[self._data_index, : self.dof] = arg1
self.px.cuda_articulation_qpos[self._data_index, : self.max_dof] = arg1
else:
self._objs[0].qpos = arg1

@property
def qvel(self):
if physx.is_gpu_enabled():
return self.px.cuda_articulation_qvel[self._data_index, : self.dof]
return self.px.cuda_articulation_qvel[self._data_index, : self.max_dof]
else:
return torch.from_numpy(self._objs[0].qvel[None, :])

@qvel.setter
def qvel(self, arg1):
if physx.is_gpu_enabled():
arg1 = to_tensor(arg1)
self.px.cuda_articulation_qvel[self._data_index, : self.dof] = arg1
self.px.cuda_articulation_qvel[self._data_index, : self.max_dof] = arg1
else:
self._objs[0].qvel = arg1

Expand Down
7 changes: 7 additions & 0 deletions manualtest/visual_all_envs_cpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gymnasium as gym
import numpy as np
import sapien

import mani_skill2.envs
from mani_skill2.utils.sapien_utils import to_numpy
Expand All @@ -8,6 +9,11 @@
if __name__ == "__main__":
# , "StackCube-v1", "PickCube-v1", "PushCube-v1", "PickSingleYCB-v1"
num_envs = 4
sapien.physx.set_gpu_memory_config(
found_lost_pairs_capacity=2**26,
max_rigid_patch_count=2**19,
max_rigid_contact_count=2**21,
)
for env_id in ["OpenCabinet-v1"]: # , "StackCube-v0", "LiftCube-v0"]:
env = gym.make(
env_id,
Expand Down Expand Up @@ -38,6 +44,7 @@
viewer.paused = True
env.render_human()
while i < 50 or (i < 50000 and num_envs == 1):
print(i)
action = env.action_space.sample()
# action[:] * 0
# action[:, 2] = -1
Expand Down

0 comments on commit e99464b

Please sign in to comment.