-
Notifications
You must be signed in to change notification settings - Fork 201
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
605a1db
commit 1473fbb
Showing
6 changed files
with
470 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
runs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Soft Actor Critic | ||
|
||
```bash | ||
python sac.py --env-id="PickCube-v0" --num_envs=512 --total_timesteps=100000000 | ||
python cleanrl_ppo_liftcube_state_gpu.py --num_envs=512 --gamma=0.8 --gae_lambda=0.9 --update_epochs=8 --target_kl=0.1 --num_minibatches=16 --env_id="PickCube-v0" --total_timesteps=100000000 --num_steps=100 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,382 @@ | ||
""" | ||
SAC Code adapted from https://github.com/vwxyzjn/cleanrl/ | ||
""" | ||
import os | ||
import random | ||
import time | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import sapien | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import tyro | ||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from torch.utils.tensorboard import SummaryWriter | ||
from mani_skill2.utils.sapien_utils import to_numpy | ||
from mani_skill2.utils.wrappers.record import RecordEpisode | ||
|
||
from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv | ||
|
||
|
||
@dataclass | ||
class Args: | ||
exp_name: str = os.path.basename(__file__)[: -len(".py")] | ||
"""the name of this experiment""" | ||
seed: int = 1 | ||
"""seed of the experiment""" | ||
torch_deterministic: bool = True | ||
"""if toggled, `torch.backends.cudnn.deterministic=False`""" | ||
cuda: bool = True | ||
"""if toggled, cuda will be enabled by default""" | ||
track: bool = False | ||
"""if toggled, this experiment will be tracked with Weights and Biases""" | ||
wandb_project_name: str = "cleanRL" | ||
"""the wandb's project name""" | ||
wandb_entity: str = None | ||
"""the entity (team) of wandb's project""" | ||
capture_video: bool = True | ||
"""whether to capture videos of the agent performances (check out `videos` folder)""" | ||
replay_buffer_on_gpu: bool = False # TODO (stao): implement | ||
|
||
# Algorithm specific arguments | ||
env_id: str = "PickCube-v0" | ||
"""the environment id of the task""" | ||
total_timesteps: int = 100_000_000 | ||
"""total timesteps of the experiments""" | ||
num_envs: int = 256 | ||
"""the number of parallel game environments""" | ||
num_eval_envs: int = 8 | ||
"""the number of evaluation environments""" | ||
buffer_size: int = 1_000_000 | ||
"""the replay memory buffer size""" | ||
gamma: float = 0.8 | ||
"""the discount factor gamma""" | ||
tau: float = 0.005 | ||
"""target smoothing coefficient (default: 0.005)""" | ||
batch_size: int = 4096 | ||
"""the batch size of sample from the replay memory""" | ||
learning_starts: int = 25600 | ||
"""timestep to start learning""" | ||
policy_lr: float = 3e-4 | ||
"""the learning rate of the policy network optimizer""" | ||
q_lr: float = 3e-4 | ||
"""the learning rate of the Q network network optimizer""" | ||
policy_frequency: int = 1 | ||
"""the frequency of training policy (delayed)""" | ||
target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. | ||
"""the frequency of updates for the target nerworks""" | ||
alpha: float = 1.0 | ||
"""Entropy regularization coefficient.""" | ||
autotune: bool = True | ||
"""automatic tuning of the entropy coefficient""" | ||
grad_updates_per_step: int = 16 | ||
"""number of critic gradient updates per parallel step through all environments""" | ||
steps_per_env: int = 1 | ||
"""number of steps each parallel env takes before performing gradient updates""" | ||
eval_freq: int = 1000 | ||
"""evaluation frequency in terms of iterations""" | ||
num_iterations: int = 0 | ||
"""the number of iterations (computed in runtime)""" | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
if capture_video and idx == 0: | ||
env = gym.make(env_id, render_mode="rgb_array") | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
else: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
env.action_space.seed(seed) | ||
return env | ||
|
||
return thunk | ||
|
||
|
||
# ALGO LOGIC: initialize agent here: | ||
class SoftQNetwork(nn.Module): | ||
def __init__(self, env): | ||
super().__init__() | ||
self.mlp = nn.Sequential(*[ | ||
nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256), | ||
nn.ReLU(), | ||
nn.Linear(256, 256), | ||
nn.ReLU(), | ||
nn.Linear(256, 256), | ||
nn.LayerNorm(256), | ||
nn.ReLU(), | ||
nn.Linear(256, 1) | ||
]) | ||
def forward(self, x, a): | ||
x = torch.cat([x, a], 1) | ||
x = self.mlp(x) | ||
return x | ||
|
||
|
||
LOG_STD_MAX = 2 | ||
LOG_STD_MIN = -5 | ||
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, env): | ||
super().__init__() | ||
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) | ||
self.fc2 = nn.Linear(256, 256) | ||
self.fc3 = nn.Linear(256, 256) | ||
self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape)) | ||
self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape)) | ||
# action rescaling | ||
self.register_buffer( | ||
"action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float32) | ||
) | ||
self.register_buffer( | ||
"action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float32) | ||
) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
mean = self.fc_mean(x) | ||
log_std = self.fc_logstd(x) | ||
log_std = torch.tanh(log_std) | ||
log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats | ||
|
||
return mean, log_std | ||
|
||
def get_action(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
mean, log_std = self(x) | ||
std = log_std.exp() | ||
normal = torch.distributions.Normal(mean, std) | ||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) | ||
y_t = torch.tanh(x_t) | ||
action = y_t * self.action_scale + self.action_bias | ||
log_prob = normal.log_prob(x_t) | ||
# Enforcing Action Bound | ||
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) | ||
log_prob = log_prob.sum(1, keepdim=True) | ||
mean = torch.tanh(mean) * self.action_scale + self.action_bias | ||
return action, log_prob, mean | ||
|
||
def get_eval_action(self, x): | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
mean = self.fc_mean(x) | ||
return torch.tanh(mean) * self.action_scale + self.action_bias | ||
|
||
|
||
if __name__ == "__main__": | ||
import stable_baselines3 as sb3 | ||
|
||
if sb3.__version__ < "2.0": | ||
raise ValueError( | ||
"""Ongoing migration: run the following command to install the new dependencies: | ||
poetry run pip install "stable_baselines3==2.0.0a1" | ||
""" | ||
) | ||
|
||
args = tyro.cli(Args) | ||
args.num_iterations = args.total_timesteps // args.num_envs | ||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
if args.track: | ||
import wandb | ||
|
||
wandb.init( | ||
project=args.wandb_project_name, | ||
entity=args.wandb_entity, | ||
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
# TRY NOT TO MODIFY: seeding | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
||
# env setup | ||
# envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) | ||
sapien.physx.set_gpu_memory_config(found_lost_pairs_capacity=2**26, max_rigid_patch_count=200000) | ||
env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_freq=100, control_freq=20) | ||
envs = ManiSkillVectorEnv(args.env_id, args.num_envs, env_kwargs) | ||
eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) | ||
if args.capture_video: | ||
eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, video_fps=30) | ||
eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, env_kwargs) | ||
|
||
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" | ||
|
||
max_action = float(envs.single_action_space.high[0]) | ||
|
||
actor = Actor(envs).to(device) | ||
qf1 = SoftQNetwork(envs).to(device) | ||
qf2 = SoftQNetwork(envs).to(device) | ||
qf1_target = SoftQNetwork(envs).to(device) | ||
qf2_target = SoftQNetwork(envs).to(device) | ||
qf1_target.load_state_dict(qf1.state_dict()) | ||
qf2_target.load_state_dict(qf2.state_dict()) | ||
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr) | ||
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr) | ||
|
||
# Automatic entropy tuning | ||
if args.autotune: | ||
target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() | ||
log_alpha = torch.zeros(1, requires_grad=True, device=device) | ||
alpha = log_alpha.exp().item() | ||
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr) | ||
else: | ||
alpha = args.alpha | ||
|
||
envs.single_observation_space.dtype = np.float32 | ||
rb = ReplayBuffer( | ||
args.buffer_size, | ||
envs.single_observation_space, | ||
envs.single_action_space, | ||
device, | ||
n_envs=args.num_envs, | ||
handle_timeout_termination=False, | ||
) | ||
start_time = time.time() | ||
global_step = 0 | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
obs, _ = envs.reset(seed=args.seed) | ||
eval_obs, _ = eval_envs.reset(seed=args.seed) | ||
for iteration in range(1, args.num_iterations + 1): | ||
# Evaluation Code | ||
if iteration % args.eval_freq == 1: | ||
print("Evaluating") | ||
eval_done = False | ||
while not eval_done: | ||
with torch.no_grad(): | ||
eval_obs, _, eval_terminations, eval_truncations, eval_infos = eval_envs.step(actor.get_eval_action(eval_obs)) | ||
if eval_truncations.any(): | ||
eval_done = True | ||
info = eval_infos["final_info"] | ||
episodic_return = info['episode']['r'].mean().cpu().numpy() | ||
print(f"eval_episodic_return={episodic_return}") | ||
writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step) | ||
writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step) | ||
writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"], global_step) | ||
|
||
for _ in range(args.steps_per_env): | ||
# ALGO LOGIC: put action logic here | ||
if global_step < args.learning_starts: | ||
actions = torch.from_numpy(envs.action_space.sample()).to(device) | ||
else: | ||
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) | ||
actions = actions.detach() | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, rewards, terminations, truncations, infos = envs.step(actions) | ||
global_step += envs.num_envs | ||
dones = torch.logical_or(terminations, truncations) | ||
assert dones.any() == dones.all() | ||
|
||
# TRY NOT TO MODIFY: record rewards for plotting purposes | ||
if "final_info" in infos: # this means all parallel envs truncated | ||
info = infos["final_info"] | ||
episodic_return = info['episode']['r'].mean().cpu().numpy() | ||
print(f"global_step={global_step}, episodic_return={episodic_return}") | ||
writer.add_scalar("charts/success_rate", info["success"].float().mean().cpu().numpy(), global_step) | ||
writer.add_scalar("charts/episodic_return", episodic_return, global_step) | ||
writer.add_scalar("charts/episodic_length", info["elapsed_steps"], global_step) | ||
|
||
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` | ||
real_next_obs = next_obs.clone() | ||
if dones.any(): | ||
real_next_obs = infos["final_observation"].clone() | ||
if not args.replay_buffer_on_gpu: | ||
real_next_obs = real_next_obs.cpu().numpy() | ||
obs = obs.cpu().numpy() | ||
actions = actions.cpu().numpy() | ||
rewards = rewards.cpu().numpy() | ||
dones = dones.cpu().numpy() | ||
infos = to_numpy(infos) | ||
rb.add(obs, real_next_obs, actions, rewards, dones, infos) | ||
|
||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook | ||
obs = next_obs | ||
|
||
# ALGO LOGIC: training. | ||
if global_step > args.learning_starts: | ||
for _ in range(args.grad_updates_per_step): | ||
data = rb.sample(args.batch_size) | ||
with torch.no_grad(): | ||
next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations) | ||
qf1_next_target = qf1_target(data.next_observations, next_state_actions) | ||
qf2_next_target = qf2_target(data.next_observations, next_state_actions) | ||
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi | ||
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) | ||
|
||
qf1_a_values = qf1(data.observations, data.actions).view(-1) | ||
qf2_a_values = qf2(data.observations, data.actions).view(-1) | ||
qf1_loss = F.mse_loss(qf1_a_values, next_q_value) | ||
qf2_loss = F.mse_loss(qf2_a_values, next_q_value) | ||
qf_loss = qf1_loss + qf2_loss | ||
|
||
# optimize the model | ||
q_optimizer.zero_grad() | ||
qf_loss.backward() | ||
q_optimizer.step() | ||
|
||
if iteration % args.policy_frequency == 0: # TD 3 Delayed update support | ||
for _ in range( | ||
args.policy_frequency | ||
): # compensate for the delay by doing 'actor_update_interval' instead of 1 | ||
pi, log_pi, _ = actor.get_action(data.observations) | ||
qf1_pi = qf1(data.observations, pi) | ||
qf2_pi = qf2(data.observations, pi) | ||
min_qf_pi = torch.min(qf1_pi, qf2_pi) | ||
actor_loss = ((alpha * log_pi) - min_qf_pi).mean() | ||
|
||
actor_optimizer.zero_grad() | ||
actor_loss.backward() | ||
actor_optimizer.step() | ||
|
||
if args.autotune: | ||
with torch.no_grad(): | ||
_, log_pi, _ = actor.get_action(data.observations) | ||
alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() | ||
|
||
a_optimizer.zero_grad() | ||
alpha_loss.backward() | ||
a_optimizer.step() | ||
alpha = log_alpha.exp().item() | ||
|
||
# update the target networks | ||
if iteration % args.target_network_frequency == 0: | ||
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): | ||
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) | ||
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): | ||
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) | ||
|
||
if iteration % 10 == 0: | ||
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) | ||
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) | ||
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) | ||
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) | ||
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) | ||
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) | ||
writer.add_scalar("losses/alpha", alpha, global_step) | ||
print("SPS:", int(global_step / (time.time() - start_time))) | ||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
if args.autotune: | ||
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step) | ||
|
||
envs.close() | ||
writer.close() |
Oops, something went wrong.