Skip to content

Commit

Permalink
sac impl
Browse files Browse the repository at this point in the history
  • Loading branch information
StoneT2000 committed Jan 22, 2024
1 parent 605a1db commit 1473fbb
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/baselines/sac/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
runs
6 changes: 6 additions & 0 deletions examples/baselines/sac/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Soft Actor Critic

```bash
python sac.py --env-id="PickCube-v0" --num_envs=512 --total_timesteps=100000000
python cleanrl_ppo_liftcube_state_gpu.py --num_envs=512 --gamma=0.8 --gae_lambda=0.9 --update_epochs=8 --target_kl=0.1 --num_minibatches=16 --env_id="PickCube-v0" --total_timesteps=100000000 --num_steps=100
```
382 changes: 382 additions & 0 deletions examples/baselines/sac/sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,382 @@
"""
SAC Code adapted from https://github.com/vwxyzjn/cleanrl/
"""
import os
import random
import time
from dataclasses import dataclass
from typing import Tuple

import gymnasium as gym
import numpy as np
import sapien
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from mani_skill2.utils.sapien_utils import to_numpy
from mani_skill2.utils.wrappers.record import RecordEpisode

from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv


@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = True
"""whether to capture videos of the agent performances (check out `videos` folder)"""
replay_buffer_on_gpu: bool = False # TODO (stao): implement

# Algorithm specific arguments
env_id: str = "PickCube-v0"
"""the environment id of the task"""
total_timesteps: int = 100_000_000
"""total timesteps of the experiments"""
num_envs: int = 256
"""the number of parallel game environments"""
num_eval_envs: int = 8
"""the number of evaluation environments"""
buffer_size: int = 1_000_000
"""the replay memory buffer size"""
gamma: float = 0.8
"""the discount factor gamma"""
tau: float = 0.005
"""target smoothing coefficient (default: 0.005)"""
batch_size: int = 4096
"""the batch size of sample from the replay memory"""
learning_starts: int = 25600
"""timestep to start learning"""
policy_lr: float = 3e-4
"""the learning rate of the policy network optimizer"""
q_lr: float = 3e-4
"""the learning rate of the Q network network optimizer"""
policy_frequency: int = 1
"""the frequency of training policy (delayed)"""
target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2.
"""the frequency of updates for the target nerworks"""
alpha: float = 1.0
"""Entropy regularization coefficient."""
autotune: bool = True
"""automatic tuning of the entropy coefficient"""
grad_updates_per_step: int = 16
"""number of critic gradient updates per parallel step through all environments"""
steps_per_env: int = 1
"""number of steps each parallel env takes before performing gradient updates"""
eval_freq: int = 1000
"""evaluation frequency in terms of iterations"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env

return thunk


# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.mlp = nn.Sequential(*[
nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 1)
])
def forward(self, x, a):
x = torch.cat([x, a], 1)
x = self.mlp(x)
return x


LOG_STD_MAX = 2
LOG_STD_MIN = -5


class Actor(nn.Module):
def __init__(self, env):
super().__init__()
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 256)
self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
# action rescaling
self.register_buffer(
"action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float32)
)
self.register_buffer(
"action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float32)
)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
mean = self.fc_mean(x)
log_std = self.fc_logstd(x)
log_std = torch.tanh(log_std)
log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats

return mean, log_std

def get_action(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mean, log_std = self(x)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
y_t = torch.tanh(x_t)
action = y_t * self.action_scale + self.action_bias
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean

def get_eval_action(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
mean = self.fc_mean(x)
return torch.tanh(mean) * self.action_scale + self.action_bias


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)

args = tyro.cli(Args)
args.num_iterations = args.total_timesteps // args.num_envs
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
# envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
sapien.physx.set_gpu_memory_config(found_lost_pairs_capacity=2**26, max_rigid_patch_count=200000)
env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_freq=100, control_freq=20)
envs = ManiSkillVectorEnv(args.env_id, args.num_envs, env_kwargs)
eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs)
if args.capture_video:
eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, video_fps=30)
eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, env_kwargs)

assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])

actor = Actor(envs).to(device)
qf1 = SoftQNetwork(envs).to(device)
qf2 = SoftQNetwork(envs).to(device)
qf1_target = SoftQNetwork(envs).to(device)
qf2_target = SoftQNetwork(envs).to(device)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)

# Automatic entropy tuning
if args.autotune:
target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
else:
alpha = args.alpha

envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
n_envs=args.num_envs,
handle_timeout_termination=False,
)
start_time = time.time()
global_step = 0

# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
eval_obs, _ = eval_envs.reset(seed=args.seed)
for iteration in range(1, args.num_iterations + 1):
# Evaluation Code
if iteration % args.eval_freq == 1:
print("Evaluating")
eval_done = False
while not eval_done:
with torch.no_grad():
eval_obs, _, eval_terminations, eval_truncations, eval_infos = eval_envs.step(actor.get_eval_action(eval_obs))
if eval_truncations.any():
eval_done = True
info = eval_infos["final_info"]
episodic_return = info['episode']['r'].mean().cpu().numpy()
print(f"eval_episodic_return={episodic_return}")
writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step)
writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step)
writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"], global_step)

for _ in range(args.steps_per_env):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = torch.from_numpy(envs.action_space.sample()).to(device)
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
actions = actions.detach()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
global_step += envs.num_envs
dones = torch.logical_or(terminations, truncations)
assert dones.any() == dones.all()

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos: # this means all parallel envs truncated
info = infos["final_info"]
episodic_return = info['episode']['r'].mean().cpu().numpy()
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/success_rate", info["success"].float().mean().cpu().numpy(), global_step)
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", info["elapsed_steps"], global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.clone()
if dones.any():
real_next_obs = infos["final_observation"].clone()
if not args.replay_buffer_on_gpu:
real_next_obs = real_next_obs.cpu().numpy()
obs = obs.cpu().numpy()
actions = actions.cpu().numpy()
rewards = rewards.cpu().numpy()
dones = dones.cpu().numpy()
infos = to_numpy(infos)
rb.add(obs, real_next_obs, actions, rewards, dones, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts:
for _ in range(args.grad_updates_per_step):
data = rb.sample(args.batch_size)
with torch.no_grad():
next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)
qf1_next_target = qf1_target(data.next_observations, next_state_actions)
qf2_next_target = qf2_target(data.next_observations, next_state_actions)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

qf1_a_values = qf1(data.observations, data.actions).view(-1)
qf2_a_values = qf2(data.observations, data.actions).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss

# optimize the model
q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()

if iteration % args.policy_frequency == 0: # TD 3 Delayed update support
for _ in range(
args.policy_frequency
): # compensate for the delay by doing 'actor_update_interval' instead of 1
pi, log_pi, _ = actor.get_action(data.observations)
qf1_pi = qf1(data.observations, pi)
qf2_pi = qf2(data.observations, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()

if args.autotune:
with torch.no_grad():
_, log_pi, _ = actor.get_action(data.observations)
alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().item()

# update the target networks
if iteration % args.target_network_frequency == 0:
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

if iteration % 10 == 0:
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
writer.add_scalar("losses/alpha", alpha, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.autotune:
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
writer.close()
Loading

0 comments on commit 1473fbb

Please sign in to comment.