Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] IMPALA and VTrace module #1506

Merged
merged 118 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
550d397
vtrace module
albertbou92 Sep 6, 2023
2a693a7
vtrace module
albertbou92 Sep 6, 2023
7a8ee38
vtrace module
albertbou92 Sep 6, 2023
a6434fd
vtrace module
albertbou92 Sep 6, 2023
3e76450
vtrace module
albertbou92 Sep 6, 2023
7613cb9
vtrace module
albertbou92 Sep 6, 2023
8c70b0a
vtrace module
albertbou92 Sep 6, 2023
e8dd5be
vtrace module
albertbou92 Sep 6, 2023
3ace31f
vtrace module
albertbou92 Sep 7, 2023
3efe601
impala
albertbou92 Sep 7, 2023
2458927
impala example
albertbou92 Sep 7, 2023
c6a60cc
impala example
albertbou92 Sep 7, 2023
f157766
impala example
albertbou92 Sep 7, 2023
f0cf4f5
impala example
albertbou92 Sep 7, 2023
ae6fb62
docs clarifications
albertbou92 Sep 18, 2023
081c2da
docs
albertbou92 Sep 19, 2023
89a5a9e
fixes
albertbou92 Sep 19, 2023
888fbcb
fixes
albertbou92 Sep 19, 2023
f3f9832
config
albertbou92 Sep 19, 2023
5b7d642
fixes
albertbou92 Sep 19, 2023
7dca35e
fixes
albertbou92 Sep 19, 2023
e8c35ef
fixes
albertbou92 Sep 19, 2023
a4af09f
fixes
albertbou92 Sep 20, 2023
8bf4787
move vtrace to adv script
albertbou92 Sep 21, 2023
8648e10
tests
albertbou92 Sep 21, 2023
dfc1c82
tests
albertbou92 Sep 22, 2023
3da748e
tests
albertbou92 Sep 22, 2023
a568378
fix
albertbou92 Sep 22, 2023
596c6cc
format
albertbou92 Sep 22, 2023
ee692f5
working impala script
albertbou92 Sep 25, 2023
a5eb8b6
working impala script
albertbou92 Sep 25, 2023
b9e81d2
test offpolicy losses
albertbou92 Sep 25, 2023
d04d050
minor script fixes
albertbou92 Sep 25, 2023
2a15708
test onpolicy losses
albertbou92 Sep 25, 2023
a5c2046
test fix
albertbou92 Sep 25, 2023
dbde27c
test fix
albertbou92 Sep 25, 2023
7411235
test fix
albertbou92 Sep 26, 2023
fa5f835
test fix
albertbou92 Sep 26, 2023
30e0cc1
test fix
albertbou92 Oct 2, 2023
0b9ed5c
fixes
albertbou92 Oct 2, 2023
15034f8
merge main
albertbou92 Oct 2, 2023
b3c0c9e
multi node
albertbou92 Oct 2, 2023
c634112
multi node
albertbou92 Oct 2, 2023
df16ace
multi node
albertbou92 Oct 2, 2023
3403e29
fix tests
albertbou92 Oct 3, 2023
9da24eb
fix tests
albertbou92 Oct 3, 2023
795620f
fix tests
albertbou92 Oct 3, 2023
bdc2392
Merge branch 'main' into vtrace
albertbou92 Oct 3, 2023
53aceba
merge main
albertbou92 Oct 3, 2023
02cebf6
multinode script
albertbou92 Oct 3, 2023
5c0aec0
call actor func
albertbou92 Oct 3, 2023
c8ef2c7
faster scripts
albertbou92 Oct 3, 2023
e024c09
multinode script
albertbou92 Oct 3, 2023
55b7947
simplify utils
albertbou92 Oct 3, 2023
6d6df00
revert tests
albertbou92 Oct 4, 2023
d4536d1
Merge branch 'main' into vtrace
albertbou92 Oct 4, 2023
5ebcfb8
Merge branch 'main' into vtrace
albertbou92 Oct 4, 2023
9e1d64b
introduce review feedback
albertbou92 Oct 4, 2023
224ae91
torch compile
albertbou92 Oct 4, 2023
e543888
torch compile
albertbou92 Oct 4, 2023
db541c0
fix
albertbou92 Oct 4, 2023
937b819
fix
albertbou92 Oct 4, 2023
9e33035
tests
albertbou92 Oct 4, 2023
199bc3b
adapt ppo tests
albertbou92 Oct 5, 2023
f1b11dd
adapt ppo tests
albertbou92 Oct 5, 2023
1d8d1ef
adapt ppo tests
albertbou92 Oct 5, 2023
ebf74b8
fix tests ppo
albertbou92 Oct 5, 2023
1180993
fix tests a2c
albertbou92 Oct 5, 2023
6e73acd
fix tests a2c
albertbou92 Oct 5, 2023
2ecb103
fix tests a2c
albertbou92 Oct 5, 2023
cd07719
fix tests reinforce
albertbou92 Oct 5, 2023
f1d2770
fix tests values
albertbou92 Oct 5, 2023
676d8f5
fix tests values
albertbou92 Oct 5, 2023
f491e7d
fix tests adv
albertbou92 Oct 5, 2023
7a63dd6
fix tests adv
albertbou92 Oct 5, 2023
d30bb9d
fix tests adv
albertbou92 Oct 5, 2023
e19c671
Merge branch 'main' into vtrace
albertbou92 Oct 5, 2023
a9e1db3
code examples
albertbou92 Oct 5, 2023
32cd518
code examples
albertbou92 Oct 5, 2023
53617bb
fix tests adv
albertbou92 Oct 5, 2023
0bc7b8c
fix tests adv
albertbou92 Oct 5, 2023
40cc02f
code examples tests
albertbou92 Oct 5, 2023
cbd923e
code examples tests
albertbou92 Oct 5, 2023
2235c02
code example with submitit
albertbou92 Oct 5, 2023
1a8efd1
code example with submitit
albertbou92 Oct 5, 2023
3ef4001
code example with submitit
albertbou92 Oct 5, 2023
fcc1121
code example with submitit
albertbou92 Oct 5, 2023
dd2a7f3
code example with submitit
albertbou92 Oct 5, 2023
624b2d6
code example with submitit
albertbou92 Oct 5, 2023
7e30069
code example with submitit
albertbou92 Oct 5, 2023
8d6c064
Merge branch 'main' into vtrace
albertbou92 Oct 5, 2023
157ad9b
Merge branch 'main' into vtrace
albertbou92 Nov 14, 2023
597623b
fix logging
albertbou92 Nov 14, 2023
5c21c1e
fix example
albertbou92 Nov 19, 2023
e47dbc3
fix example
albertbou92 Nov 19, 2023
607ad53
Merge branch 'main' into vtrace
albertbou92 Nov 22, 2023
c23401a
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
886b4e0
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
803fc4f
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
72a3c6e
Update torchrl/objectives/value/advantages.py
albertbou92 Nov 22, 2023
4a061b5
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
e7069e4
Update examples/impala/impala_multi_node_ray.py
albertbou92 Nov 22, 2023
5399cf1
merge main
albertbou92 Nov 22, 2023
39584ab
fixes
albertbou92 Nov 22, 2023
c68fd40
fixes
albertbou92 Nov 22, 2023
638c0d6
format
albertbou92 Nov 22, 2023
2f8b545
fixes
albertbou92 Nov 22, 2023
6ddcb3a
fixes
albertbou92 Nov 22, 2023
94306bf
fixes
albertbou92 Nov 22, 2023
9cc0284
fixes
albertbou92 Nov 23, 2023
63392f0
fixes
albertbou92 Nov 23, 2023
e61f342
fixes
albertbou92 Nov 23, 2023
6d384d5
fixes
albertbou92 Nov 23, 2023
89770e4
submitit example
albertbou92 Nov 23, 2023
9132a60
submitit example
albertbou92 Nov 23, 2023
89a803b
README
albertbou92 Nov 23, 2023
bd02b30
fix tests
albertbou92 Nov 23, 2023
0a382bb
fix unused_args
albertbou92 Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans
# ==================================================================================== #
# ================================ Gymnasium ========================================= #

