Skip to content

Commit

Permalink
WPI: parallelized rl
Browse files Browse the repository at this point in the history
  • Loading branch information
Ectras committed Mar 28, 2024
1 parent 022829e commit da5a06a
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions src/scheduling/learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import logging

from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv

from gymnasium.wrappers import FlattenObservation
import gymnasium as gym
Expand All @@ -16,7 +17,6 @@
def train_for_settings(
settings: list[dict[str, Any]],
total_timesteps: int = 100000,
check_env: bool = False,
) -> None:
"""Train a PPO model for the scheduling environment.
Expand All @@ -32,18 +32,15 @@ def train_for_settings(
"""
for i, setting in enumerate(settings):
logging.info("Training model for setting %d", i)
env = gym.make("Scheduling-v0", **setting)
env = FlattenObservation(env)
if check_env:
check_env(env)
logging.info("Environment checked successfully. Training model...")
env = make_vec_env("Scheduling-v0", n_envs=2, env_kwargs=setting, wrapper_class=FlattenObservation)#, vec_env_cls=SubprocVecEnv)
if i == 0:
model = PPO(
"MlpPolicy", env, verbose=1
"MlpPolicy", env, verbose=1,
n_steps=24
) # Create a single PPO model instance
else:
model.set_env(env)
model.learn(total_timesteps)
model.learn(total_timesteps, progress_bar=True)
logging.info("Setting completed.")
env.close()
model.save("ppo_scheduling")
Expand Down

0 comments on commit da5a06a

Please sign in to comment.