Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 23, 2024
1 parent 6dde6eb commit 18b176c
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 40 deletions.
4 changes: 3 additions & 1 deletion mani_skill2/agents/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def _initialize_joints(self):

def _assert_fully_actuated(self):
active_joints = self.articulation.get_active_joints()
if len(active_joints) != len(self.joints) or set(active_joints) != set(self.joints):
if len(active_joints) != len(self.joints) or set(active_joints) != set(
self.joints
):
print("active_joints:", [x.name for x in active_joints])
print("controlled_joints:", [x.name for x in self.joints])
raise AssertionError("{} is not fully actuated".format(self.articulation))
Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/agents/robots/fetch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fetch import Fetch, FETCH_UNIQUE_COLLISION_BIT
from .fetch import FETCH_UNIQUE_COLLISION_BIT, Fetch
85 changes: 61 additions & 24 deletions mani_skill2/agents/robots/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
get_actor_contacts,
get_obj_by_name,
get_pairwise_contact_impulse,
to_tensor,
)
from mani_skill2.utils.structs.actor import Actor
from mani_skill2.utils.sapien_utils import to_tensor
from mani_skill2.utils.structs.base import BaseStruct
from mani_skill2.utils.structs.joint import Joint
from mani_skill2.utils.structs.link import Link
Expand All @@ -27,6 +27,7 @@

FETCH_UNIQUE_COLLISION_BIT = 1 << 30


class Fetch(BaseAgent):
uid = "fetch"
urdf_path = f"{PACKAGE_ASSET_DIR}/robots/fetch/fetch.urdf"
Expand Down Expand Up @@ -221,42 +222,74 @@ def controller_configs(self):
force_limit=500,
)


