Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HyAR #115

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ jobs:
run: |
pip install pybullet
./bin/test_reproductions --gpu_id -1 --base_env pybullet
- name: HyAR reproductions test
run: |
./bin/test_reproductions --gpu_id -1 --base_env hybrid_env --env FakeHybridNNablaRL-v1
copyright:
runs-on: ubuntu-latest
timeout-minutes: 3
Expand Down
8 changes: 8 additions & 0 deletions bin/evaluate_algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ DELAYED_MUJOCO_ENV_LIST=(
"DelayedWalker2d-v1"
)

HYBRID_ENV_LIST=(
"Goal-v0"
"Platform-v0"
)

GPU_ID=0
ALGO_NAME="dqn"
BASE_ENV_NAME="atari"
Expand Down Expand Up @@ -185,6 +190,9 @@ do
if [ $BASE_ENV_NAME = "delayed_mujoco" ]; then
ENV_NAME=${DELAYED_MUJOCO_ENV_LIST[$INDEX]}
fi
if [ $BASE_ENV_NAME = "hybrid_env" ]; then
ENV_NAME=${HYBRID_ENV_LIST[$INDEX]}
fi
echo "Start running training for: " ${ENV_NAME}
if [ -n "$BATCH_SIZE" ]; then
${ROOT_DIR}/bin/train_with_seeds "${REPRODUCTION_CODE_DIR}/${ALGO_NAME}_reproduction.py" $GPU_ID $ENV_NAME $SAVE_DIR $NUM_SEEDS $BATCH_SIZE &
Expand Down
9 changes: 9 additions & 0 deletions bin/test_reproductions
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ do
python ${SCRIPT} --gpu ${GPU_ID} --env ${ENV} --snapshot-dir ${SNAPSHOT_DIR} --showcase \
--showcase_runs ${SHOWCASE_RUNS} --dataset-path ${DATASET_PATH}
fi
elif [ ${ALGORITHM} = "hyar" ]; then
echo "Test run of training for ${ALGORITHM}"
python ${SCRIPT} --gpu ${GPU_ID} --env ${ENV} --save-dir "${RESULT_BASE_DIR}/${ALGORITHM}" --seed ${SEED} \
--total_iterations ${TOTAL_ITERATIONS} --save_timing ${TOTAL_ITERATIONS} \
--vae-pretrain-episodes 1 --vae-pretrain-times 1
SNAPSHOT_DIR="${RESULT_BASE_DIR}/${ALGORITHM}/${ENV}_results/seed-${SEED}/iteration-${TOTAL_ITERATIONS}"
echo "Test run of showcase for ${ALGORITHM}"
python ${SCRIPT} --gpu ${GPU_ID} --env ${ENV} --snapshot-dir ${SNAPSHOT_DIR} --showcase \
--showcase_runs ${SHOWCASE_RUNS}
else
echo "Test run of training for ${ALGORITHM}"
python ${SCRIPT} --gpu ${GPU_ID} --env ${ENV} --save-dir "${RESULT_BASE_DIR}/${ALGORITHM}" --seed ${SEED} \
Expand Down
10 changes: 10 additions & 0 deletions docs/source/nnablarl_api/algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ HER
:members:
:show-inheritance:

HyAR
====
.. autoclass:: nnabla_rl.algorithms.hyar.HyARConfig
:members:
:show-inheritance:

.. autoclass:: nnabla_rl.algorithms.hyar.HyAR
:members:
:show-inheritance:

iLQR
====
.. autoclass:: nnabla_rl.algorithms.ilqr.iLQRConfig
Expand Down
70 changes: 36 additions & 34 deletions nnabla_rl/algorithms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,44 @@ nnabla-rl offers various (deep) reinforcement learning and optimal control algor
- Online training: Training which is performed by interacting with the environment. You'll need to prepare an environment which is compatible with the [OpenAI gym's environment interface](https://gym.openai.com/docs/#environments).
- Offline(Batch) training: Training which is performed sorely from provided data. You'll need to prepare a dataset capsuled with the [ReplayBuffer](../replay_buffer.py).
- Continuous/Discrete action: If you are familiar with the training of deep neural nets, the action type's difference is similar to the difference of regression and classification. Continuous action is an action which consists of real value(s) (e.g. robot's motor torque). In contrast, discrete action is an action which can be labeled (e.g. UP, DOWN, RIGHT, LEFT). The choice of action type depends on the environment (problem) and applicable algorithm changes depending on the its action type.
- Hybrid action: Hybrid action is an environment that requires both discrete and continuous action in pairs.
- RNN layer support: Supports training of network models with recurrent layers.

