From 1473fbbdf01ac5f4903378553253742c255a2258 Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Sun, 21 Jan 2024 17:26:40 -0800 Subject: [PATCH] sac impl --- examples/baselines/sac/.gitignore | 1 + examples/baselines/sac/README.md | 6 + examples/baselines/sac/sac.py | 382 ++++++++++++++++++ .../cleanrl_ppo_liftcube_state_gpu.py | 2 +- mani_skill2/envs/sapien_env.py | 2 +- mani_skill2/vector/wrappers/gymnasium.py | 79 ++++ 6 files changed, 470 insertions(+), 2 deletions(-) create mode 100644 examples/baselines/sac/.gitignore create mode 100644 examples/baselines/sac/README.md create mode 100644 examples/baselines/sac/sac.py create mode 100644 mani_skill2/vector/wrappers/gymnasium.py diff --git a/examples/baselines/sac/.gitignore b/examples/baselines/sac/.gitignore new file mode 100644 index 000000000..cb1d07bf8 --- /dev/null +++ b/examples/baselines/sac/.gitignore @@ -0,0 +1 @@ +runs \ No newline at end of file diff --git a/examples/baselines/sac/README.md b/examples/baselines/sac/README.md new file mode 100644 index 000000000..a97b84f9f --- /dev/null +++ b/examples/baselines/sac/README.md @@ -0,0 +1,6 @@ +# Soft Actor Critic + +```bash +python sac.py --env-id="PickCube-v0" --num_envs=512 --total_timesteps=100000000 + python cleanrl_ppo_liftcube_state_gpu.py --num_envs=512 --gamma=0.8 --gae_lambda=0.9 --update_epochs=8 --target_kl=0.1 --num_minibatches=16 --env_id="PickCube-v0" --total_timesteps=100000000 --num_steps=100 +``` \ No newline at end of file diff --git a/examples/baselines/sac/sac.py b/examples/baselines/sac/sac.py new file mode 100644 index 000000000..06eea7814 --- /dev/null +++ b/examples/baselines/sac/sac.py @@ -0,0 +1,382 @@ +""" +SAC Code adapted from https://github.com/vwxyzjn/cleanrl/ +""" +import os +import random +import time +from dataclasses import dataclass +from typing import Tuple + +import gymnasium as gym +import numpy as np +import sapien +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter +from mani_skill2.utils.sapien_utils import to_numpy +from mani_skill2.utils.wrappers.record import RecordEpisode + +from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = True + """whether to capture videos of the agent performances (check out `videos` folder)""" + replay_buffer_on_gpu: bool = False # TODO (stao): implement + + # Algorithm specific arguments + env_id: str = "PickCube-v0" + """the environment id of the task""" + total_timesteps: int = 100_000_000 + """total timesteps of the experiments""" + num_envs: int = 256 + """the number of parallel game environments""" + num_eval_envs: int = 8 + """the number of evaluation environments""" + buffer_size: int = 1_000_000 + """the replay memory buffer size""" + gamma: float = 0.8 + """the discount factor gamma""" + tau: float = 0.005 + """target smoothing coefficient (default: 0.005)""" + batch_size: int = 4096 + """the batch size of sample from the replay memory""" + learning_starts: int = 25600 + """timestep to start learning""" + policy_lr: float = 3e-4 + """the learning rate of the policy network optimizer""" + q_lr: float = 3e-4 + """the learning rate of the Q network network optimizer""" + policy_frequency: int = 1 + """the frequency of training policy (delayed)""" + target_network_frequency: int = 1 # Denis Yarats' implementation delays this by 2. + """the frequency of updates for the target nerworks""" + alpha: float = 1.0 + """Entropy regularization coefficient.""" + autotune: bool = True + """automatic tuning of the entropy coefficient""" + grad_updates_per_step: int = 16 + """number of critic gradient updates per parallel step through all environments""" + steps_per_env: int = 1 + """number of steps each parallel env takes before performing gradient updates""" + eval_freq: int = 1000 + """evaluation frequency in terms of iterations""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + env.action_space.seed(seed) + return env + + return thunk + + +# ALGO LOGIC: initialize agent here: +class SoftQNetwork(nn.Module): + def __init__(self, env): + super().__init__() + self.mlp = nn.Sequential(*[ + nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.LayerNorm(256), + nn.ReLU(), + nn.Linear(256, 1) + ]) + def forward(self, x, a): + x = torch.cat([x, a], 1) + x = self.mlp(x) + return x + + +LOG_STD_MAX = 2 +LOG_STD_MIN = -5 + + +class Actor(nn.Module): + def __init__(self, env): + super().__init__() + self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 256) + self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape)) + self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape)) + # action rescaling + self.register_buffer( + "action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float32) + ) + self.register_buffer( + "action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float32) + ) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + mean = self.fc_mean(x) + log_std = self.fc_logstd(x) + log_std = torch.tanh(log_std) + log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats + + return mean, log_std + + def get_action(self, x) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mean, log_std = self(x) + std = log_std.exp() + normal = torch.distributions.Normal(mean, std) + x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) + y_t = torch.tanh(x_t) + action = y_t * self.action_scale + self.action_bias + log_prob = normal.log_prob(x_t) + # Enforcing Action Bound + log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6) + log_prob = log_prob.sum(1, keepdim=True) + mean = torch.tanh(mean) * self.action_scale + self.action_bias + return action, log_prob, mean + + def get_eval_action(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + mean = self.fc_mean(x) + return torch.tanh(mean) * self.action_scale + self.action_bias + + +if __name__ == "__main__": + import stable_baselines3 as sb3 + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: +poetry run pip install "stable_baselines3==2.0.0a1" +""" + ) + + args = tyro.cli(Args) + args.num_iterations = args.total_timesteps // args.num_envs + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + # envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + sapien.physx.set_gpu_memory_config(found_lost_pairs_capacity=2**26, max_rigid_patch_count=200000) + env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_freq=100, control_freq=20) + envs = ManiSkillVectorEnv(args.env_id, args.num_envs, env_kwargs) + eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) + if args.capture_video: + eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, video_fps=30) + eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, env_kwargs) + + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + max_action = float(envs.single_action_space.high[0]) + + actor = Actor(envs).to(device) + qf1 = SoftQNetwork(envs).to(device) + qf2 = SoftQNetwork(envs).to(device) + qf1_target = SoftQNetwork(envs).to(device) + qf2_target = SoftQNetwork(envs).to(device) + qf1_target.load_state_dict(qf1.state_dict()) + qf2_target.load_state_dict(qf2.state_dict()) + q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr) + actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr) + + # Automatic entropy tuning + if args.autotune: + target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item() + log_alpha = torch.zeros(1, requires_grad=True, device=device) + alpha = log_alpha.exp().item() + a_optimizer = optim.Adam([log_alpha], lr=args.q_lr) + else: + alpha = args.alpha + + envs.single_observation_space.dtype = np.float32 + rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + device, + n_envs=args.num_envs, + handle_timeout_termination=False, + ) + start_time = time.time() + global_step = 0 + + # TRY NOT TO MODIFY: start the game + obs, _ = envs.reset(seed=args.seed) + eval_obs, _ = eval_envs.reset(seed=args.seed) + for iteration in range(1, args.num_iterations + 1): + # Evaluation Code + if iteration % args.eval_freq == 1: + print("Evaluating") + eval_done = False + while not eval_done: + with torch.no_grad(): + eval_obs, _, eval_terminations, eval_truncations, eval_infos = eval_envs.step(actor.get_eval_action(eval_obs)) + if eval_truncations.any(): + eval_done = True + info = eval_infos["final_info"] + episodic_return = info['episode']['r'].mean().cpu().numpy() + print(f"eval_episodic_return={episodic_return}") + writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step) + writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step) + writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"], global_step) + + for _ in range(args.steps_per_env): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = torch.from_numpy(envs.action_space.sample()).to(device) + else: + actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) + actions = actions.detach() + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, terminations, truncations, infos = envs.step(actions) + global_step += envs.num_envs + dones = torch.logical_or(terminations, truncations) + assert dones.any() == dones.all() + + # TRY NOT TO MODIFY: record rewards for plotting purposes + if "final_info" in infos: # this means all parallel envs truncated + info = infos["final_info"] + episodic_return = info['episode']['r'].mean().cpu().numpy() + print(f"global_step={global_step}, episodic_return={episodic_return}") + writer.add_scalar("charts/success_rate", info["success"].float().mean().cpu().numpy(), global_step) + writer.add_scalar("charts/episodic_return", episodic_return, global_step) + writer.add_scalar("charts/episodic_length", info["elapsed_steps"], global_step) + + # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` + real_next_obs = next_obs.clone() + if dones.any(): + real_next_obs = infos["final_observation"].clone() + if not args.replay_buffer_on_gpu: + real_next_obs = real_next_obs.cpu().numpy() + obs = obs.cpu().numpy() + actions = actions.cpu().numpy() + rewards = rewards.cpu().numpy() + dones = dones.cpu().numpy() + infos = to_numpy(infos) + rb.add(obs, real_next_obs, actions, rewards, dones, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + for _ in range(args.grad_updates_per_step): + data = rb.sample(args.batch_size) + with torch.no_grad(): + next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations) + qf1_next_target = qf1_target(data.next_observations, next_state_actions) + qf2_next_target = qf2_target(data.next_observations, next_state_actions) + min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi + next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) + + qf1_a_values = qf1(data.observations, data.actions).view(-1) + qf2_a_values = qf2(data.observations, data.actions).view(-1) + qf1_loss = F.mse_loss(qf1_a_values, next_q_value) + qf2_loss = F.mse_loss(qf2_a_values, next_q_value) + qf_loss = qf1_loss + qf2_loss + + # optimize the model + q_optimizer.zero_grad() + qf_loss.backward() + q_optimizer.step() + + if iteration % args.policy_frequency == 0: # TD 3 Delayed update support + for _ in range( + args.policy_frequency + ): # compensate for the delay by doing 'actor_update_interval' instead of 1 + pi, log_pi, _ = actor.get_action(data.observations) + qf1_pi = qf1(data.observations, pi) + qf2_pi = qf2(data.observations, pi) + min_qf_pi = torch.min(qf1_pi, qf2_pi) + actor_loss = ((alpha * log_pi) - min_qf_pi).mean() + + actor_optimizer.zero_grad() + actor_loss.backward() + actor_optimizer.step() + + if args.autotune: + with torch.no_grad(): + _, log_pi, _ = actor.get_action(data.observations) + alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean() + + a_optimizer.zero_grad() + alpha_loss.backward() + a_optimizer.step() + alpha = log_alpha.exp().item() + + # update the target networks + if iteration % args.target_network_frequency == 0: + for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + + if iteration % 10 == 0: + writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) + writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) + writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) + writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) + writer.add_scalar("losses/alpha", alpha, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if args.autotune: + writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step) + + envs.close() + writer.close() \ No newline at end of file diff --git a/examples/tutorials/reinforcement-learning/cleanrl_ppo_liftcube_state_gpu.py b/examples/tutorials/reinforcement-learning/cleanrl_ppo_liftcube_state_gpu.py index 8558ebce6..7e90b3f1c 100644 --- a/examples/tutorials/reinforcement-learning/cleanrl_ppo_liftcube_state_gpu.py +++ b/examples/tutorials/reinforcement-learning/cleanrl_ppo_liftcube_state_gpu.py @@ -1,6 +1,6 @@ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy # python cleanrl_ppo_liftcube_state_gpu.py --num_envs=512 --gamma=0.8 --gae_lambda=0.9 --update_epochs=1 --num_minibatches=128 --env_id="PickCube-v0" --total_timesteps=100000000 -# python cleanrl_ppo_liftcube_state_gpu.py --num_envs=2048 --gamma=0.8 --gae_lambda=0.9 --update_epochs=1 --num_minibatches=32 --env_id="PushCube-v0" --total_timesteps=100000000 --num-steps=12 +# python cleanrl_ppo_liftcube_state_gpu.py --num_envs=512 --gamma=0.8 --gae_lambda=0.9 --update_epochs=8 --target_kl=0.1 --num_minibatches=16 --env_id="PickCube-v0" --total_timesteps=100000000 --num_steps=100 # TODO: train shorter horizon to leverage parallelization more. import os import random diff --git a/mani_skill2/envs/sapien_env.py b/mani_skill2/envs/sapien_env.py index f6851c570..24ed84130 100644 --- a/mani_skill2/envs/sapien_env.py +++ b/mani_skill2/envs/sapien_env.py @@ -641,7 +641,7 @@ def step_action(self, action): self.agent.set_control_mode(action["control_mode"]) self.agent.set_action(action["action"]) set_action = True - elif torch is not None and isinstance(action, torch.Tensor): + elif isinstance(action, torch.Tensor): self.agent.set_action(action) set_action = True else: diff --git a/mani_skill2/vector/wrappers/gymnasium.py b/mani_skill2/vector/wrappers/gymnasium.py new file mode 100644 index 000000000..baf364714 --- /dev/null +++ b/mani_skill2/vector/wrappers/gymnasium.py @@ -0,0 +1,79 @@ +from typing import Dict, List, Optional, Tuple, Union + +import gymnasium as gym +import torch +from gymnasium import Space +from gymnasium.vector import VectorEnv + +from mani_skill2.envs.sapien_env import BaseEnv +from mani_skill2.utils.structs.types import Array + + +class ManiSkillVectorEnv(VectorEnv): + """ + Gymnasium Vector Env implementation for ManiSkill environments running on the GPU for parallel simulation and optionally parallel rendering + + Note that currently this also assumes modeling tasks as infinite horizon (e.g. terminations is always False, only reset when timelimit is reached) + """ + + def __init__( + self, + env: Union[BaseEnv, str], + num_envs: int, + env_kwargs: Dict = dict(), + auto_reset: bool = True, + ): + if isinstance(env, str): + self._env = gym.make(env, num_envs=num_envs, **env_kwargs) + else: + self._env = env + self.auto_reset = auto_reset + super().__init__( + num_envs, self.env.single_observation_space, self.env.single_action_space + ) + + self.returns = torch.zeros(self.num_envs, device=self.env.device) + + @property + def env(self) -> BaseEnv: + return self._env + + def reset( + self, + *, + seed: Optional[Union[int, List[int]]] = None, + options: Optional[dict] = None, + ): + obs, info = self.env.reset(seed=seed, options=options) + self.returns *= 0 + return obs, info + + def step( + self, actions: Union[Array, Dict] + ) -> Tuple[Array, Array, Array, Array, Dict]: + obs, rew, terminations, truncations, infos = self.env.step(actions) + self.returns += rew + infos["episode"] = dict(r=self.returns) + terminations = torch.zeros(self.num_envs, device=self.env.device) + if truncations: + infos["episode"]["r"] = self.returns.clone() + final_obs = obs + obs, _ = self.reset() + new_infos = dict() + new_infos["final_info"] = infos + new_infos["final_observation"] = final_obs + infos = new_infos + truncations = torch.ones_like(terminations) * truncations # gym timelimit wrapper returns a bool, for consistency we convert to a tensor here + return obs, rew, terminations, truncations, infos + + def close(self): + return self.env.close() + + def call(self, name: str, *args, **kwargs): + function = getattr(self.env, name) + return function(*args, **kwargs) + + def get_attr(self, name: str): + raise RuntimeError( + "To get an attribute get it from the .env property of this object" + )