diff --git a/docs/source/background/introduction.rst b/docs/source/background/introduction.rst index 93a3a2d..66f8271 100644 --- a/docs/source/background/introduction.rst +++ b/docs/source/background/introduction.rst @@ -75,5 +75,6 @@ random walk down a 1D corridor: while not done: action = random.choice(range(len(env.state.children))) - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated total_reward += reward diff --git a/docs/source/examples/hallway.ipynb b/docs/source/examples/hallway.ipynb index a05f659..df1c107 100644 --- a/docs/source/examples/hallway.ipynb +++ b/docs/source/examples/hallway.ipynb @@ -334,8 +334,9 @@ } ], "source": [ - "obs = env.reset()\n", - "print(obs)" + "obs, info = env.reset()\n", + "print(obs)\n", + "print(info)" ] }, { @@ -379,7 +380,7 @@ ], "source": [ "# Not a valid action\n", - "obs, rew, done, info = env.step(1)" + "obs, rew, terminated, truncated, info = env.step(1)" ] }, { @@ -390,7 +391,7 @@ "outputs": [], "source": [ "# A valid action\n", - "obs, rew, done, info = env.step(0)" + "obs, rew, terminated, truncated, info = env.step(0)" ] }, { @@ -504,7 +505,7 @@ "metadata": {}, "outputs": [], "source": [ - "obs, rew, done, info = env.step(1)" + "obs, rew, terminated, truncated, info = env.step(1)" ] }, { @@ -604,7 +605,7 @@ "env.step(0)\n", "\n", "for _ in range(5):\n", - " obs, rew, done, info = env.step(1)\n", + " obs, rew, terminated, truncated, info = env.step(1)\n", "\n", "env.make_observation()" ] diff --git a/docs/source/examples/tsp_docs.ipynb b/docs/source/examples/tsp_docs.ipynb index 9755f38..da34f73 100644 --- a/docs/source/examples/tsp_docs.ipynb +++ b/docs/source/examples/tsp_docs.ipynb @@ -381,7 +381,8 @@ "rand_rew = 0.\n", "while not done:\n", " action = env.action_space.sample()\n", - " _, rew, done, _ = env.step(action)\n", + " _, rew, terminated, truncated, _ = env.step(action)\n", + " done = terminated or truncated\n", " rand_rew += rew\n", " \n", "print(f\"Random reward = {rand_rew}\")\n", @@ -425,7 +426,7 @@ } ], "source": [ - "obs = env.reset()\n", + "obs, info = env.reset()\n", "\n", "done = False\n", "greedy_rew = 0.\n", @@ -433,7 +434,8 @@ "while not done:\n", " # Get the node with shortest distance to the parent (current) node\n", " idx = np.argmin([x[\"parent_dist\"] for x in obs[1:]]) \n", - " obs, rew, done, _ = env.step(idx)\n", + " obs, rew, terminated, truncated, _ = env.step(idx)\n", + " done = terminated or truncated\n", " greedy_rew += rew\n", " \n", "print(f\"Greedy reward = {greedy_rew}\")\n", diff --git a/docs/source/examples/tsp_env.ipynb b/docs/source/examples/tsp_env.ipynb index 51cbc2f..d93b65c 100644 --- a/docs/source/examples/tsp_env.ipynb +++ b/docs/source/examples/tsp_env.ipynb @@ -62,7 +62,7 @@ "%%capture\n", "\n", "# Reset the environment and initialize the observation, reward, and done fields\n", - "obs = env.reset()\n", + "obs, info = env.reset()\n", "greedy_reward = 0\n", "done = False\n", "\n", @@ -74,7 +74,8 @@ "\n", " # Get the observation for the next set of candidate nodes,\n", " # incremental reward, and done flags\n", - " obs, reward, done, info = env.step(action)\n", + " obs, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", "\n", " # Append the step's reward to the running total\n", " greedy_reward += reward\n", @@ -182,7 +183,7 @@ " )[:k]\n", "\n", " for entry in top_actions:\n", - " obs, reward, done, info = entry[\"env\"].step(entry[\"action_index\"])\n", + " obs, reward, terminated, truncated, info = entry[\"env\"].step(entry[\"action_index\"])\n", "\n", " return [(entry[\"env\"], entry[\"reward\"]) for entry in top_actions], done" ] @@ -194,7 +195,7 @@ "metadata": {}, "outputs": [], "source": [ - "obs = env.reset()\n", + "obs, info = env.reset()\n", "env_list = [(env, 0)]\n", "done = False\n", "\n", @@ -212,7 +213,7 @@ "metadata": {}, "outputs": [], "source": [ - "obs = env.reset()\n", + "obs, info = env.reset()\n", "env_list = [(env, 0)]\n", "done = False\n", "\n", diff --git a/experiments/hallway/custom_env.py b/experiments/hallway/custom_env.py index 9b2e024..105806d 100644 --- a/experiments/hallway/custom_env.py +++ b/experiments/hallway/custom_env.py @@ -17,7 +17,7 @@ import os import random -import gym +import gymnasium as gym import ray from gym.spaces import Box, Discrete from ray import tune @@ -87,9 +87,10 @@ def __init__(self, config: EnvContext): # Set the seed. This is only used for the final (reach goal) reward. self.seed(config.worker_index * config.num_workers) - def reset(self): + def reset(self, *, seed=None, options=None): self.cur_pos = 0 - return [self.cur_pos] + info_dict = {} + return [self.cur_pos], info_dict def step(self, action): assert action in [0, 1], action @@ -97,9 +98,16 @@ def step(self, action): self.cur_pos -= 1 elif action == 1: self.cur_pos += 1 - done = self.cur_pos >= self.end_pos + terminated = self.cur_pos >= self.end_pos + truncated = False # Produce a random reward when we reach the goal. - return [self.cur_pos], random.random() * 2 if done else -0.1, done, {} + return ( + [self.cur_pos], + random.random() * 2 if terminated else -0.1, + terminated, + truncated, + {} + ) def seed(self, seed=None): random.seed(seed) diff --git a/experiments/tsp/untrained_model_sampling.ipynb b/experiments/tsp/untrained_model_sampling.ipynb index 156cce7..2ff6dd6 100644 --- a/experiments/tsp/untrained_model_sampling.ipynb +++ b/experiments/tsp/untrained_model_sampling.ipynb @@ -137,7 +137,8 @@ " )\n", " action_probabilities = tf.nn.softmax(masked_action_values).numpy()\n", " action = np.random.choice(env.max_num_children, size=1, p=action_probabilities)[0]\n", - " obs, reward, done, info = env.step(action)\n", + " obs, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", " total_reward += reward\n", " \n", " return total_reward" @@ -213,7 +214,7 @@ } ], "source": [ - "obs = env.reset()\n", + "obs, info = env.reset()\n", "env.observation_space.contains(obs)" ] }, @@ -376,11 +377,12 @@ " # run until episode ends\n", " episode_reward = 0\n", " done = False\n", - " obs = env.reset()\n", + " obs, info = env.reset()\n", "\n", " while not done:\n", " action = agent.compute_single_action(obs)\n", - " obs, reward, done, info = env.step(action)\n", + " obs, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", " episode_reward += reward\n", " \n", " return episode_reward" diff --git a/graphenv/examples/hallway/hallway_state.py b/graphenv/examples/hallway/hallway_state.py index 8d20d43..4d5273d 100644 --- a/graphenv/examples/hallway/hallway_state.py +++ b/graphenv/examples/hallway/hallway_state.py @@ -1,7 +1,7 @@ import random from typing import Dict, Sequence -import gym +import gymnasium as gym import numpy as np from graphenv import tf from graphenv.vertex import Vertex diff --git a/graphenv/examples/tsp/tsp_nfp_state.py b/graphenv/examples/tsp/tsp_nfp_state.py index 4d50b46..aa7da1d 100644 --- a/graphenv/examples/tsp/tsp_nfp_state.py +++ b/graphenv/examples/tsp/tsp_nfp_state.py @@ -1,7 +1,7 @@ from math import sqrt from typing import Dict, Optional -import gym +import gymnasium as gym import numpy as np from graphenv.examples.tsp.tsp_preprocessor import TSPPreprocessor from graphenv.examples.tsp.tsp_state import TSPState diff --git a/graphenv/examples/tsp/tsp_state.py b/graphenv/examples/tsp/tsp_state.py index 2eae36e..c50c8cf 100644 --- a/graphenv/examples/tsp/tsp_state.py +++ b/graphenv/examples/tsp/tsp_state.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence -import gym +import gymnasium as gym import networkx as nx import numpy as np from graphenv import tf diff --git a/graphenv/graph_env.py b/graphenv/graph_env.py index 8dfc3cd..46b6934 100644 --- a/graphenv/graph_env.py +++ b/graphenv/graph_env.py @@ -2,7 +2,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple -import gym +import gymnasium as gym import numpy as np from ray.rllib.env.env_context import EnvContext from ray.rllib.utils.spaces.repeated import Repeated @@ -61,7 +61,7 @@ def __init__(self, env_config: EnvContext) -> None: self.action_space = gym.spaces.Discrete(self.max_num_children) logger.debug("leaving graphenv construction") - def reset(self) -> Dict[str, np.ndarray]: + def reset(self, *, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict]: """Reset this state to the root vertex. It is possible for state.root to return different root vertices on each call. @@ -69,9 +69,9 @@ def reset(self) -> Dict[str, np.ndarray]: Dict[str, np.ndarray]: Observation of the root vertex. """ self.state = self.state.root - return self.make_observation() + return self.make_observation(), self.state.info - def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]: + def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, bool, dict]: """Steps the environment to a new state by taking an action. In the case of GraphEnv, the action specifies which next vertex to move to and this method advances the environment to that vertex. @@ -86,7 +86,8 @@ def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]: Tuple[Dict[str, np.ndarray], float, bool, dict]: Tuple of: a dictionary of the new state's observation, the reward received by moving to the new state's vertex, - a bool which is true iff the new stae is a terminal vertex, + a bool which is true iff the new state is a terminal vertex, + a bool which is true if the search is truncated a dictionary of debugging information related to this call """ @@ -115,10 +116,17 @@ def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]: RuntimeWarning, ) + # In RLlib 2.3, the config options "no_done_at_end", "horizon", and "soft_horizon" are no longer supported + # according to the migration guide https://docs.google.com/document/d/1lxYK1dI5s0Wo_jmB6V6XiP-_aEBsXDykXkD1AXRase4/edit# + # Instead, wrap your gymnasium environment with a TimeLimit wrapper, + # which will set truncated according to the number of timesteps + # see https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.TimeLimit + truncated = False result = ( self.make_observation(), self.state.reward, self.state.terminal, + truncated, self.state.info, ) logger.debug( diff --git a/graphenv/graph_model.py b/graphenv/graph_model.py index b64d1bd..34a3c09 100644 --- a/graphenv/graph_model.py +++ b/graphenv/graph_model.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing import Dict, List, Tuple -import gym +import gymnasium as gym from ray.rllib.models.repeated_values import RepeatedValues from ray.rllib.utils.typing import TensorStructType, TensorType diff --git a/graphenv/vertex.py b/graphenv/vertex.py index a7149c0..2cca659 100644 --- a/graphenv/vertex.py +++ b/graphenv/vertex.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar -import gym +import gymnasium as gym V = TypeVar("V") diff --git a/setup.cfg b/setup.cfg index 815ef68..3741be8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,9 +12,10 @@ classifiers = packages = find: install_requires = networkx==3.0 - ray[tune,rllib]==2.2.0 + ray[tune,rllib]==2.3.1 numpy<1.24.0 tqdm==4.64.1 + matplotlib [options.extras_require] tensorflow = tensorflow diff --git a/tests/test_hallway.py b/tests/test_hallway.py index e8c8658..ba3d372 100644 --- a/tests/test_hallway.py +++ b/tests/test_hallway.py @@ -61,14 +61,14 @@ def test_graphenv_reset(hallway_env: GraphEnv): def test_graphenv_step(hallway_env: GraphEnv): - obs, reward, terminal, info = hallway_env.step(0) + obs, reward, terminal, truncated, info = hallway_env.step(0) for _ in range(3): assert terminal is False assert reward == -0.1 assert hallway_env.observation_space.contains(obs) assert hallway_env.action_space.contains(1) - obs, reward, terminal, info = hallway_env.step(1) + obs, reward, terminal, truncated, info = hallway_env.step(1) assert terminal is True assert reward > 0 diff --git a/tests/test_tsp.py b/tests/test_tsp.py index 1d2d53a..88719ed 100644 --- a/tests/test_tsp.py +++ b/tests/test_tsp.py @@ -27,7 +27,7 @@ def test_graphenv(): } ) - obs = env.reset() + obs, info = env.reset() assert env.observation_space.contains(obs)