controller_configs = dict(
pd_joint_delta_pos=dict(
arm=arm_pd_joint_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_joint_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_joint_pos=dict(
arm=arm_pd_joint_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_joint_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_ee_delta_pos=dict(
arm=arm_pd_ee_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_ee_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_ee_delta_pose=dict(
arm=arm_pd_ee_delta_pose, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_ee_delta_pose,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_ee_delta_pose_align=dict(
arm=arm_pd_ee_delta_pose_align, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_ee_delta_pose_align,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
# TODO(jigu): how to add boundaries for the following controllers
pd_joint_target_delta_pos=dict(
arm=arm_pd_joint_target_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_joint_target_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_ee_target_delta_pos=dict(
arm=arm_pd_ee_target_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_ee_target_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_ee_target_delta_pose=dict(
arm=arm_pd_ee_target_delta_pose, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_ee_target_delta_pose,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
# Caution to use the following controllers
pd_joint_vel=dict(
arm=arm_pd_joint_vel, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_joint_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_joint_pos_vel=dict(
arm=arm_pd_joint_pos_vel, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_joint_pos_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
pd_joint_delta_pos_vel=dict(
arm=arm_pd_joint_delta_pos_vel, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
arm=arm_pd_joint_delta_pos_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
base=base_pd_joint_vel,
),
)

Expand All @@ -270,13 +303,9 @@ def _after_init(self):
self.finger2_link: Link = get_obj_by_name(
self.robot.get_links(), "r_gripper_finger_link"
)
self.tcp: Link = get_obj_by_name(
self.robot.get_links(), self.ee_link_name
)
self.tcp: Link = get_obj_by_name(self.robot.get_links(), self.ee_link_name)

self.base_link: Link = get_obj_by_name(
self.robot.get_links(), "base_link"
)
self.base_link: Link = get_obj_by_name(self.robot.get_links(), "base_link")
self.l_wheel_link: Link = get_obj_by_name(
self.robot.get_links(), "l_wheel_link"
)
Expand Down Expand Up @@ -329,8 +358,12 @@ def is_grasping(self, object: Actor = None, min_impulse=1e-6, max_angle=85):
contacts = self.scene.get_contacts()

if object is None:
finger1_contacts = get_actor_contacts(contacts, self.finger1_link._bodies[0].entity)
finger2_contacts = get_actor_contacts(contacts, self.finger2_link._bodies[0].entity)
finger1_contacts = get_actor_contacts(
contacts, self.finger1_link._bodies[0].entity
)
finger2_contacts = get_actor_contacts(
contacts, self.finger2_link._bodies[0].entity
)
return (
np.linalg.norm(compute_total_impulse(finger1_contacts))
>= min_impulse
Expand All @@ -339,10 +372,14 @@ def is_grasping(self, object: Actor = None, min_impulse=1e-6, max_angle=85):
)
else:
limpulse = get_pairwise_contact_impulse(
contacts, self.finger1_link._bodies[0].entity, object._bodies[0].entity
contacts,
self.finger1_link._bodies[0].entity,
object._bodies[0].entity,
)
rimpulse = get_pairwise_contact_impulse(
contacts, self.finger2_link._bodies[0].entity, object._bodies[0].entity
contacts,
self.finger2_link._bodies[0].entity,
object._bodies[0].entity,
)

# direction to open the gripper
Expand Down Expand Up @@ -380,7 +417,7 @@ def build_grasp_pose(approaching, closing, center):
T[:3, :3] = np.stack([ortho, closing, approaching], axis=1)
T[:3, 3] = center
return sapien.Pose(T)

@property
def tcp_pose(self) -> Pose:
p = (self.finger1_link.pose.p + self.finger2_link.pose.p) / 2
Expand Down
18 changes: 17 additions & 1 deletion mani_skill2/envs/pick_and_place/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,23 @@ def _initialize_agent(self):
self.agent.robot.set_pose(Pose([-0.562, 0, 0]))
elif self.robot_uid == "fetch":
qpos = np.array(
[0, 0, 0, 0.386, 0, 0, 0, -np.pi / 4, 0, np.pi / 4, 0, np.pi / 3, 0, 0.015, 0.015]
[
0,
0,
0,
0.386,
0,
0,
0,
-np.pi / 4,
0,
np.pi / 4,
0,
np.pi / 3,
0,
0.015,
0.015,
]
)
self.agent.reset(qpos)
self.agent.robot.set_pose(sapien.Pose([-0.82, 0, -0.920]))
Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def step(self, action: Union[None, np.ndarray, Dict]):
if self.num_envs == 1:
terminated = terminated[0]
reward = reward[0]

if physx.is_gpu_enabled():
return obs, reward, terminated, torch.Tensor(False), info
else:
Expand Down
15 changes: 7 additions & 8 deletions mani_skill2/utils/sapien_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence]):
else:
return torch.tensor(array)


def _to_numpy(array: Union[Array, Sequence]) -> np.ndarray:
if isinstance(array, (dict)):
return {k: _to_numpy(v) for k, v in array.items()}
Expand All @@ -56,12 +57,14 @@ def _to_numpy(array: Union[Array, Sequence]) -> np.ndarray:
else:
return np.array(array)


def to_numpy(array: Union[Array, Sequence], dtype=None) -> np.ndarray:
array = _to_numpy(array)
if dtype is not None:
return array.astype(dtype)
return array


def _unbatch(array: Union[Array, Sequence]):
if isinstance(array, (dict)):
return {k: _unbatch(v) for k, v in array.items()}
Expand All @@ -78,9 +81,11 @@ def _unbatch(array: Union[Array, Sequence]):
return array[0]
return array


def unbatch(*args: Tuple[Union[Array, Sequence]]):
return tuple([_unbatch(x) for x in args])


def clone_tensor(array: Array):
if torch is not None and isinstance(array, torch.Tensor):
return array.clone()
Expand Down Expand Up @@ -321,15 +326,9 @@ def get_pairwise_contacts(
"""
pairwise_contacts = []
for contact in contacts:
if (
contact.bodies[0].entity == actor0
and contact.bodies[1].entity == actor1
):
if contact.bodies[0].entity == actor0 and contact.bodies[1].entity == actor1:
pairwise_contacts.append((contact, True))
elif (
contact.bodies[0].entity == actor1
and contact.bodies[1].entity == actor0
):
elif contact.bodies[0].entity == actor1 and contact.bodies[1].entity == actor0:
pairwise_contacts.append((contact, False))
return pairwise_contacts

Expand Down
27 changes: 23 additions & 4 deletions mani_skill2/utils/scene_builder/table/table_scene_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,34 @@ def initialize(self):
self.env.agent.robot.set_pose(sapien.Pose([-0.562, 0, 0]))
elif self.env.robot_uid == "fetch":
qpos = np.array(
[0, 0, 0, 0.386, 0, 0, 0, -np.pi / 4, 0, np.pi / 4, 0, np.pi / 3, 0, 0.015, 0.015]
[
0,
0,
0,
0.386,
0,
0,
0,
-np.pi / 4,
0,
np.pi / 4,
0,
np.pi / 3,
0,
0.015,
0.015,
]
)
self.env.agent.reset(qpos)
self.env.agent.robot.set_pose(sapien.Pose([-0.82, 0, -self.table_height]))

from mani_skill2.agents.robots.fetch import FETCH_UNIQUE_COLLISION_BIT
cs = self.ground._objs[0].find_component_by_type(
sapien.physx.PhysxRigidStaticComponent
).get_collision_shapes()[0]

cs = (
self.ground._objs[0]
.find_component_by_type(sapien.physx.PhysxRigidStaticComponent)
.get_collision_shapes()[0]
)
cg = cs.get_collision_groups()
cg[2] = FETCH_UNIQUE_COLLISION_BIT
cs.set_collision_groups(cg)
Expand Down

0 comments on commit 18b176c

Please sign in to comment.