Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch mobile #188

Merged
merged 10 commits into from
Jan 22, 2024
7 changes: 1 addition & 6 deletions mani_skill2/agents/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,7 @@ def _initialize_joints(self):

def _assert_fully_actuated(self):
active_joints = self.articulation.get_active_joints()
if len(active_joints) != len(self.joints) or not np.all(
[
active_joint == joint
for active_joint, joint in zip(active_joints, self.joints)
]
):
if len(active_joints) != len(self.joints) or set(active_joints) != set(self.joints):
print("active_joints:", [x.name for x in active_joints])
print("controlled_joints:", [x.name for x in self.joints])
raise AssertionError("{} is not fully actuated".format(self.articulation))
Expand Down
12 changes: 10 additions & 2 deletions mani_skill2/agents/controllers/pd_base_vel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch

from mani_skill2.utils.geometry import rotate_2d_vec_by_angle

Expand All @@ -18,14 +19,21 @@ def _initialize_action_space(self):
def set_action(self, action: np.ndarray):
action = self._preprocess_action(action)

# TODO (arth): add support for batched qpos and gpu sim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good that you added the todos, otherwise we will all forget this XD

if isinstance(self.qpos, torch.Tensor):
qpos = self.qpos.detach().cpu().numpy()
qpos = qpos[0]
if isinstance(action, torch.Tensor):
action = action.detach().cpu().numpy()

# Convert to ego-centric action
# Assume the 3rd DoF stands for orientation
ori = self.qpos[2]
ori = qpos[2]
vel = rotate_2d_vec_by_angle(action[:2], ori)
new_action = np.hstack([vel, action[2:]])

for i, joint in enumerate(self.joints):
joint.set_drive_velocity_target(new_action[i])
joint.set_drive_velocity_target(np.array([new_action[i]]))


class PDBaseVelControllerConfig(PDJointVelControllerConfig):
Expand Down
10 changes: 6 additions & 4 deletions mani_skill2/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy.spatial.transform import Rotation

from mani_skill2.utils.common import clip_and_scale_action
from mani_skill2.utils.sapien_utils import get_obj_by_name
from mani_skill2.utils.sapien_utils import get_obj_by_name, to_numpy, to_tensor
from mani_skill2.utils.structs.pose import vectorize_pose

from .base_controller import BaseController, ControllerConfig
Expand Down Expand Up @@ -64,15 +64,17 @@ def reset(self):

def compute_ik(self, target_pose, max_iterations=100):
# Assume the target pose is defined in the base frame
# TODO (arth): currently ik only supports cpu, so input/output is managed as such
# in future, need to change input/output processing per gpu implementation
result, success, error = self.pmodel.compute_inverse_kinematics(
self.ee_link_idx,
target_pose,
initial_qpos=self.articulation.get_qpos(),
target_pose.sp,
initial_qpos=to_numpy(self.articulation.get_qpos()).squeeze(0),
active_qmask=self.qmask,
max_iterations=max_iterations,
)
if success:
return result[self.joint_indices]
return to_tensor([result[self.joint_indices]])
else:
return None

Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/agents/robots/fetch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fetch import Fetch
from .fetch import Fetch, FETCH_UNIQUE_COLLISION_BIT
235 changes: 152 additions & 83 deletions mani_skill2/agents/robots/fetch/fetch.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
from copy import deepcopy
from typing import Dict, Tuple

import numpy as np
import sapien
import sapien.physx as physx
import torch

from mani_skill2 import PACKAGE_ASSET_DIR
from mani_skill2.agents.base_agent import BaseAgent
from mani_skill2.agents.controllers import *
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.common import np_compute_angle_between
from mani_skill2.utils.common import compute_angle_between, np_compute_angle_between
from mani_skill2.utils.sapien_utils import (
compute_total_impulse,
get_actor_contacts,
get_obj_by_name,
get_pairwise_contact_impulse,
)
from mani_skill2.utils.structs.actor import Actor
from mani_skill2.utils.sapien_utils import to_tensor
from mani_skill2.utils.structs.base import BaseStruct
from mani_skill2.utils.structs.joint import Joint
from mani_skill2.utils.structs.link import Link
from mani_skill2.utils.structs.pose import Pose
from mani_skill2.utils.structs.types import Array

FETCH_UNIQUE_COLLISION_BIT = 1 << 30

class Fetch(BaseAgent):
uid = "fetch"
Expand All @@ -33,6 +43,19 @@ class Fetch(BaseAgent):
),
),
)
sensor_configs = [
CameraConfig(
uid="fetch_head",
p=[0, 0, 0],
q=[1, 0, 0, 0],
width=128,
height=128,
fov=1.57,
near=0.01,
far=10,
entity_uid="head_camera_link",
)
]

