Skip to content

Commit

Permalink
add minimal pick cube task with "task sheet"
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 23, 2024
1 parent e07b630 commit 3d6a75b
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 8 deletions.
46 changes: 46 additions & 0 deletions mani_skill2/envs/minimal_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from collections import OrderedDict
from typing import Any, Dict

import numpy as np
import torch

from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at


@register_env(name="CustomEnv-v0", max_episode_steps=200)
class CustomEnv(BaseEnv):
def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
super().__init__(*args, robot_uid=robot_uid, **kwargs)

def _register_sensors(self):
pose = look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [
CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10)
]

def _register_render_cameras(self):
pose = look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10)

def _load_actors(self):
pass

def _initialize_actors(self):
pass

def _get_obs_extra(self):
return OrderedDict()

def evaluate(self, obs: Any):
return {"success": torch.zeros(self.num_envs, device=self.device)}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
return torch.zeros(self.num_envs, device=self.device)

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
1 change: 1 addition & 0 deletions mani_skill2/envs/tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ optimal solution horizon: average time to solve the task when uniformally perfor

env_id, task code, max_episode_steps, optimal solution_horizon, difficulty
PushCube-v0, push_cube.py, 50, ~12, 1
PickCube-v0, pick_cube.py, 100, ~20, 3?
1 change: 1 addition & 0 deletions mani_skill2/envs/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .fmb.fmb import FMBEnv
from .pick_cube import PickCubeEnv
from .push_cube import PushCubeEnv
from .push_object import PushObjectEnv
120 changes: 120 additions & 0 deletions mani_skill2/envs/tasks/pick_cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from collections import OrderedDict
from typing import Any, Dict

import numpy as np
import torch

import mani_skill2.envs.utils.randomization as randomization
from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.building.actors import build_cube, build_sphere
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at
from mani_skill2.utils.scene_builder.table.table_scene_builder import TableSceneBuilder
from mani_skill2.utils.structs.pose import Pose


@register_env("PickCube-v1", max_episode_steps=100)
class PickCubeEnv(BaseEnv):
"""
A simple task where the objective is to grasp a cube and move it to a target goal position.
Randomizations
--------------
- the cube's xy position is randomized on top of a table in the region [0.1, 0.1] x [-0.1, -0.1]. It is placed flat on the table
- the cube's z-axis rotation is randomized to a random angle
- the target goal position (marked by a green sphere) of the cube has its xy position randomized in the region [0.1, 0.1] x [-0.1, -0.1] and z randomized in [0, 0.3]
Success Conditions
------------------
- the cube position is within goal_thresh (default 0.025) euclidean distance of the goal position
Visualization: TODO: ADD LINK HERE
Changelog:
Different to v0, v1 does not require the robot to be static at the end which makes this task similar to other benchmarks and also easier
"""

cube_half_size = 0.02
goal_thresh = 0.025

def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
super().__init__(*args, robot_uid=robot_uid, **kwargs)

def _register_sensors(self):
pose = look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [
CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10)
]

def _register_render_cameras(self):
pose = look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10)

