From ab64a60561c4c76ee650426360337bcc16fd094a Mon Sep 17 00:00:00 2001 From: David GERARD Date: Sun, 24 Nov 2024 20:13:12 +0000 Subject: [PATCH] fix: modifications from pre-commit --- .github/workflows/linux-tutorials-test.yml | 4 ++-- .../SB3/connect_four/sb3_connect_four_action_mask.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index 174836b2c..d559302eb 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -15,10 +15,10 @@ jobs: runs-on: ubuntu-latest strategy: fail-fast: false - + matrix: python-version: ['3.8', '3.9', '3.10', '3.11'] - tutorial: [Tianshou, CustomEnvironment, CleanRL, SB3/kaz, SB3/waterworld, SB3/test] # TODO: fix tutorials and add back Ray, fix SB3/connect_four tutorial + tutorial: [Tianshou, CustomEnvironment, CleanRL, SB3/kaz, SB3/waterworld, SB3/test] # TODO: fix tutorials and add back Ray, fix SB3/connect_four tutorial steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py b/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py index 27818640d..e3dc63d34 100644 --- a/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py +++ b/tutorials/SB3/connect_four/sb3_connect_four_action_mask.py @@ -8,8 +8,8 @@ import glob import os import time -import gymnasium as gym +import gymnasium as gym from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.wrappers import ActionMasker @@ -18,7 +18,6 @@ from pettingzoo.classic import connect_four_v3 - class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper): """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking.""" @@ -176,10 +175,11 @@ def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs): if __name__ == "__main__": - if gym.__version__ > "0.29.1": - raise ImportError(f"This script requires gymnasium version 0.29.1 or lower, but you have version {gym.__version__}.") - + raise ImportError( + f"This script requires gymnasium version 0.29.1 or lower, but you have version {gym.__version__}." + ) + env_fn = connect_four_v3 env_kwargs = {}