def __init__(self, *args, **kwargs):
self.arm_joint_names = [
Expand Down Expand Up @@ -67,6 +90,12 @@ def __init__(self, *args, **kwargs):
self.body_damping = 1e2
self.body_force_limit = 100

self.base_joint_names = [
"root_x_axis_joint",
"root_y_axis_joint",
"root_z_rotation_joint",
]

super().__init__(*args, **kwargs)

@property
Expand Down Expand Up @@ -181,110 +210,164 @@ def controller_configs(self):
normalize_action=False,
)

# -------------------------------------------------------------------------- #
# Base
# -------------------------------------------------------------------------- #
base_pd_joint_vel = PDBaseVelControllerConfig(
self.base_joint_names,
lower=[-0.5, -0.5, -3.14],
upper=[0.5, 0.5, 3.14],
damping=1000,
force_limit=500,
)


controller_configs = dict(
pd_joint_delta_pos=dict(
arm=arm_pd_joint_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_joint_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_joint_pos=dict(
arm=arm_pd_joint_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_joint_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_ee_delta_pos=dict(
arm=arm_pd_ee_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_ee_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_ee_delta_pose=dict(
arm=arm_pd_ee_delta_pose,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_ee_delta_pose, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_ee_delta_pose_align=dict(
arm=arm_pd_ee_delta_pose_align,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_ee_delta_pose_align, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
# TODO(jigu): how to add boundaries for the following controllers
pd_joint_target_delta_pos=dict(
arm=arm_pd_joint_target_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_joint_target_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_ee_target_delta_pos=dict(
arm=arm_pd_ee_target_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_ee_target_delta_pos, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_ee_target_delta_pose=dict(
arm=arm_pd_ee_target_delta_pose,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_ee_target_delta_pose, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
# Caution to use the following controllers
pd_joint_vel=dict(
arm=arm_pd_joint_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_joint_vel, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_joint_pos_vel=dict(
arm=arm_pd_joint_pos_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_joint_pos_vel, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
pd_joint_delta_pos_vel=dict(
arm=arm_pd_joint_delta_pos_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
arm=arm_pd_joint_delta_pos_vel, gripper=gripper_pd_joint_pos, body=body_pd_joint_pos, base=base_pd_joint_vel,
),
)

# Make a deepcopy in case users modify any config
return deepcopy_dict(controller_configs)

def _after_init(self):
self.finger1_link = get_obj_by_name(
self.finger1_link: Link = get_obj_by_name(
self.robot.get_links(), "l_gripper_finger_link"
)
self.finger2_link = get_obj_by_name(
self.finger2_link: Link = get_obj_by_name(
self.robot.get_links(), "r_gripper_finger_link"
)
self.tcp = get_obj_by_name(self.robot.get_links(), self.ee_link_name)

def is_grasping(self, object: sapien.Entity = None, min_impulse=1e-6, max_angle=85):
contacts = self.scene.get_contacts()
if object is None:
finger1_contacts = get_actor_contacts(contacts, self.finger1_link)
finger2_contacts = get_actor_contacts(contacts, self.finger2_link)
return (
np.linalg.norm(compute_total_impulse(finger1_contacts)) >= min_impulse
and np.linalg.norm(compute_total_impulse(finger2_contacts))
>= min_impulse
)
else:
limpulse = get_pairwise_contact_impulse(contacts, self.finger1_link, object)
rimpulse = get_pairwise_contact_impulse(contacts, self.finger2_link, object)
self.tcp: Link = get_obj_by_name(
self.robot.get_links(), self.ee_link_name
)

# direction to open the gripper
ldirection = -self.finger1_link.pose.to_transformation_matrix()[:3, 1]
rdirection = self.finger2_link.pose.to_transformation_matrix()[:3, 1]
self.base_link: Link = get_obj_by_name(
self.robot.get_links(), "base_link"
)
self.l_wheel_link: Link = get_obj_by_name(
self.robot.get_links(), "l_wheel_link"
)
self.r_wheel_link: Link = get_obj_by_name(
self.robot.get_links(), "r_wheel_link"
)
for link in [self.base_link, self.l_wheel_link, self.r_wheel_link]:
cs = link._bodies[0].get_collision_shapes()[0]
cg = cs.get_collision_groups()
cg[2] = FETCH_UNIQUE_COLLISION_BIT
cs.set_collision_groups(cg)

# angle between impulse and open direction
langle = np_compute_angle_between(ldirection, limpulse)
rangle = np_compute_angle_between(rdirection, rimpulse)
self.queries: Dict[str, Tuple[physx.PhysxGpuContactQuery, Tuple[int]]] = dict()

lflag = (
np.linalg.norm(limpulse) >= min_impulse
and np.rad2deg(langle) <= max_angle
def is_grasping(self, object: Actor = None, min_impulse=1e-6, max_angle=85):
# TODO (stao): is_grasping code needs to be updated for new GPU sim
if physx.is_gpu_enabled():
if object.name not in self.queries:
body_pairs = list(zip(self.finger1_link._bodies, object._bodies))
body_pairs += list(zip(self.finger2_link._bodies, object._bodies))
self.queries[object.name] = (
self.scene.px.gpu_create_contact_query(body_pairs),
(len(object._bodies), 3),
)
print(f"Create query for Fetch grasp({object.name})")
query, contacts_shape = self.queries[object.name]
self.scene.px.gpu_query_contacts(query)
# query.cuda_contacts # (num_unique_pairs * num_envs, 3)
contacts = query.cuda_contacts.clone().reshape((-1, *contacts_shape))
lforce = torch.linalg.norm(contacts[0], axis=1)
rforce = torch.linalg.norm(contacts[1], axis=1)

# NOTE (stao): 0.5 * time_step is a decent value when tested on a pick cube task.
min_force = 0.5 * self.scene.px.timestep

# direction to open the gripper
ldirection = -self.finger1_link.pose.to_transformation_matrix()[..., :3, 1]
rdirection = self.finger2_link.pose.to_transformation_matrix()[..., :3, 1]
langle = compute_angle_between(ldirection, contacts[0])
rangle = compute_angle_between(rdirection, contacts[1])
lflag = torch.logical_and(
lforce >= min_force, torch.rad2deg(langle) <= max_angle
)
rflag = (
np.linalg.norm(rimpulse) >= min_impulse
and np.rad2deg(rangle) <= max_angle
rflag = torch.logical_and(
rforce >= min_force, torch.rad2deg(rangle) <= max_angle
)

return all([lflag, rflag])
return torch.logical_and(lflag, rflag)
else:
contacts = self.scene.get_contacts()

if object is None:
finger1_contacts = get_actor_contacts(contacts, self.finger1_link._bodies[0].entity)
finger2_contacts = get_actor_contacts(contacts, self.finger2_link._bodies[0].entity)
return (
np.linalg.norm(compute_total_impulse(finger1_contacts))
>= min_impulse
and np.linalg.norm(compute_total_impulse(finger2_contacts))
>= min_impulse
)
else:
limpulse = get_pairwise_contact_impulse(
contacts, self.finger1_link._bodies[0].entity, object._bodies[0].entity
)
rimpulse = get_pairwise_contact_impulse(
contacts, self.finger2_link._bodies[0].entity, object._bodies[0].entity
)

# direction to open the gripper
ldirection = -self.finger1_link.pose.to_transformation_matrix()[
..., :3, 1
]
rdirection = self.finger2_link.pose.to_transformation_matrix()[
..., :3, 1
]

# TODO Convert this to batched code
# angle between impulse and open direction
langle = np_compute_angle_between(ldirection[0], limpulse)
rangle = np_compute_angle_between(rdirection[0], rimpulse)

lflag = (
np.linalg.norm(limpulse) >= min_impulse
and np.rad2deg(langle) <= max_angle
)
rflag = (
np.linalg.norm(rimpulse) >= min_impulse
and np.rad2deg(rangle) <= max_angle
)

return all([lflag, rflag])

@staticmethod
def build_grasp_pose(approaching, closing, center):
Expand All @@ -297,23 +380,9 @@ def build_grasp_pose(approaching, closing, center):
T[:3, :3] = np.stack([ortho, closing, approaching], axis=1)
T[:3, 3] = center
return sapien.Pose(T)

@property
def sensor_configs(self):
return [
CameraConfig(
uid="fetch_head",
p=[0, 0, 0],
q=[0.9238795, 0, 0.3826834, 0],
width=128,
height=128,
fov=1.57,
near=0.01,
far=10,
entity_uid="head_camera_link",
)
]


@property
def tcp_pose_p(self):
return (self.finger1_link.pose.p + self.finger2_link.pose.p) / 2
def tcp_pose(self) -> Pose:
p = (self.finger1_link.pose.p + self.finger2_link.pose.p) / 2
q = (self.finger1_link.pose.q + self.finger2_link.pose.q) / 2
return Pose.create_from_pq(p=p, q=q)
Loading