python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \
collector.total_frames=80 \
collector.frames_per_batch=20 \
collector.num_workers=1 \
logger.backend= \
logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"object_store_memory": 1024**3,
}
collector = RayCollector(
env_makers=[env] * num_collectors,
create_env_fn=[env] * num_collectors,
policy=policy_module,
collector_class=SyncDataCollector,
collector_kwargs={
Expand Down
33 changes: 33 additions & 0 deletions examples/impala/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results

This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018.

## Examples Structure

Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files:

1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py).

2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py).

3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml).


## Running the Examples

You can execute the single node IMPALA algorithm on Atari environments by running the following command:

```bash
python impala_single_node.py
```

You can execute the multi-node IMPALA algorithm on Atari environments by running the following command:

```bash
python impala_single_node_ray.py
```
or

```bash
python impala_single_node_submitit.py
```
65 changes: 65 additions & 0 deletions examples/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
ray_init_config:
address: null
num_cpus: null
num_gpus: null
resources: null
object_store_memory: null
local_mode: False
ignore_reinit_error: False
include_dashboard: null
dashboard_host: 127.0.0.1
dashboard_port: null
job_config: null
configure_logging: True
logging_level: info
logging_format: null
log_to_driver: True
namespace: null
runtime_env: null
storage: null

# Device for the forward and backward passes
local_device: "cuda:0"

