From da5a06a77d44deeb086b9627c1ca2f93301fc55f Mon Sep 17 00:00:00 2001 From: Ectras <40306539+Ectras@users.noreply.github.com> Date: Thu, 28 Mar 2024 12:54:55 +0100 Subject: [PATCH] WPI: parallelized rl --- src/scheduling/learning/train.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/scheduling/learning/train.py b/src/scheduling/learning/train.py index fccc13f..fdda290 100644 --- a/src/scheduling/learning/train.py +++ b/src/scheduling/learning/train.py @@ -4,8 +4,9 @@ import logging from stable_baselines3 import PPO -from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv from gymnasium.wrappers import FlattenObservation import gymnasium as gym @@ -16,7 +17,6 @@ def train_for_settings( settings: list[dict[str, Any]], total_timesteps: int = 100000, - check_env: bool = False, ) -> None: """Train a PPO model for the scheduling environment. @@ -32,18 +32,15 @@ def train_for_settings( """ for i, setting in enumerate(settings): logging.info("Training model for setting %d", i) - env = gym.make("Scheduling-v0", **setting) - env = FlattenObservation(env) - if check_env: - check_env(env) - logging.info("Environment checked successfully. Training model...") + env = make_vec_env("Scheduling-v0", n_envs=2, env_kwargs=setting, wrapper_class=FlattenObservation)#, vec_env_cls=SubprocVecEnv) if i == 0: model = PPO( - "MlpPolicy", env, verbose=1 + "MlpPolicy", env, verbose=1, + n_steps=24 ) # Create a single PPO model instance else: model.set_env(env) - model.learn(total_timesteps) + model.learn(total_timesteps, progress_bar=True) logging.info("Setting completed.") env.close() model.save("ppo_scheduling")