diff --git a/pettingzoo/atari/base_atari_env.py b/pettingzoo/atari/base_atari_env.py index 35dc342f4..05a47e8a8 100644 --- a/pettingzoo/atari/base_atari_env.py +++ b/pettingzoo/atari/base_atari_env.py @@ -169,9 +169,8 @@ def __init__( self._screen = None self._seed(seed) - def _seed(self, seed=None): - if seed is None: - _, seed = seeding.np_random() + def _seed(self, seed): + self.np_random, seed = seeding.np_random(seed) self.ale.setInt(b"random_seed", seed) self.ale.loadROM(self.rom_path) self.ale.setMode(self.mode) @@ -179,6 +178,8 @@ def _seed(self, seed=None): def reset(self, seed=None, options=None): if seed is not None: self._seed(seed=seed) + else: + self.np_random, seed = seeding.np_random() self.ale.reset_game() self.agents = self.possible_agents[:] self.terminations = {agent: False for agent in self.possible_agents} diff --git a/pettingzoo/utils/wrappers/base.py b/pettingzoo/utils/wrappers/base.py index b81d10613..cea324189 100644 --- a/pettingzoo/utils/wrappers/base.py +++ b/pettingzoo/utils/wrappers/base.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from typing import Any import gymnasium.spaces @@ -19,72 +18,12 @@ def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): super().__init__() self.env = env - try: - self.possible_agents = self.env.possible_agents - except AttributeError: - pass - - self.metadata = self.env.metadata - - # we don't want these defined as we don't want them used before they are gotten - - # self.agent_selection = self.env.agent_selection - - # self.rewards = self.env.rewards - # self.dones = self.env.dones - - # we don't want to care one way or the other whether environments have an infos or not before reset - try: - self.infos = self.env.infos - except AttributeError: - pass - - # Not every environment has the .state_space attribute implemented - try: - self.state_space = ( - self.env.state_space # pyright: ignore[reportGeneralTypeIssues] - ) - except AttributeError: - pass - def __getattr__(self, name: str) -> Any: """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" - if name.startswith("_"): + if name.startswith("_") and name != "_cumulative_rewards": raise AttributeError(f"accessing private attribute '{name}' is prohibited") return getattr(self.env, name) - @property - def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]: - warnings.warn( - "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." - ) - try: - return { - agent: self.observation_space(agent) for agent in self.possible_agents - } - except AttributeError as e: - raise AttributeError( - "The base environment does not have an `observation_spaces` dict attribute. Use the environment's `observation_space` method instead" - ) from e - - @property - def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]: - warnings.warn( - "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." - ) - try: - return {agent: self.action_space(agent) for agent in self.possible_agents} - except AttributeError as e: - raise AttributeError( - "The base environment does not have an action_spaces dict attribute. Use the environment's `action_space` method instead" - ) from e - - def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: - return self.env.observation_space(agent) - - def action_space(self, agent: AgentID) -> gymnasium.spaces.Space: - return self.env.action_space(agent) - @property def unwrapped(self) -> AECEnv: return self.env.unwrapped @@ -98,14 +37,6 @@ def render(self) -> None | np.ndarray | str | list: def reset(self, seed: int | None = None, options: dict | None = None): self.env.reset(seed=seed, options=options) - self.agent_selection = self.env.agent_selection - self.rewards = self.env.rewards - self.terminations = self.env.terminations - self.truncations = self.env.truncations - self.infos = self.env.infos - self.agents = self.env.agents - self._cumulative_rewards = self.env._cumulative_rewards - def observe(self, agent: AgentID) -> ObsType | None: return self.env.observe(agent) @@ -115,13 +46,11 @@ def state(self) -> np.ndarray: def step(self, action: ActionType) -> None: self.env.step(action) - self.agent_selection = self.env.agent_selection - self.rewards = self.env.rewards - self.terminations = self.env.terminations - self.truncations = self.env.truncations - self.infos = self.env.infos - self.agents = self.env.agents - self._cumulative_rewards = self.env._cumulative_rewards + def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: + return self.env.observation_space(agent) + + def action_space(self, agent: AgentID) -> gymnasium.spaces.Space: + return self.env.action_space(agent) def __str__(self) -> str: """Returns a name which looks like: "max_observation".""" diff --git a/pettingzoo/utils/wrappers/base_parallel.py b/pettingzoo/utils/wrappers/base_parallel.py index 25b376a2a..499dd46f8 100644 --- a/pettingzoo/utils/wrappers/base_parallel.py +++ b/pettingzoo/utils/wrappers/base_parallel.py @@ -1,40 +1,26 @@ from __future__ import annotations -import warnings - import gymnasium.spaces import numpy as np -from gymnasium.utils import seeding from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv class BaseParallelWrapper(ParallelEnv[AgentID, ObsType, ActionType]): def __init__(self, env: ParallelEnv[AgentID, ObsType, ActionType]): + super().__init__() self.env = env - self.metadata = env.metadata - try: - self.possible_agents = env.possible_agents - except AttributeError: - pass - - # Not every environment has the .state_space attribute implemented - try: - self.state_space = ( - self.env.state_space # pyright: ignore[reportGeneralTypeIssues] - ) - except AttributeError: - pass + def __getattr__(self, name: str): + """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" + if name.startswith("_"): + raise AttributeError(f"accessing private attribute '{name}' is prohibited") + return getattr(self.env, name) def reset( self, seed: int | None = None, options: dict | None = None ) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: - self.np_random, _ = seeding.np_random(seed) - - res, info = self.env.reset(seed=seed, options=options) - self.agents = self.env.agents - return res, info + return self.env.reset(seed=seed, options=options) def step( self, actions: dict[AgentID, ActionType] @@ -45,9 +31,7 @@ def step( dict[AgentID, bool], dict[AgentID, dict], ]: - res = self.env.step(actions) - self.agents = self.env.agents - return res + return self.env.step(actions) def render(self) -> None | np.ndarray | str | list: return self.env.render() @@ -62,32 +46,6 @@ def unwrapped(self) -> ParallelEnv: def state(self) -> np.ndarray: return self.env.state() - @property - def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]: - warnings.warn( - "The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead." - ) - try: - return { - agent: self.observation_space(agent) for agent in self.possible_agents - } - except AttributeError as e: - raise AttributeError( - "The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead" - ) from e - - @property - def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]: - warnings.warn( - "The `action_spaces` dictionary is deprecated. Use the `action_space` function instead." - ) - try: - return {agent: self.action_space(agent) for agent in self.possible_agents} - except AttributeError as e: - raise AttributeError( - "The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead" - ) from e - def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: return self.env.observation_space(agent) diff --git a/pettingzoo/utils/wrappers/multi_episode_env.py b/pettingzoo/utils/wrappers/multi_episode_env.py index d924a0c45..9c233f0c9 100644 --- a/pettingzoo/utils/wrappers/multi_episode_env.py +++ b/pettingzoo/utils/wrappers/multi_episode_env.py @@ -59,13 +59,13 @@ def step(self, action: ActionType) -> None: None: """ super().step(action) - if self.agents: + if self.env.agents: return # if we've crossed num_episodes, truncate all agents # and let the environment terminate normally if self._episodes_elapsed >= self._num_episodes: - self.truncations = {agent: True for agent in self.agents} + self.env.unwrapped.truncations = {agent: True for agent in self.env.agents} return # if no more agents and haven't had enough episodes, @@ -73,8 +73,6 @@ def step(self, action: ActionType) -> None: self._episodes_elapsed += 1 self._seed = self._seed + 1 if self._seed else None super().reset(seed=self._seed, options=self._options) - self.truncations = {agent: False for agent in self.agents} - self.terminations = {agent: False for agent in self.agents} def __str__(self) -> str: """__str__. diff --git a/pettingzoo/utils/wrappers/order_enforcing.py b/pettingzoo/utils/wrappers/order_enforcing.py index 6b78c9d7d..649c23caa 100644 --- a/pettingzoo/utils/wrappers/order_enforcing.py +++ b/pettingzoo/utils/wrappers/order_enforcing.py @@ -45,7 +45,10 @@ def __getattr__(self, value: str) -> Any: elif value == "render_mode" and hasattr(self.env, "render_mode"): return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues] elif value == "possible_agents": - EnvLogger.error_possible_agents_attribute_missing("possible_agents") + try: + return self.env.possible_agents + except AttributeError: + EnvLogger.error_possible_agents_attribute_missing("possible_agents") elif value == "observation_spaces": raise AttributeError( "The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead" @@ -58,20 +61,22 @@ def __getattr__(self, value: str) -> Any: raise AttributeError( "agent_order has been removed from the API. Please consider using agent_iter instead." ) - elif value in { - "rewards", - "terminations", - "truncations", - "infos", - "agent_selection", - "num_agents", - "agents", - }: + elif ( + value + in { + "rewards", + "terminations", + "truncations", + "infos", + "agent_selection", + "num_agents", + "agents", + } + and not self._has_reset + ): raise AttributeError(f"{value} cannot be accessed before reset") else: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{value}'" - ) + return super().__getattr__(value) def render(self) -> None | np.ndarray | str | list: if not self._has_reset: diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index 1522ff503..a49d9a0be 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -67,13 +67,13 @@ def step(self, action: ActionType) -> None: and not _prev_action_mask[action] ): EnvLogger.warn_on_illegal_move() - self._cumulative_rewards[self.agent_selection] = 0 - self.terminations = {d: True for d in self.agents} - self.truncations = {d: True for d in self.agents} + self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0 + self.env.unwrapped.terminations = {d: True for d in self.agents} + self.env.unwrapped.truncations = {d: True for d in self.agents} self._prev_obs = None self._prev_info = None - self.rewards = {d: 0 for d in self.truncations} - self.rewards[current_agent] = float(self._illegal_value) + self.env.unwrapped.rewards = {d: 0 for d in self.truncations} + self.env.unwrapped.rewards[current_agent] = float(self._illegal_value) self._accumulate_rewards() self._deads_step_first() self._terminated = True diff --git a/test/action_mask_test.py b/test/action_mask_test.py index 3155b0845..05717abbd 100644 --- a/test/action_mask_test.py +++ b/test/action_mask_test.py @@ -2,7 +2,7 @@ import pytest -from pettingzoo.test import api_test, seed_test +from pettingzoo.test import seed_test from pettingzoo.test.example_envs import ( generated_agents_env_action_mask_info_v0, generated_agents_env_action_mask_obs_v0, @@ -20,7 +20,6 @@ def test_action_mask(env_constructor: Type[AECEnv]): """Test that environments function deterministically in cases where action mask is in observation, or in info.""" seed_test(env_constructor) - api_test(env_constructor()) # Step through the environment according to example code given in AEC documentation (following action mask) env = env_constructor()