def _load_actors(self):
self.table_scene = TableSceneBuilder(
self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.table_scene.build()
self.cube = build_cube(
self._scene, half_size=self.cube_half_size, color=[1, 0, 0, 1], name="cube"
)
self.goal_site = build_sphere(
self._scene,
radius=self.goal_thresh,
color=[0, 1, 0, 1],
name="goal_site",
body_type="kinematic",
add_collision=False,
)

def _initialize_actors(self):
self.table_scene.initialize()
xyz = np.zeros((self.num_envs, 3))
xyz[:, :2] = self._episode_rng.uniform(-0.1, 0.1, [self.num_envs, 2])
xyz[:, 2] = self.cube_half_size
qs = randomization.random_quaternions(
self._episode_rng, lock_x=True, lock_y=True, n=self.num_envs
)
self.cube.set_pose(Pose.create_from_pq(xyz, qs, device=self.device))

goal_xyz = np.zeros((self.num_envs, 3))
goal_xyz[:, :2] = self._episode_rng.uniform(-0.1, 0.1, [self.num_envs, 2])
goal_xyz[:, 2] = self._episode_rng.uniform(0, 0.3, [self.num_envs]) + xyz[:, 2]
self.goal_site.set_pose(Pose.create_from_pq(goal_xyz, device=self.device))

def _get_obs_extra(self):
obs = OrderedDict(tcp_pose=self.agent.tcp.pose, goal_pos=self.goal_site)
if "state" in self.obs_mode:
obs.update(obs_pose=self.cube.pose.raw_pose)

def evaluate(self, obs: Any):
is_obj_placed = (
torch.linalg.norm(self.goal_site.pose.p - self.cube.pose.p, axis=1)
<= self.goal_thresh
)
return {"success": is_obj_placed, "is_obj_placed": is_obj_placed}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
tcp_to_obj_dist = torch.linalg.norm(
self.cube.pose.p - self.agent.tcp.pose.p, axis=1
)
reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
reward = reaching_reward

is_grasped = self.agent.is_grasping(self.cube)
reward += is_grasped

obj_to_goal_dist = torch.linalg.norm(
self.goal_site.pose.p - self.cube.pose.p, axis=1
)
place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
reward += place_reward * is_grasped

reward[info["success"]] = 4
return reward

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
return self.compute_dense_reward(obs=obs, action=action, info=info) / 4
22 changes: 19 additions & 3 deletions mani_skill2/envs/tasks/push_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,28 @@
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at
from mani_skill2.utils.scene_builder.table.table_scene_builder import TableSceneBuilder
from mani_skill2.utils.structs.pose import Pose, vectorize_pose
from mani_skill2.utils.structs.pose import Pose
from mani_skill2.utils.structs.types import Array


@register_env("PushCube-v0", max_episode_steps=50)
class PushCubeEnv(BaseEnv):
"""
A simple task where the objective is to push and move a cube to a goal region in front of it
Randomizations
--------------
- the cube's xy position is randomized on top of a table in the region [0.1, 0.1] x [-0.1, -0.1]. It is placed flat on the table
- the target goal region is marked by a red/white circular target. The position of the target is fixed to be the cube xy position + [0.1 + goal_radius, 0]
Success Conditions
------------------
- the cube's xy position is within goal_radius (default 0.1) of the target's xy position by euclidean distance.
Visualization: TODO: ADD LINK HERE
"""

# Specify some supported robot types
agent: Union[Panda, Xmate3Robotiq]

Expand Down Expand Up @@ -124,14 +140,14 @@ def _get_obs_extra(self):
# some useful observation info for solving the task includes the pose of the tcp (tool center point) which is the point between the
# grippers of the robot
obs = OrderedDict(
tcp_pose=vectorize_pose(self.agent.tcp.pose),
tcp_pose=self.agent.tcp.pose.raw_pose,
goal_pos=self.goal_region.pose.p,
)
if self._obs_mode in ["state", "state_dict"]:
# if the observation mode is state/state_dict, we provide ground truth information about where the cube is.
# for visual observation modes one should rely on the sensed visual data to determine where the cube is
obs.update(
obj_pose=vectorize_pose(self.obj.pose),
obj_pose=self.obj.pose.raw_pose,
)
return obs

Expand Down
18 changes: 16 additions & 2 deletions mani_skill2/envs/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@

from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.registration import register
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at


# register the environment by a unique ID and specify a max time limit. Now once this file is imported you can do gym.make("CustomEnv-v0")
@register(name="CustomEnv-v0", max_episode_steps=200)
@register_env(name="CustomEnv-v0", max_episode_steps=200)
class CustomEnv(BaseEnv):
# in the __init__ function you can pick a default robot your task should use e.g. the panda robot
def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
Expand Down Expand Up @@ -130,3 +130,17 @@ def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Di
# this should be equal to compute_dense_reward / max possible reward
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward

def get_state(self):
# this function is important in order to allow accurate replaying of trajectories. Make sure to specify any
# non simulation state related data such as a random 3D goal position you generated
# alternatively you can skip this part if the environment's rewards, observations, success etc. are dependent on simulation data
# e.g. self.your_custom_actor.pose.p will always give you your actor's 3D position
state = super().get_state()
return torch.hstack([state, self.goal_pos])

def set_state(self, state):
# this function complements get_state and sets any non simulation state related data correctly so the environment behaves
# the exact same in terms of output rewards, observations, success etc. should you reset state to a given state and take the same actions
self.goal_pos = state[:, -3:]
super().set_state(state[:, :-3])
2 changes: 2 additions & 0 deletions mani_skill2/envs/utils/randomization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .pose import *
from .samplers import *
27 changes: 27 additions & 0 deletions mani_skill2/envs/utils/randomization/pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
import transforms3d

from mani_skill2.utils.geometry.rotation_conversions import (
euler_angles_to_matrix,
matrix_to_quaternion,
)
from mani_skill2.utils.sapien_utils import to_tensor


def random_quaternions(
rng: np.random.RandomState,
lock_x: bool = False,
lock_y: bool = False,
lock_z: bool = False,
n=1,
):
xyz_angles = rng.uniform(0, np.pi * 2, (n, 3))
if lock_x:
xyz_angles[:, 0] *= 0
if lock_y:
xyz_angles[:, 1] *= 0
if lock_z:
xyz_angles[:, 2] *= 0
return matrix_to_quaternion(
euler_angles_to_matrix(to_tensor(xyz_angles), convention="XYZ")
)
Empty file.
2 changes: 2 additions & 0 deletions mani_skill2/utils/sapien_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def _unbatch(array: Union[Array, Sequence]):
if isinstance(array, torch.Tensor):
return array.squeeze(0)
if isinstance(array, np.ndarray):
if array.shape == (1,):
return array.item()
if np.iterable(array) and array.shape[0] == 1:
return array.squeeze(0)
if isinstance(array, list):
Expand Down
2 changes: 0 additions & 2 deletions mani_skill2/utils/structs/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import sapien
import sapien.physx as physx
import torch
from transforms3d.quaternions import qmult, quat2mat

from mani_skill2.utils.geometry.geometry import rotate_vector
from mani_skill2.utils.geometry.rotation_conversions import (
quaternion_apply,
quaternion_multiply,
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# TODO (stao): reactivate old tasks once fixed
ENV_IDS = [
"LiftCube-v0",
"PickCube-v0",
"PickCube-v1",
"StackCube-v0",
"PickSingleYCB-v0",
# "PickClutterYCB-v0",
Expand Down

0 comments on commit 3d6a75b

Please sign in to comment.