Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 24, 2024
1 parent 67c3e25 commit c205d4c
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 175 deletions.
93 changes: 34 additions & 59 deletions mani_skill2/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,45 @@
import sapien.physx as physx
import torch
from gymnasium import spaces
from scipy.spatial.transform import Rotation

from mani_skill2.utils.common import clip_and_scale_action
from mani_skill2.utils.sapien_utils import get_obj_by_name, to_numpy, to_tensor
from mani_skill2.utils.structs.pose import Pose, vectorize_pose

from .base_controller import BaseController, ControllerConfig
from mani_skill2.utils.structs.types import Array
from contextlib import contextmanager,redirect_stderr,redirect_stdout
from os import devnull

@contextmanager
def suppress_stdout_stderr():
"""A context manager that redirects stdout and stderr to devnull"""
with open(devnull, 'w') as fnull:
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
yield (err, out)
# TODO (stao): https://github.com/UM-ARM-Lab/pytorch_kinematics/issues/35 pk requires mujoco as it is always imported despite not being used for the code we need
import pytorch_kinematics as pk
from .base_controller import ControllerConfig
from .pd_joint_pos import PDJointPosController

from mani_skill2.utils.geometry.rotation_conversions import (
quaternion_apply,
quaternion_multiply,
quaternion_to_matrix,
euler_angles_to_matrix,
matrix_to_quaternion,
)
try:
from curobo.types.base import TensorDeviceType
from curobo.types.math import Pose as Curobo_Pose
from curobo.types.robot import RobotConfig
from curobo.util_file import get_robot_configs_path, join_path, load_yaml
from curobo.wrap.reacher.ik_solver import IKSolver, IKSolverConfig
except:
pass
# NOTE(jigu): not necessary to inherit, just for convenience
class PDEEPosController(PDJointPosController):
config: "PDEEPosControllerConfig"

def _initialize_joints(self):
self.initial_qpos = None
super()._initialize_joints()

# Pinocchio model to compute IK
# TODO (stao): Batched IK? https://curobo.org/source/getting_started/2a_python_examples.html#inverse-kinematics
if physx.is_gpu_enabled():

tensor_args = TensorDeviceType()

config_file = load_yaml(
"/home/stao/work/research/maniskill/ManiSkill2/mani_skill2/assets/robots/panda/panda_v2.yml"
)
urdf_file = config_file["robot_cfg"]["kinematics"]["urdf_path"]
base_link = "panda_link0"
ee_link = config_file["robot_cfg"]["kinematics"]["ee_link"]
robot_cfg = RobotConfig.from_basic(
urdf_file, base_link, ee_link, tensor_args
)

ik_config = IKSolverConfig.load_from_robot_config(
robot_cfg,
None,
rotation_threshold=0.05,
position_threshold=0.005,
num_seeds=20,
self_collision_check=False,
self_collision_opt=False,
tensor_args=tensor_args,
use_cuda_graph=True,
)
self.curobo_ik_solver = IKSolver(ik_config)
with open(self.config.urdf_path, "r") as f:
urdf_str = f.read()
with suppress_stdout_stderr():
self.pk_chain = pk.build_serial_chain_from_urdf(urdf_str, end_link_name=self.config.ee_link,).to(device="cuda")
else:
self.pmodel = self.articulation.create_pinocchio_model()
# TODO should we just use jacobian inverse * delta method from pk?
self.pmodel = self.articulation._objs[0].create_pinocchio_model()
self.qmask = np.zeros(self.articulation.dof, dtype=bool)
self.qmask[self.joint_indices] = 1

Expand Down Expand Up @@ -101,23 +79,18 @@ def reset(self):
super().reset()
self._target_pose = self.ee_pose_at_base

def compute_ik(self, target_pose: Pose, max_iterations=100):
def compute_ik(self, target_pose: Pose, action: Array, max_iterations=100):
# Assume the target pose is defined in the base frame
# TODO (arth): currently ik only supports cpu, so input/output is managed as such
# in future, need to change input/output processing per gpu implementation

