Skip to content

Commit

Permalink
Merge branch 'dev' into tasks-remake
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 24, 2024
2 parents 4bae958 + e9717ca commit fba9ee8
Show file tree
Hide file tree
Showing 37 changed files with 291 additions and 362 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ repos:
- id: check-ast
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: check-yaml
# - id: check-yaml
- id: end-of-file-fixer
files: \.py$
- id: trailing-whitespace
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarking/benchmark_gpu_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main(args):
env.close()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v0")
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1")
parser.add_argument("-o", "--obs-mode", type=str, default="none")
parser.add_argument("-n", "--num-envs", type=int, default=256)
parser.add_argument(
Expand Down
21 changes: 9 additions & 12 deletions mani_skill2/agents/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,12 @@ def _preprocess_action(self, action: Array):
# TODO(jigu): support discrete action
if self.scene.num_envs > 1:
action_dim = self.action_space.shape[1]
assert action.shape == (self.scene.num_envs, action_dim), (
action.shape,
action_dim,
)
else:
action_dim = self.action_space.shape[0]
assert action.shape == (action_dim,), (action.shape, action_dim)
assert action.shape == (self.scene.num_envs, action_dim), (
action.shape,
action_dim,
)

if self._normalize_action:
action = self._clip_and_scale_action(action)
Expand Down Expand Up @@ -273,17 +272,15 @@ def set_action(self, action: np.ndarray):
# TODO (stao): optimization, do we really need this sanity check? Does gymnasium already do this for us
if self.scene.num_envs > 1:
action_dim = self.action_space.shape[1]
assert action.shape == (self.scene.num_envs, action_dim), (
action.shape,
action_dim,
)
else:
action_dim = self.action_space.shape[0]
assert action.shape == (action_dim,), (action.shape, action_dim)

assert action.shape == (self.scene.num_envs, action_dim), (
action.shape,
action_dim,
)
for uid, controller in self.controllers.items():
start, end = self.action_mapping[uid]
controller.set_action(action[..., start:end])
controller.set_action(action[:, start:end])

def to_action_dict(self, action: np.ndarray):
"""Convert a flat action to a dict of actions."""
Expand Down
28 changes: 11 additions & 17 deletions mani_skill2/agents/controllers/pd_base_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from mani_skill2.utils.geometry import rotate_2d_vec_by_angle
from mani_skill2.utils.structs.types import Array

from .pd_joint_vel import PDJointVelController, PDJointVelControllerConfig

Expand All @@ -16,25 +17,18 @@ def _initialize_action_space(self):
assert len(self.joints) >= 3, len(self.joints)
super()._initialize_action_space()

def set_action(self, action: np.ndarray):
def set_action(self, action: Array):
action = self._preprocess_action(action)

# TODO (arth): add support for batched qpos and gpu sim
if isinstance(self.qpos, torch.Tensor):
qpos = self.qpos.detach().cpu().numpy()
qpos = qpos[0]
if isinstance(action, torch.Tensor):
action = action.detach().cpu().numpy()

# Convert to ego-centric action
# Assume the 3rd DoF stands for orientation
ori = qpos[2]
vel = rotate_2d_vec_by_angle(action[:2], ori)
new_action = np.hstack([vel, action[2:]])

for i, joint in enumerate(self.joints):
joint.set_drive_velocity_target(np.array([new_action[i]]))


ori = self.qpos[:, 2]
rot_mat = torch.zeros(ori.shape[0], 2, 2, device=action.device)
rot_mat[:, 0, 0] = torch.cos(ori)
rot_mat[:, 0, 1] = -torch.sin(ori)
rot_mat[:, 1, 0] = torch.sin(ori)
rot_mat[:, 1, 1] = torch.cos(ori)
vel = (rot_mat @ action[:, :2].unsqueeze(-1)).squeeze(-1)
new_action = torch.hstack([vel, action[:, 2:]])
self.articulation.set_joint_drive_velocity_targets(new_action, self.joints)
class PDBaseVelControllerConfig(PDJointVelControllerConfig):
controller_cls = PDBaseVelController
113 changes: 76 additions & 37 deletions mani_skill2/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from dataclasses import dataclass
from os import devnull
from typing import Sequence, Union

import numpy as np

# TODO (stao): https://github.com/UM-ARM-Lab/pytorch_kinematics/issues/35 pk requires mujoco as it is always imported despite not being used for the code we need
import pytorch_kinematics as pk
import sapien
import sapien.physx as physx
import torch
from gymnasium import spaces
from scipy.spatial.transform import Rotation

from mani_skill2.utils.common import clip_and_scale_action
from mani_skill2.utils.geometry.rotation_conversions import (
euler_angles_to_matrix,
matrix_to_quaternion,
)
from mani_skill2.utils.sapien_utils import get_obj_by_name, to_numpy, to_tensor
from mani_skill2.utils.structs.pose import vectorize_pose
from mani_skill2.utils.structs.pose import Pose, vectorize_pose
from mani_skill2.utils.structs.types import Array

from .base_controller import BaseController, ControllerConfig
from .base_controller import ControllerConfig
from .pd_joint_pos import PDJointPosController


Expand All @@ -21,14 +30,29 @@ class PDEEPosController(PDJointPosController):
config: "PDEEPosControllerConfig"

def _initialize_joints(self):
self.initial_qpos = None
super()._initialize_joints()

# Pinocchio model to compute IK
# TODO (stao): Batched IK? https://curobo.org/source/getting_started/2a_python_examples.html#inverse-kinematics
if physx.is_gpu_enabled():
pass
with open(self.config.urdf_path, "r") as f:
urdf_str = f.read()

# NOTE (stao): it seems that the pk library currently always outputs some complaints if there are unknown attributes in a URDF. Hide it with this contextmanager here
@contextmanager
def suppress_stdout_stderr():
"""A context manager that redirects stdout and stderr to devnull"""
with open(devnull, "w") as fnull:
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
yield (err, out)

with suppress_stdout_stderr():
self.pk_chain = pk.build_serial_chain_from_urdf(
urdf_str,
end_link_name=self.config.ee_link,
).to(device="cuda")
else:
self.pmodel = self.articulation.create_pinocchio_model()
# TODO should we just use jacobian inverse * delta method from pk?
self.pmodel = self.articulation._objs[0].create_pinocchio_model()
self.qmask = np.zeros(self.articulation.dof, dtype=bool)
self.qmask[self.joint_indices] = 1

Expand Down Expand Up @@ -63,17 +87,26 @@ def reset(self):
super().reset()
self._target_pose = self.ee_pose_at_base

def compute_ik(self, target_pose, max_iterations=100):
def compute_ik(self, target_pose: Pose, action: Array, max_iterations=100):
# Assume the target pose is defined in the base frame
# TODO (arth): currently ik only supports cpu, so input/output is managed as such
# in future, need to change input/output processing per gpu implementation
result, success, error = self.pmodel.compute_inverse_kinematics(
self.ee_link_idx,
target_pose.sp,
initial_qpos=to_numpy(self.articulation.get_qpos()).squeeze(0),
active_qmask=self.qmask,
max_iterations=max_iterations,
)
if physx.is_gpu_enabled():
jacobian = self.pk_chain.jacobian(self.articulation.get_qpos())
# NOTE (stao): a bit of a hacky way to check if we want to do IK on position or pose here
if action.shape[1] == 3:
jacobian = jacobian[:, 0:3]

# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return self.articulation.get_qpos() + delta_joint_pos.squeeze(-1)

else:
result, success, error = self.pmodel.compute_inverse_kinematics(
self.ee_link_idx,
target_pose.sp,
initial_qpos=to_numpy(self.articulation.get_qpos()).squeeze(0),
active_qmask=self.qmask,
max_iterations=max_iterations,
)
if success:
return to_tensor([result[self.joint_indices]])
else:
Expand All @@ -82,7 +115,7 @@ def compute_ik(self, target_pose, max_iterations=100):
def compute_target_pose(self, prev_ee_pose_at_base, action):
# Keep the current rotation and change the position
if self.config.use_delta:
delta_pose = sapien.Pose(action)
delta_pose = Pose.create(action)

if self.config.frame == "base":
target_pose = delta_pose * prev_ee_pose_at_base
Expand All @@ -92,13 +125,12 @@ def compute_target_pose(self, prev_ee_pose_at_base, action):
raise NotImplementedError(self.config.frame)
else:
assert self.config.frame == "base", self.config.frame
target_pose = sapien.Pose(action)
target_pose = Pose.create(action)

return target_pose

def set_action(self, action: np.ndarray):
def set_action(self, action: Array):
action = self._preprocess_action(action)

self._step = 0
self._start_qpos = self.qpos

Expand All @@ -108,10 +140,9 @@ def set_action(self, action: np.ndarray):
prev_ee_pose_at_base = self.ee_pose_at_base

self._target_pose = self.compute_target_pose(prev_ee_pose_at_base, action)
self._target_qpos = self.compute_ik(self._target_pose)
self._target_qpos = self.compute_ik(self._target_pose, action)
if self._target_qpos is None:
self._target_qpos = self._start_qpos

if self.config.interpolate:
self._step_size = (self._target_qpos - self._start_qpos) / self._sim_steps
else:
Expand All @@ -137,6 +168,7 @@ class PDEEPosControllerConfig(ControllerConfig):
force_limit: Union[float, Sequence[float]] = 1e10
friction: Union[float, Sequence[float]] = 0.0
ee_link: str = None
urdf_path: str = None
frame: str = "ee" # [base, ee]
use_delta: bool = True
use_target: bool = False
Expand Down Expand Up @@ -168,23 +200,26 @@ def _initialize_action_space(self):
self.action_space = spaces.Box(low, high, dtype=np.float32)

def _clip_and_scale_action(self, action):
# TODO (stao): support batched actions
# NOTE(xiqiang): rotation should be clipped by norm.
pos_action = clip_and_scale_action(
action[:3], self.action_space_low[:3], self.action_space_high[:3]
action[:, :3], self.action_space_low[:3], self.action_space_high[:3]
)
rot_action = action[3:]
rot_norm = torch.linalg.norm(rot_action)
if rot_norm > 1:
rot_action = rot_action / rot_norm
rot_action = action[:, 3:]

rot_norm = torch.linalg.norm(rot_action, axis=1)
rot_action[rot_norm > 1] = torch.mul(rot_action, 1 / rot_norm[:, None])[
rot_norm > 1
]
rot_action = rot_action * self.config.rot_bound
return np.hstack([pos_action, rot_action])
return torch.hstack([pos_action, rot_action])

def compute_target_pose(self, prev_ee_pose_at_base, action):
def compute_target_pose(self, prev_ee_pose_at_base: Pose, action):
if self.config.use_delta:
delta_pos, delta_rot = action[0:3], action[3:6]
delta_quat = Rotation.from_rotvec(delta_rot).as_quat()[[3, 0, 1, 2]]
delta_pose = sapien.Pose(delta_pos, delta_quat)
delta_pos, delta_rot = action[:, 0:3], action[:, 3:6]
delta_quat = matrix_to_quaternion(euler_angles_to_matrix(delta_rot, "XYZ"))
# TODO (stao): verify correctness, the results of these two delta_rot to quaternion are a little different
# delta_quat = Rotation.from_rotvec(delta_rot).as_quat()[[3, 0, 1, 2]]
delta_pose = Pose.create_from_pq(delta_pos, delta_quat)

if self.config.frame == "base":
target_pose = delta_pose * prev_ee_pose_at_base
Expand All @@ -198,9 +233,12 @@ def compute_target_pose(self, prev_ee_pose_at_base, action):
raise NotImplementedError(self.config.frame)
else:
assert self.config.frame == "base", self.config.frame
target_pos, target_rot = action[0:3], action[3:6]
target_quat = Rotation.from_rotvec(target_rot).as_quat()[[3, 0, 1, 2]]
target_pose = sapien.Pose(target_pos, target_quat)
target_pos, target_rot = action[:, 0:3], action[:, 3:6]
target_quat = matrix_to_quaternion(
euler_angles_to_matrix(target_rot, "XYZ")
)
# target_quat = Rotation.from_rotvec(target_rot).as_quat()[[3, 0, 1, 2]]
target_pose = Pose.create_from_pq(target_pos, target_quat)

return target_pose

Expand All @@ -215,6 +253,7 @@ class PDEEPoseControllerConfig(ControllerConfig):
force_limit: Union[float, Sequence[float]] = 1e10
friction: Union[float, Sequence[float]] = 0.0
ee_link: str = None
urdf_path: str = None
frame: str = "ee" # [base, ee, ee_align]
use_delta: bool = True
use_target: bool = False
Expand Down
16 changes: 9 additions & 7 deletions mani_skill2/agents/controllers/pd_joint_pos_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence, Union

import numpy as np
import torch
from gymnasium import spaces

from .base_controller import BaseController, ControllerConfig
Expand All @@ -25,31 +26,32 @@ def reset(self):
self._target_qvel = np.zeros_like(self._target_qpos)

def set_drive_velocity_targets(self, targets):
for i, joint in enumerate(self.joints):
joint.set_drive_velocity_target(targets[i])
self.articulation.set_joint_drive_velocity_targets(targets, self.joints)

def set_action(self, action: np.ndarray):
action = self._preprocess_action(action)
nq = len(action) // 2
nq = len(action[0]) // 2

self._step = 0
self._start_qpos = self.qpos

if self.config.use_delta:
if self.config.use_target:
self._target_qpos = self._target_qpos + action[:nq]
self._target_qpos = self._target_qpos + action[:, :nq]
else:
self._target_qpos = self._start_qpos + action[:nq]
self._target_qpos = self._start_qpos + action[:, :nq]
else:
# Compatible with mimic
self._target_qpos = np.broadcast_to(action[:nq], self._start_qpos.shape)
self._target_qpos = torch.broadcast_to(
action[:, :nq], self._start_qpos.shape
)

if self.config.interpolate:
self._step_size = (self._target_qpos - self._start_qpos) / self._sim_steps
else:
self.set_drive_targets(self._target_qpos)

self._target_qvel = action[nq:]
self._target_qvel = action[:, nq:]
self.set_drive_velocity_targets(self._target_qvel)


Expand Down
2 changes: 2 additions & 0 deletions mani_skill2/agents/robots/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def controller_configs(self):
self.arm_damping,
self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path
)
arm_pd_ee_delta_pose = PDEEPoseControllerConfig(
self.arm_joint_names,
Expand All @@ -144,6 +145,7 @@ def controller_configs(self):
self.arm_damping,
self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path
)

arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos)
Expand Down
2 changes: 2 additions & 0 deletions mani_skill2/agents/robots/panda/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def controller_configs(self):
self.arm_damping,
self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path,
)
arm_pd_ee_delta_pose = PDEEPoseControllerConfig(
self.arm_joint_names,
Expand All @@ -108,6 +109,7 @@ def controller_configs(self):
self.arm_damping,
self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path,
)

arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos)
Expand Down
Loading

0 comments on commit fba9ee8

Please sign in to comment.