|Algorithm|Online training|Offline(Batch) training|Continuous action|Discrete action|RNN layer support|
|:---|:---:|:---:|:---:|:---:|:---:|
|[A2C](https://arxiv.org/abs/1602.01783)|:heavy_check_mark:|:x:|(We will support continuous action in the future)|:heavy_check_mark:|:x:|
|[ATRPO](https://arxiv.org/pdf/2106.07329)|:heavy_check_mark:|:x:|:heavy_check_mark:|(We will support discrete action in the future)|:x:|
|[BCQ](https://arxiv.org/abs/1812.02900)|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|
|[BEAR](https://arxiv.org/abs/1906.00949)|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|
|[Categorical DDQN](https://arxiv.org/abs/1710.02298)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[Categorical DQN](https://arxiv.org/abs/1707.06887)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[DDPG](https://arxiv.org/abs/1509.02971)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[DDQN](https://arxiv.org/abs/1509.06461)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[DecisionTransformer](https://arxiv.org/abs/2106.01345)|:x:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|
|[DEMME-SAC](https://arxiv.org/abs/2106.10517)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[DQN](https://www.nature.com/articles/nature14236)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[DRQN](https://arxiv.org/abs/1507.06527)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[GAIL](https://arxiv.org/abs/1606.03476)|:heavy_check_mark:|:x:|:heavy_check_mark:|(We will support discrete action in the future)|:x:|
|[HER](https://arxiv.org/abs/1707.06347)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[IQN](https://arxiv.org/abs/1806.06923)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:<sup>*</sup>|
|[MME-SAC](https://arxiv.org/abs/2106.10517)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[M-DQN](https://proceedings.neurips.cc/paper/2020/file/2c6a0bae0f071cbbf0bb3d5b11d90a82-Paper.pdf)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[M-IQN](https://proceedings.neurips.cc/paper/2020/file/2c6a0bae0f071cbbf0bb3d5b11d90a82-Paper.pdf)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[PPO](https://arxiv.org/abs/1707.06347)|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|
|[QRSAC](https://www.nature.com/articles/s41586-021-04357-7)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[QRDQN](https://arxiv.org/abs/1710.10044)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[QtOpt (ICRA 2018 version)](https://arxiv.org/pdf/1802.10264)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[Rainbow](https://arxiv.org/abs/1710.02298)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|
|[REDQ](https://arxiv.org/abs/2101.05982)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[REINFORCE](https://link.springer.com/content/pdf/10.1007/BF00992696.pdf)|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|
|[SAC](https://arxiv.org/abs/1812.05905)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[SAC (ICML 2018 version)](https://arxiv.org/abs/1801.01290)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[SAC-D](https://arxiv.org/abs/2206.13901)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[TD3](https://arxiv.org/abs/1802.09477)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[TRPO](https://arxiv.org/abs/1502.05477)|:heavy_check_mark:|:x:|:heavy_check_mark:|(We will support discrete action in the future)|:x:|
|[TRPO (ICML 2015 version)](https://arxiv.org/abs/1502.05477)|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|
|[XQL](https://arxiv.org/abs/2301.02328)|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|Algorithm|Online training|Offline(Batch) training|Continuous action|Discrete action|Hybrid action|RNN layer support|
|:---|:---:|:---:|:---:|:---:|:---:|:---:|
|[A2C](https://arxiv.org/abs/1602.01783)|:heavy_check_mark:|:x:|(We will support continuous action in the future)|:heavy_check_mark:|:x:|:x:|
|[ATRPO](https://arxiv.org/pdf/2106.07329)|:heavy_check_mark:|:x:|:heavy_check_mark:|(We will support discrete action in the future)|:x:|:x:|
|[BCQ](https://arxiv.org/abs/1812.02900)|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:x:|
|[BEAR](https://arxiv.org/abs/1906.00949)|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:x:|
|[Categorical DDQN](https://arxiv.org/abs/1710.02298)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[Categorical DQN](https://arxiv.org/abs/1707.06887)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[DDPG](https://arxiv.org/abs/1509.02971)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[DDQN](https://arxiv.org/abs/1509.06461)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[DecisionTransformer](https://arxiv.org/abs/2106.01345)|:x:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|
|[DEMME-SAC](https://arxiv.org/abs/2106.10517)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[DQN](https://www.nature.com/articles/nature14236)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[DRQN](https://arxiv.org/abs/1507.06527)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[GAIL](https://arxiv.org/abs/1606.03476)|:heavy_check_mark:|:x:|:heavy_check_mark:|(We will support discrete action in the future)|:x:|:x:|
|[HER](https://arxiv.org/abs/1707.06347)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[HyAR](https://openreview.net/pdf?id=64trBbOhdGU)|:heavy_check_mark:|:x:|:x:|:x:|:heavy_check_mark:|:x:|
|[IQN](https://arxiv.org/abs/1806.06923)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:<sup>*</sup>|
|[MME-SAC](https://arxiv.org/abs/2106.10517)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[M-DQN](https://proceedings.neurips.cc/paper/2020/file/2c6a0bae0f071cbbf0bb3d5b11d90a82-Paper.pdf)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[M-IQN](https://proceedings.neurips.cc/paper/2020/file/2c6a0bae0f071cbbf0bb3d5b11d90a82-Paper.pdf)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[PPO](https://arxiv.org/abs/1707.06347)|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|
|[QRSAC](https://www.nature.com/articles/s41586-021-04357-7)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[QRDQN](https://arxiv.org/abs/1710.10044)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[QtOpt (ICRA 2018 version)](https://arxiv.org/pdf/1802.10264)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[Rainbow](https://arxiv.org/abs/1710.02298)|:heavy_check_mark:|:heavy_check_mark:|:x:|:heavy_check_mark:|:x:|:heavy_check_mark:|
|[REDQ](https://arxiv.org/abs/2101.05982)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[REINFORCE](https://link.springer.com/content/pdf/10.1007/BF00992696.pdf)|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|
|[SAC](https://arxiv.org/abs/1812.05905)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[SAC (ICML 2018 version)](https://arxiv.org/abs/1801.01290)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[SAC-D](https://arxiv.org/abs/2206.13901)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[TD3](https://arxiv.org/abs/1802.09477)|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|
|[TRPO](https://arxiv.org/abs/1502.05477)|:heavy_check_mark:|:x:|:heavy_check_mark:|(We will support discrete action in the future)|:x:|:x:|
|[TRPO (ICML 2015 version)](https://arxiv.org/abs/1502.05477)|:heavy_check_mark:|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|
|[XQL](https://arxiv.org/abs/2301.02328)|:x:|:heavy_check_mark:|:heavy_check_mark:|:x:|:x:|:heavy_check_mark:|

<sup>*</sup>May require special treatment to train with RNN layers.

Expand Down
2 changes: 2 additions & 0 deletions nnabla_rl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nnabla_rl.algorithms.dummy import Dummy, DummyConfig
from nnabla_rl.algorithms.gail import GAIL, GAILConfig
from nnabla_rl.algorithms.her import HER, HERConfig
from nnabla_rl.algorithms.hyar import HyAR, HyARConfig
from nnabla_rl.algorithms.icml2015_trpo import ICML2015TRPO, ICML2015TRPOConfig
from nnabla_rl.algorithms.icml2018_sac import ICML2018SAC, ICML2018SACConfig
from nnabla_rl.algorithms.icra2018_qtopt import ICRA2018QtOpt, ICRA2018QtOptConfig
Expand Down Expand Up @@ -94,6 +95,7 @@ def get_class_of(name):
register_algorithm(DRQN, DRQNConfig)
register_algorithm(Dummy, DummyConfig)
register_algorithm(HER, HERConfig)
register_algorithm(HyAR, HyARConfig)
register_algorithm(ICML2018SAC, ICML2018SACConfig)
register_algorithm(iLQR, iLQRConfig)
register_algorithm(IQN, IQNConfig)
Expand Down
Loading
Loading