# Resources assigned to each IMPALA rollout collection worker
remote_worker_resources:
num_cpus: 1
num_gpus: 0.25
memory: 1073741824 # 1*1024**3 - 1GB

# collector
collector:
frames_per_batch: 80
total_frames: 200_000_000
num_workers: 12

# logger
logger:
backend: wandb
exp_name: Atari_IMPALA
test_interval: 200_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.0006
eps: 1e-8
weight_decay: 0.0
momentum: 0.0
alpha: 0.99
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
batch_size: 32
sgd_updates: 1
critic_coef: 0.5
entropy_coef: 0.01
loss_critic_type: l2
46 changes: 46 additions & 0 deletions examples/impala/config_multi_node_submitit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"

# SLURM config
slurm_config:
timeout_min: 10
slurm_partition: train
slurm_cpus_per_task: 1
slurm_gpus_per_node: 1

# collector
collector:
backend: gloo
frames_per_batch: 80
total_frames: 200_000_000
num_workers: 1

# logger
logger:
backend: wandb
exp_name: Atari_IMPALA
test_interval: 200_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.0006
eps: 1e-8
weight_decay: 0.0
momentum: 0.0
alpha: 0.99
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
batch_size: 32
sgd_updates: 1
critic_coef: 0.5
entropy_coef: 0.01
loss_critic_type: l2
38 changes: 38 additions & 0 deletions examples/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Environment
env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
device: "cuda:0"

# collector
collector:
frames_per_batch: 80
total_frames: 200_000_000
num_workers: 12

# logger
logger:
backend: wandb
exp_name: Atari_IMPALA
test_interval: 200_000_000
num_test_episodes: 3

# Optim
optim:
lr: 0.0006
eps: 1e-8
weight_decay: 0.0
momentum: 0.0
alpha: 0.99
max_grad_norm: 40.0
anneal_lr: True

# loss
loss:
gamma: 0.99
batch_size: 32
sgd_updates: 1
critic_coef: 0.5
entropy_coef: 0.01
loss_critic_type: l2
Loading
Loading