Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update SAC Example #1524

Merged
merged 40 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9045977
fix
BY571 Sep 12, 2023
7d42ba5
update optimizer
BY571 Sep 12, 2023
738c6df
fix
BY571 Sep 13, 2023
e3e4ced
add init alpha option
BY571 Sep 14, 2023
2897340
logalpha fix
BY571 Sep 14, 2023
f83c6e6
naming fixes
BY571 Sep 14, 2023
e58b9b0
fix
BY571 Sep 14, 2023
1956d80
update logging small fixes
BY571 Sep 15, 2023
8a60301
add wd
BY571 Sep 15, 2023
e56b46b
add eps
BY571 Sep 17, 2023
220861a
no eps
BY571 Sep 18, 2023
0974772
undetach q at actorloss
BY571 Sep 18, 2023
ac54930
tests
BY571 Sep 19, 2023
1cfc821
update test
BY571 Sep 19, 2023
500bd5d
update test
BY571 Sep 19, 2023
4b23446
update config, test add set_gym_backend
BY571 Sep 21, 2023
567cd2b
update header
BY571 Sep 21, 2023
e5a96af
Merge remote-tracking branch 'origin/main' into sac_benchmark
vmoens Sep 21, 2023
ede6064
fix max episode steps
BY571 Sep 22, 2023
b2a04e6
update objective
BY571 Sep 26, 2023
1bf7382
update objective
BY571 Sep 26, 2023
a04437d
Merge branch 'main' into sac_benchmark
BY571 Sep 26, 2023
01d6e56
sep critic opti
BY571 Sep 26, 2023
522d061
fixes
BY571 Sep 27, 2023
67e47b6
fix
BY571 Sep 27, 2023
5af2d9a
logexp test
BY571 Sep 27, 2023
7546aad
frameskip weight decay
BY571 Sep 28, 2023
06c2e68
fix frameskip, scratchdir buffer
BY571 Sep 28, 2023
25cf664
update config
BY571 Sep 28, 2023
9a7b0b4
undo stepcount
BY571 Oct 2, 2023
0c8a1c4
merge main
BY571 Oct 3, 2023
0272f03
Merge branch 'main' into sac_benchmark
vmoens Oct 3, 2023
f4f65a5
fix config
BY571 Oct 3, 2023
d0a6fab
merge main
BY571 Oct 3, 2023
b0a3799
amend
vmoens Oct 3, 2023
b758607
amend
vmoens Oct 3, 2023
f0482c7
amend
vmoens Oct 3, 2023
fbbc287
amend
vmoens Oct 3, 2023
b5673fb
empty
vmoens Oct 3, 2023
727776e
Merge branch 'main' into sac_benchmark
vmoens Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,10 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.collector_device=cuda:0 \
optimization.batch_size=10 \
optimization.utd_ratio=1 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
network.device=cuda:0 \
Expand Down Expand Up @@ -221,17 +220,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/sac/sac.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
optim.batch_size=10 \
optim.utd_ratio=1 \
network.device=cuda:0 \
optimization.batch_size=10 \
optimization.utd_ratio=1 \
optim.batch_size=10 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
# record_video=True \
# record_frames=4 \
python .github/unittest/helpers/coverage_run_parallel.py examples/iql/iql_online.py \
total_frames=48 \
batch_size=10 \
Expand Down
40 changes: 20 additions & 20 deletions examples/sac/config.yaml
Original file line number Diff line number Diff line change
@@ -1,49 +1,49 @@
# Environment
# environment and task
env:
name: HalfCheetah-v3
task: ""
exp_name: "HalfCheetah-SAC"
library: gym
frame_skip: 1
seed: 1
exp_name: ${env.name}_SAC
library: gymnasium
max_episode_steps: 1000
seed: 42

# Collection
# collector
collector:
total_frames: 1000000
init_random_frames: 10000
total_frames: 1_000_000
init_random_frames: 25000
frames_per_batch: 1000
max_frames_per_traj: 1000
init_env_steps: 1000
async_collection: 1
collector_device: cpu
env_per_collector: 1
num_workers: 1
reset_at_each_iter: False

# Replay Buffer
# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: ${env.exp_name}_${env.seed}

# Optimization
optimization:
# optim
optim:
utd_ratio: 1.0
gamma: 0.99
loss_function: smooth_l1
lr: 3e-4
weight_decay: 2e-4
lr_scheduler: ""
loss_function: l2
lr: 3.0e-4
weight_decay: 0.0
batch_size: 256
target_update_polyak: 0.995
alpha_init: 1.0
adam_eps: 1.0e-8

# Algorithm
# network
network:
hidden_sizes: [256, 256]
activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"

# Logging
# logging
logger:
backend: wandb
mode: online
Expand Down
158 changes: 94 additions & 64 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
The helper functions are coded in the utils.py associated with this script.
"""

import time

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm

from tensordict import TensorDict
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_environment,
make_loss_module,
Expand All @@ -35,6 +38,7 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("SAC", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
Expand All @@ -48,132 +52,158 @@ def main(cfg: "DictConfig"): # noqa: F821
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create Environments
# Create environments
train_env, eval_env = make_environment(cfg)
# Create Agent

# Create agent
model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)

# Create TD3 loss
# Create SAC loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Make Off-Policy Collector
# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)

# Make Replay Buffer
# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optimization.batch_size,
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir="/tmp/" + cfg.replay_buffer.scratch_dir,
device=device,
)

# Make Optimizers
optimizer = make_sac_optimizer(cfg, loss_module)

rewards = []
rewards_eval = []
# Create optimizers
(
optimizer_actor,
optimizer_critic,
optimizer_alpha,
) = make_sac_optimizer(cfg, loss_module)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
r0 = None
q_loss = None

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optimization.utd_ratio
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
env_per_collector = cfg.collector.env_per_collector
eval_iter = cfg.logger.eval_iter
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip
frames_per_batch = cfg.collector.frames_per_batch
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for i, tensordict in enumerate(collector):
# update weights of the inference policy
sampling_time = time.time() - sampling_start

# Update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

tensordict = tensordict.view(-1)
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(actor_losses, q_losses, alpha_losses) = ([], [], [])
for _ in range(num_updates):
# sample from replay buffer
losses = TensorDict(
{},
batch_size=[
num_updates,
],
)
for i in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]
loss = actor_loss + q_loss + alpha_loss

optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())
alpha_losses.append(alpha_loss.item())
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# update qnet_target params
# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

losses[i] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
).detach()

# Update qnet_target params
target_net_updater.step()

# update priority
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append(
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
}
if q_loss is not None:
train_log.update(
{
"actor_loss": np.mean(actor_losses),
"q_loss": np.mean(q_losses),
"alpha_loss": np.mean(alpha_losses),
"alpha": loss_td["alpha"],
"entropy": loss_td["entropy"],
}
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we group all the logging in a single function, to avoid overloading the training loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, if you can have a look at the DDPG PR I tried to compress the logging but I'm open to other ideas to do it.

episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if logger is not None:
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue").mean().item()
metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item()
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item()
metrics_to_log["train/alpha"] = loss_td["alpha"].item()
metrics_to_log["train/entropy"] = loss_td["entropy"].item()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
if logger is not None:
logger.log_scalar(
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
)
if len(rewards_eval):
pbar.set_description(
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
)
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading
Loading