if physx.is_gpu_enabled():
initial_qpos = self.curobo_ik_solver.sample_configs(self.scene.num_envs)
initial_qpos = self.articulation.get_qpos()[:, self.qmask]
kin_state = self.curobo_ik_solver.fk(initial_qpos)
# goal = Curobo_Pose(kin_state.ee_position, kin_state.ee_quaternion)
goal = Curobo_Pose(target_pose.p, target_pose.q)
# import ipdb;ipdb.set_trace()
result = self.curobo_ik_solver.solve_batch(goal)
q_solution = result.solution[result.success] # (N, 1, dof)
if torch.all(result.success):
return q_solution
return None
jacobian = self.pk_chain.jacobian(self.articulation.get_qpos()[:, self.qmask])
# NOTE (stao): a bit of a hacky way to check if we want to do IK on position or pose here
if action.shape[1] == 3:
jacobian_pos = jacobian[:, 0:3]
jacobian_pinv = torch.linalg.pinv(jacobian_pos)
delta_joint_pos = 1.0 * jacobian_pinv @ action.unsqueeze(-1)
delta_joint_pos = delta_joint_pos.squeeze(-1)
return self.articulation.get_qpos()[:, self.qmask] + delta_joint_pos

else:
result, success, error = self.pmodel.compute_inverse_kinematics(
self.ee_link_idx,
Expand Down Expand Up @@ -148,7 +121,7 @@ def compute_target_pose(self, prev_ee_pose_at_base, action):

return target_pose

def set_action(self, action: np.ndarray):
def set_action(self, action: Array):
action = self._preprocess_action(action)
self._step = 0
self._start_qpos = self.qpos
Expand All @@ -159,7 +132,7 @@ def set_action(self, action: np.ndarray):
prev_ee_pose_at_base = self.ee_pose_at_base

self._target_pose = self.compute_target_pose(prev_ee_pose_at_base, action)
self._target_qpos = self.compute_ik(self._target_pose)
self._target_qpos = self.compute_ik(self._target_pose, action)
if self._target_qpos is None:
self._target_qpos = self._start_qpos
if self.config.interpolate:
Expand Down Expand Up @@ -187,6 +160,7 @@ class PDEEPosControllerConfig(ControllerConfig):
force_limit: Union[float, Sequence[float]] = 1e10
friction: Union[float, Sequence[float]] = 0.0
ee_link: str = None
urdf_path: str = None
frame: str = "ee" # [base, ee]
use_delta: bool = True
use_target: bool = False
Expand Down Expand Up @@ -267,6 +241,7 @@ class PDEEPoseControllerConfig(ControllerConfig):
force_limit: Union[float, Sequence[float]] = 1e10
friction: Union[float, Sequence[float]] = 0.0
ee_link: str = None
urdf_path: str = None
frame: str = "ee" # [base, ee, ee_align]
use_delta: bool = True
use_target: bool = False
Expand Down
2 changes: 2 additions & 0 deletions mani_skill2/agents/robots/panda/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def controller_configs(self):
self.arm_damping,
self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path
)
arm_pd_ee_delta_pose = PDEEPoseControllerConfig(
self.arm_joint_names,
Expand All @@ -108,6 +109,7 @@ def controller_configs(self):
self.arm_damping,
self.arm_force_limit,
ee_link=self.ee_link_name,
urdf_path=self.urdf_path
)

arm_pd_ee_target_delta_pos = deepcopy(arm_pd_ee_delta_pos)
Expand Down
114 changes: 0 additions & 114 deletions mani_skill2/assets/robots/panda/panda_v2.yml

This file was deleted.

2 changes: 1 addition & 1 deletion mani_skill2/utils/structs/articulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def create_pinocchio_model(self):
# NOTE (stao): This is available but not typed in SAPIEN
if physx.is_gpu_enabled():
raise NotImplementedError(
"Cannot create a pinocchio model when GPU is enabled. If you wish to do inverse kinematics you must use curobo"
"Cannot create a pinocchio model when GPU is enabled. If you wish to do inverse kinematics you must use pytorch_kinematics"
)
else:
return self._objs[0].create_pinocchio_model()
Expand Down
3 changes: 2 additions & 1 deletion manualtest/visual_all_envs_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
num_envs=num_envs,
enable_shadow=True,
render_mode="rgb_array",
control_mode="pd_ee_delta_pose",
control_mode="pd_ee_delta_pos",
sim_freq=500,
control_freq=100,
)
Expand All @@ -36,6 +36,7 @@
while i < 50:
obs, rew, terminated, truncated, info = env.step(env.action_space.sample())
done = np.logical_or(to_numpy(terminated), to_numpy(truncated))
print(rew)
if num_envs == 1:
env.render_human()
done = done.any()
Expand Down

0 comments on commit c205d4c

Please sign in to comment.