diff --git a/mani_skill2/agents/controllers/pd_ee_pose.py b/mani_skill2/agents/controllers/pd_ee_pose.py index 2d82fe4dd..1998161c6 100644 --- a/mani_skill2/agents/controllers/pd_ee_pose.py +++ b/mani_skill2/agents/controllers/pd_ee_pose.py @@ -6,67 +6,45 @@ import sapien.physx as physx import torch from gymnasium import spaces -from scipy.spatial.transform import Rotation from mani_skill2.utils.common import clip_and_scale_action from mani_skill2.utils.sapien_utils import get_obj_by_name, to_numpy, to_tensor from mani_skill2.utils.structs.pose import Pose, vectorize_pose - -from .base_controller import BaseController, ControllerConfig +from mani_skill2.utils.structs.types import Array +from contextlib import contextmanager,redirect_stderr,redirect_stdout +from os import devnull + +@contextmanager +def suppress_stdout_stderr(): + """A context manager that redirects stdout and stderr to devnull""" + with open(devnull, 'w') as fnull: + with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: + yield (err, out) +# TODO (stao): https://github.com/UM-ARM-Lab/pytorch_kinematics/issues/35 pk requires mujoco as it is always imported despite not being used for the code we need +import pytorch_kinematics as pk +from .base_controller import ControllerConfig from .pd_joint_pos import PDJointPosController from mani_skill2.utils.geometry.rotation_conversions import ( - quaternion_apply, - quaternion_multiply, - quaternion_to_matrix, euler_angles_to_matrix, matrix_to_quaternion, ) -try: - from curobo.types.base import TensorDeviceType - from curobo.types.math import Pose as Curobo_Pose - from curobo.types.robot import RobotConfig - from curobo.util_file import get_robot_configs_path, join_path, load_yaml - from curobo.wrap.reacher.ik_solver import IKSolver, IKSolverConfig -except: - pass # NOTE(jigu): not necessary to inherit, just for convenience class PDEEPosController(PDJointPosController): config: "PDEEPosControllerConfig" def _initialize_joints(self): + self.initial_qpos = None super()._initialize_joints() - # Pinocchio model to compute IK - # TODO (stao): Batched IK? https://curobo.org/source/getting_started/2a_python_examples.html#inverse-kinematics if physx.is_gpu_enabled(): - - tensor_args = TensorDeviceType() - - config_file = load_yaml( - "/home/stao/work/research/maniskill/ManiSkill2/mani_skill2/assets/robots/panda/panda_v2.yml" - ) - urdf_file = config_file["robot_cfg"]["kinematics"]["urdf_path"] - base_link = "panda_link0" - ee_link = config_file["robot_cfg"]["kinematics"]["ee_link"] - robot_cfg = RobotConfig.from_basic( - urdf_file, base_link, ee_link, tensor_args - ) - - ik_config = IKSolverConfig.load_from_robot_config( - robot_cfg, - None, - rotation_threshold=0.05, - position_threshold=0.005, - num_seeds=20, - self_collision_check=False, - self_collision_opt=False, - tensor_args=tensor_args, - use_cuda_graph=True, - ) - self.curobo_ik_solver = IKSolver(ik_config) + with open(self.config.urdf_path, "r") as f: + urdf_str = f.read() + with suppress_stdout_stderr(): + self.pk_chain = pk.build_serial_chain_from_urdf(urdf_str, end_link_name=self.config.ee_link,).to(device="cuda") else: - self.pmodel = self.articulation.create_pinocchio_model() + # TODO should we just use jacobian inverse * delta method from pk? + self.pmodel = self.articulation._objs[0].create_pinocchio_model() self.qmask = np.zeros(self.articulation.dof, dtype=bool) self.qmask[self.joint_indices] = 1 @@ -101,23 +79,18 @@ def reset(self): super().reset() self._target_pose = self.ee_pose_at_base - def compute_ik(self, target_pose: Pose, max_iterations=100): + def compute_ik(self, target_pose: Pose, action: Array, max_iterations=100): # Assume the target pose is defined in the base frame - # TODO (arth): currently ik only supports cpu, so input/output is managed as such - # in future, need to change input/output processing per gpu implementation - if physx.is_gpu_enabled(): - initial_qpos = self.curobo_ik_solver.sample_configs(self.scene.num_envs) - initial_qpos = self.articulation.get_qpos()[:, self.qmask] - kin_state = self.curobo_ik_solver.fk(initial_qpos) - # goal = Curobo_Pose(kin_state.ee_position, kin_state.ee_quaternion) - goal = Curobo_Pose(target_pose.p, target_pose.q) - # import ipdb;ipdb.set_trace() - result = self.curobo_ik_solver.solve_batch(goal) - q_solution = result.solution[result.success] # (N, 1, dof) - if torch.all(result.success): - return q_solution - return None + jacobian = self.pk_chain.jacobian(self.articulation.get_qpos()[:, self.qmask]) + # NOTE (stao): a bit of a hacky way to check if we want to do IK on position or pose here + if action.shape[1] == 3: + jacobian_pos = jacobian[:, 0:3] + jacobian_pinv = torch.linalg.pinv(jacobian_pos) + delta_joint_pos = 1.0 * jacobian_pinv @ action.unsqueeze(-1) + delta_joint_pos = delta_joint_pos.squeeze(-1) + return self.articulation.get_qpos()[:, self.qmask] + delta_joint_pos + else: result, success, error = self.pmodel.compute_inverse_kinematics( self.ee_link_idx, @@ -148,7 +121,7 @@ def compute_target_pose(self, prev_ee_pose_at_base, action): return target_pose - def set_action(self, action: np.ndarray): + def set_action(self, action: Array): action = self._preprocess_action(action) self._step = 0 self._start_qpos = self.qpos @@ -159,7 +132,7 @@ def set_action(self, action: np.ndarray): prev_ee_pose_at_base = self.ee_pose_at_base self._target_pose = self.compute_target_pose(prev_ee_pose_at_base, action) - self._target_qpos = self.compute_ik(self._target_pose) + self._target_qpos = self.compute_ik(self._target_pose, action) if self._target_qpos is None: self._target_qpos = self._start_qpos if self.config.interpolate: @@ -187,6 +160,7 @@ class PDEEPosControllerConfig(ControllerConfig): force_limit: Union[float, Sequence[float]] = 1e10 friction: Union[float, Sequence[float]] = 0.0 ee_link: str = None + urdf_path: str = None frame: str = "ee" # [base, ee] use_delta: bool = True use_target: bool = False @@ -267,6 +241,7 @@ class PDEEPoseControllerConfig(ControllerConfig): force_limit: Union[float, Sequence[float]] = 1e10 friction: Union[float, Sequence[float]] = 0.0 ee_link: str = None + urdf_path: str = None frame: str = "ee" # [base, ee, ee_align] use_delta: bool = True use_target: bool = False diff --git a/mani_skill2/agents/robots/panda/panda.py b/mani_skill2/agents/robots/panda/panda.py index 6ab3ca5a4..46458509d 100644 --- a/mani_skill2/agents/robots/panda/panda.py +++ b/mani_skill2/agents/robots/panda/panda.py @@ -98,6 +98,7 @@ def controller_configs(self): self.arm_damping, self.arm_force_limit, ee_link=self.ee_link_name, + urdf_path=self.urdf_path ) arm_pd_ee_delta_pose = PDEEPoseControllerConfig( self.arm_joint_names, @@ -108,6 +109,7 @@ def controller_configs(self): self.arm_damping, self.arm_force_limit, ee_link=self.ee_link_name, + urdf_path=self.urdf_path ) arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos) diff --git a/mani_skill2/assets/robots/panda/panda_v2.yml b/mani_skill2/assets/robots/panda/panda_v2.yml deleted file mode 100644 index 07ae2764f..000000000 --- a/mani_skill2/assets/robots/panda/panda_v2.yml +++ /dev/null @@ -1,114 +0,0 @@ -## -## Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -## -## NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -## property and proprietary rights in and to this material, related -## documentation and any modifications thereto. Any use, reproduction, -## disclosure or distribution of this material and related documentation -## without an express license agreement from NVIDIA CORPORATION or -## its affiliates is strictly prohibited. -## - -robot_cfg: - kinematics: - use_usd_kinematics: False - isaac_usd_path: "/Isaac/Robots/Franka/franka.usd" - usd_path: "robot/non_shipping/franka/franka_panda_meters.usda" - usd_robot_root: "/panda" - usd_flip_joints: ["panda_joint1","panda_joint2","panda_joint3","panda_joint4", "panda_joint5", - "panda_joint6","panda_joint7","panda_finger_joint1", "panda_finger_joint2"] - usd_flip_joints: { - "panda_joint1": "Z", - "panda_joint2": "Z", - "panda_joint3": "Z", - "panda_joint4": "Z", - "panda_joint5": "Z", - "panda_joint6": "Z", - "panda_joint7": "Z", - "panda_finger_joint1": "Y", - "panda_finger_joint2": "Y", - } - - usd_flip_joint_limits: ["panda_finger_joint2"] - urdf_path: "robot/franka_description/franka_panda.urdf" - asset_root_path: "robot/franka_description" - base_link: "base_link" - ee_link: "panda_hand" - collision_link_names: - [ - "panda_link0", - "panda_link1", - "panda_link2", - "panda_link3", - "panda_link4", - "panda_link5", - "panda_link6", - "panda_link7", - "panda_hand", - "panda_leftfinger", - "panda_rightfinger", - "attached_object", - ] - collision_spheres: "spheres/franka_mesh.yml" - collision_sphere_buffer: 0.0025 - extra_collision_spheres: {"attached_object": 4} - use_global_cumul: True - self_collision_ignore: - { - "panda_link0": ["panda_link1", "panda_link2"], - "panda_link1": ["panda_link2", "panda_link3", "panda_link4"], - "panda_link2": ["panda_link3", "panda_link4"], - "panda_link3": ["panda_link4", "panda_link6"], - "panda_link4": - ["panda_link5", "panda_link6", "panda_link7", "panda_link8"], - "panda_link5": ["panda_link6", "panda_link7", "panda_hand","panda_leftfinger", "panda_rightfinger"], - "panda_link6": ["panda_link7", "panda_hand", "attached_object", "panda_leftfinger", "panda_rightfinger"], - "panda_link7": ["panda_hand", "attached_object", "panda_leftfinger", "panda_rightfinger"], - "panda_hand": ["panda_leftfinger", "panda_rightfinger","attached_object"], - "panda_leftfinger": ["panda_rightfinger", "attached_object"], - "panda_rightfinger": ["attached_object"], - - } - - self_collision_buffer: - { - "panda_link0": 0.1, - "panda_link1": 0.05, - "panda_link2": 0.0, - "panda_link3": 0.0, - "panda_link4": 0.0, - "panda_link5": 0.0, - "panda_link6": 0.0, - "panda_link7": 0.0, - "panda_hand": 0.02, - "panda_leftfinger": 0.01, - "panda_rightfinger": 0.01, - "attached_object": 0.0, - } - #link_names: ["panda_link4"] - mesh_link_names: - [ - "panda_link0", - "panda_link1", - "panda_link2", - "panda_link3", - "panda_link4", - "panda_link5", - "panda_link6", - "panda_link7", - "panda_hand", - "panda_leftfinger", - "panda_rightfinger", - ] - lock_joints: {"panda_finger_joint1": 0.04, "panda_finger_joint2": 0.04} - extra_links: {"attached_object":{"parent_link_name": "panda_hand" , - "link_name": "attached_object", "fixed_transform": [0,0,0,1,0,0,0], "joint_type":"FIXED", - "joint_name": "attach_joint" }} - cspace: - joint_names: ["panda_joint1","panda_joint2","panda_joint3","panda_joint4", "panda_joint5", - "panda_joint6","panda_joint7","panda_finger_joint1", "panda_finger_joint2"] - retract_config: [0.0, -1.3, 0.0, -2.5, 0.0, 1.0, 0., 0.04, 0.04] - null_space_weight: [1,1,1,1,1,1,1,1,1] - cspace_distance_weight: [1,1,1,1,1,1,1,1,1] - max_acceleration: 15.0 - max_jerk: 500.0 \ No newline at end of file diff --git a/mani_skill2/utils/structs/articulation.py b/mani_skill2/utils/structs/articulation.py index 8c32d503f..64e54b840 100644 --- a/mani_skill2/utils/structs/articulation.py +++ b/mani_skill2/utils/structs/articulation.py @@ -360,7 +360,7 @@ def create_pinocchio_model(self): # NOTE (stao): This is available but not typed in SAPIEN if physx.is_gpu_enabled(): raise NotImplementedError( - "Cannot create a pinocchio model when GPU is enabled. If you wish to do inverse kinematics you must use curobo" + "Cannot create a pinocchio model when GPU is enabled. If you wish to do inverse kinematics you must use pytorch_kinematics" ) else: return self._objs[0].create_pinocchio_model() diff --git a/manualtest/visual_all_envs_cpu.py b/manualtest/visual_all_envs_cpu.py index f756db86c..24a362c2f 100644 --- a/manualtest/visual_all_envs_cpu.py +++ b/manualtest/visual_all_envs_cpu.py @@ -14,7 +14,7 @@ num_envs=num_envs, enable_shadow=True, render_mode="rgb_array", - control_mode="pd_ee_delta_pose", + control_mode="pd_ee_delta_pos", sim_freq=500, control_freq=100, ) @@ -36,6 +36,7 @@ while i < 50: obs, rew, terminated, truncated, info = env.step(env.action_space.sample()) done = np.logical_or(to_numpy(terminated), to_numpy(truncated)) + print(rew) if num_envs == 1: env.render_human() done = done.any()