Skip to content

Commit

Permalink
using torch rng
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 24, 2024
1 parent 081e234 commit 4ffcdfe
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 49 deletions.
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ pip install -e . # install MS2 locally
pip install pytest coverage stable-baselines3 # add development dependencies for testing purposes
```

Then to setup pre-commit, run

```
pre-commit install
```

## Testing

Testing is currently semi-automated and a WIP. We currently rely on coverage.py and pytest to test ManiSkill2.
Expand Down
4 changes: 2 additions & 2 deletions examples/baselines/ppo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
Code adapted from [CleanRL](https://github.com/vwxyzjn/cleanrl/)

```bash
python ppo.py --num_envs=512 --gamma=0.8 --gae_lambda=0.9 --update_epochs=8 --target_kl=0.1 --num_minibatches=32 --env_id="PickCube-v1" --total_timesteps=100000000 --num_steps=100
python cleanrl_ppo_liftcube_state_gpu.py --num_envs=2048 --gamma=0.8 --gae_lambda=0.9 --update_epochs=1 --num_minibatches=32 --env_id="PushCube-v0" --total_timesteps=100000000 --num-steps=12
python ppo.py --num_envs=512 --update_epochs=8 --target_kl=0.1 --num_minibatches=32 --env_id="PickCube-v1" --total_timesteps=100000000 --num_steps=100
python ppo.py --num_envs=2048 --update_epochs=1 --num_minibatches=32 --env_id="PushCube-v1" --total_timesteps=100000000 --num-steps=12
```
6 changes: 3 additions & 3 deletions examples/baselines/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ def get_action(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean

def get_eval_action(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
mean = self.fc_mean(x)
mean = self.fc_mean(x)
return torch.tanh(mean) * self.action_scale + self.action_bias


Expand Down Expand Up @@ -379,4 +379,4 @@ def get_eval_action(self, x):
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
writer.close()
writer.close()
3 changes: 2 additions & 1 deletion mani_skill2/envs/minimal_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class CustomEnv(BaseEnv):
Success Conditions
------------------
Visualization: link to a video/gif of the task being solved
"""

def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
super().__init__(*args, robot_uid=robot_uid, **kwargs)
Expand Down
85 changes: 51 additions & 34 deletions mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,28 +448,30 @@ def reconfigure(self):
Tasks like PegInsertionSide and TurnFaucet will call this each time as the peg
shape changes each time and the faucet model changes each time respectively.
"""
self._clear()

# load everything into the scene first before initializing anything
self._setup_scene()
self._load_agent()
self._load_actors()
self._load_articulations()

self._setup_lighting()

# Cache entites and articulations
self._actors = self.get_actors()
self._articulations = self.get_articulations()
if sapien.physx.is_gpu_enabled():
self._scene._setup_gpu()
self.device = self.physx_system.cuda_rigid_body_data.device
self._scene._gpu_fetch_all()
# TODO (stao): unknown what happens when we do reconfigure more than once when GPU is one. figure this out
self.agent.initialize()
self._setup_sensors() # for GPU sim, we have to setup sensors after we call setup gpu in order to enable loading mounted sensors
if self._viewer is not None:
self._setup_viewer()
with torch.random.fork_rng():
torch.manual_seed(seed=self._episode_seed)
self._clear()

# load everything into the scene first before initializing anything
self._setup_scene()
self._load_agent()
self._load_actors()
self._load_articulations()

self._setup_lighting()

# Cache entites and articulations
self._actors = self.get_actors()
self._articulations = self.get_articulations()
if sapien.physx.is_gpu_enabled():
self._scene._setup_gpu()
self.device = self.physx_system.cuda_rigid_body_data.device
self._scene._gpu_fetch_all()
# TODO (stao): unknown what happens when we do reconfigure more than once when GPU is one. figure this out
self.agent.initialize()
self._setup_sensors() # for GPU sim, we have to setup sensors after we call setup gpu in order to enable loading mounted sensors
if self._viewer is not None:
self._setup_viewer()

def _load_actors(self):
"""Loads all actors into the scene. Called by `self.reconfigure`"""
Expand Down Expand Up @@ -525,17 +527,30 @@ def _setup_lighting(self):
# Reset
# -------------------------------------------------------------------------- #
def reset(self, seed=None, options=None):
"""
Reset the ManiSkill environment
Note that ManiSkill always holds two RNG states, a main RNG, and an episode RNG. The main RNG is used purely to sample an episode seed which
helps with reproducibility of episodes. The episode RNG is used by the environment/task itself to e.g. randomize object positions, randomize assets etc.
Upon environment creation via gym.make, the main RNG is set with a fixed seed of 2022.
During each reset call, if seed is None, main RNG is unchanged and an episode seed is sampled from the main RNG to create the episode RNG.
If seed is not None, main RNG is set to that seed and the episode seed is also set to that seed.
Note that when giving a specific seed via `reset(seed=...)`, we always set the main RNG based on that seed. This then deterministically changes the **sequence** of RNG
used for each episode after each call to reset with `seed=None`. By default this sequence of rng starts with the default main seed used which is 2022,
which means that when creating an environment and resetting without a seed, it will always have the same sequence of RNG for each episode.
"""
self._elapsed_steps = 0
if options is None:
options = dict()

# when giving a specific seed, we always set the main RNG based on that seed. This then deterministically changes the **sequence** of RNG
# used for each episode after each call to reset with seed=none. By default this sequence of rng starts with the default main seed used which is 2022,
# which means that when creating an environment and resetting without a seed, it will always have the same sequence of RNG for each episode.
self._set_main_rng(seed)
self._set_episode_rng(
seed
) # we first set the first episode seed to allow environments to use it to reconfigure the environment with a seed
self._elapsed_steps = 0
# we first set the first episode seed to allow environments to use it to reconfigure the environment with a seed
self._set_episode_rng(seed)

reconfigure = options.get("reconfigure", False)
if reconfigure:
# Reconfigure the scene if assets change
Expand All @@ -558,7 +573,7 @@ def reset(self, seed=None, options=None):
return obs, {}

def _set_main_rng(self, seed):
"""Set the main random generator (e.g., to generate the seed for each episode)."""
"""Set the main random generator which is only used to set the seed of the episode RNG to improve reproducibility"""
if seed is None:
if self._main_seed is not None:
return
Expand All @@ -579,10 +594,12 @@ def initialize_episode(self):
"""Initialize the episode, e.g., poses of entities and articulations, and robot configuration.
No new assets are created. Task-relevant information can be initialized here, like goals.
"""
self._initialize_actors()
self._initialize_articulations()
self._initialize_agent()
self._initialize_task()
with torch.random.fork_rng():
torch.manual_seed(self._episode_seed)
self._initialize_actors()
self._initialize_articulations()
self._initialize_agent()
self._initialize_task()

def _initialize_actors(self):
"""Initialize the poses of actors. Called by `self.initialize_episode`"""
Expand Down
12 changes: 5 additions & 7 deletions mani_skill2/envs/tasks/push_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
import torch.random
from transforms3d.euler import euler2quat

from mani_skill2.agents.robots.panda.panda import Panda
Expand All @@ -34,7 +35,7 @@
from mani_skill2.utils.structs.types import Array


@register_env("PushCube-v0", max_episode_steps=50)
@register_env("PushCube-v1", max_episode_steps=50)
class PushCubeEnv(BaseEnv):
"""
Task Description
Expand Down Expand Up @@ -113,9 +114,7 @@ def _initialize_actors(self):

# here we write some randomization code that randomizes the x, y position of the cube we are pushing in the range [-0.1, -0.1] to [0.1, 0.1]
xyz = torch.zeros((self.num_envs, 3), device=self.device)
xyz[..., :2] = torch.from_numpy(
self._episode_rng.uniform(-0.1, 0.1, [self.num_envs, 2])
).cuda()
xyz[..., :2] = torch.rand((self.num_envs, 2), device=self.device) * 0.2 - 0.1
xyz[..., 2] = self.cube_half_size
q = [1, 0, 0, 0]
# we can then create a pose object using Pose.create_from_pq to then set the cube pose with. Note that even though our quaternion
Expand All @@ -128,9 +127,8 @@ def _initialize_actors(self):
target_region_xyz = xyz + torch.tensor(
[0.1 + self.goal_radius, 0, 0], device=self.device
)
target_region_xyz[
..., 2
] = 1e-3 # set a little bit above 0 so the target is sitting on the table
# set a little bit above 0 so the target is sitting on the table
target_region_xyz[..., 2] = 1e-3
self.goal_region.set_pose(
Pose.create_from_pq(
p=target_region_xyz,
Expand Down
65 changes: 65 additions & 0 deletions mani_skill2/envs/tasks/stack_cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from collections import OrderedDict
from typing import Any, Dict

import numpy as np
import torch

from mani_skill2.envs.sapien_env import BaseEnv
from mani_skill2.envs.utils.randomization.pose import random_quaternions
from mani_skill2.sensors.camera import CameraConfig
from mani_skill2.utils.building.actors import build_cube
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import look_at
from mani_skill2.utils.scene_builder.table.table_scene_builder import TableSceneBuilder
from mani_skill2.utils.structs.pose import Pose


@register_env(name="StackCube-v1", max_episode_steps=100)
class StackCubeEnv(BaseEnv):
def __init__(self, *args, robot_uid="panda", robot_init_qpos_noise=0.02, **kwargs):
self.robot_init_qpos_noise = robot_init_qpos_noise
super().__init__(*args, robot_uid=robot_uid, **kwargs)

def _register_sensors(self):
pose = look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [
CameraConfig("base_camera", pose.p, pose.q, 128, 128, np.pi / 2, 0.01, 10)
]

def _register_render_cameras(self):
pose = look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
return CameraConfig("render_camera", pose.p, pose.q, 512, 512, 1, 0.01, 10)

def _load_actors(self):
self.table_scene = TableSceneBuilder(
env=self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.table_scene.build()
self.box_half_size = torch.Tensor([0.02] * 3, device=self.device)
self.cubeA = build_cube(
self._scene, half_size=0.02, color=[1, 0, 0, 1], name="cubeA"
)
self.cubeB = build_cube(
self._scene, half_size=0.02, color=[0, 1, 0, 1], name="cubeB"
)

def _initialize_actors(self):
self.table_scene.initialize()
qs = random_quaternions(
self._episode_rng, lock_x=True, lock_y=True, lock_z=False, n=self.num_envs
)
ps = [0, 0, 0.02]
self.cubeA.set_pose(Pose.create_from_pq(p=ps, q=qs))

def _get_obs_extra(self):
return OrderedDict()

def evaluate(self, obs: Any):
return {"success": torch.zeros(self.num_envs, device=self.device, dtype=bool)}

def compute_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
return torch.zeros(self.num_envs, device=self.device)

def compute_normalized_dense_reward(self, obs: Any, action: np.ndarray, info: Dict):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
3 changes: 2 additions & 1 deletion mani_skill2/vector/wrappers/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def step(
new_infos["final_info"] = infos
new_infos["final_observation"] = final_obs
infos = new_infos
truncations = torch.ones_like(terminations) * truncations # gym timelimit wrapper returns a bool, for consistency we convert to a tensor here
# gym timelimit wrapper returns a bool, for consistency we convert to a tensor here
truncations = torch.ones_like(terminations) * truncations
return obs, rew, terminations, truncations, infos

def close(self):
Expand Down
2 changes: 1 addition & 1 deletion manualtest/visual_all_envs_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if __name__ == "__main__":
# , "StackCube-v0", "LiftCube-v0"
num_envs = 2
for env_id in ["PushObject-v0"]: # , "StackCube-v0", "LiftCube-v0"]:
for env_id in ["PushCube-v0"]: # , "StackCube-v0", "LiftCube-v0"]:
env = gym.make(
env_id,
num_envs=num_envs,
Expand Down

0 comments on commit 4ffcdfe

Please sign in to comment.