Skip to content

Commit

Permalink
Update wrappers to use __getattr__ instead of redefining attributes (
Browse files Browse the repository at this point in the history
…#1140)

Co-authored-by: elliottower <[email protected]>
  • Loading branch information
ffelten and elliottower authored Nov 28, 2023
1 parent 7a67cde commit 79de877
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 154 deletions.
7 changes: 4 additions & 3 deletions pettingzoo/atari/base_atari_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,17 @@ def __init__(
self._screen = None
self._seed(seed)

def _seed(self, seed=None):
if seed is None:
_, seed = seeding.np_random()
def _seed(self, seed):
self.np_random, seed = seeding.np_random(seed)
self.ale.setInt(b"random_seed", seed)
self.ale.loadROM(self.rom_path)
self.ale.setMode(self.mode)

def reset(self, seed=None, options=None):
if seed is not None:
self._seed(seed=seed)
else:
self.np_random, seed = seeding.np_random()
self.ale.reset_game()
self.agents = self.possible_agents[:]
self.terminations = {agent: False for agent in self.possible_agents}
Expand Down
83 changes: 6 additions & 77 deletions pettingzoo/utils/wrappers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from typing import Any

import gymnasium.spaces
Expand All @@ -19,72 +18,12 @@ def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
super().__init__()
self.env = env

try:
self.possible_agents = self.env.possible_agents
except AttributeError:
pass

self.metadata = self.env.metadata

# we don't want these defined as we don't want them used before they are gotten

# self.agent_selection = self.env.agent_selection

# self.rewards = self.env.rewards
# self.dones = self.env.dones

# we don't want to care one way or the other whether environments have an infos or not before reset
try:
self.infos = self.env.infos
except AttributeError:
pass

# Not every environment has the .state_space attribute implemented
try:
self.state_space = (
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
pass

def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
if name.startswith("_") and name != "_cumulative_rewards":
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

@property
def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environment's `observation_space` method instead"
) from e

@property
def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environment's `action_space` method instead"
) from e

def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.observation_space(agent)

def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.action_space(agent)

@property
def unwrapped(self) -> AECEnv:
return self.env.unwrapped
Expand All @@ -98,14 +37,6 @@ def render(self) -> None | np.ndarray | str | list:
def reset(self, seed: int | None = None, options: dict | None = None):
self.env.reset(seed=seed, options=options)

self.agent_selection = self.env.agent_selection
self.rewards = self.env.rewards
self.terminations = self.env.terminations
self.truncations = self.env.truncations
self.infos = self.env.infos
self.agents = self.env.agents
self._cumulative_rewards = self.env._cumulative_rewards

def observe(self, agent: AgentID) -> ObsType | None:
return self.env.observe(agent)

Expand All @@ -115,13 +46,11 @@ def state(self) -> np.ndarray:
def step(self, action: ActionType) -> None:
self.env.step(action)

self.agent_selection = self.env.agent_selection
self.rewards = self.env.rewards
self.terminations = self.env.terminations
self.truncations = self.env.truncations
self.infos = self.env.infos
self.agents = self.env.agents
self._cumulative_rewards = self.env._cumulative_rewards
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.observation_space(agent)

def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.action_space(agent)

def __str__(self) -> str:
"""Returns a name which looks like: "max_observation<space_invaders_v1>"."""
Expand Down
58 changes: 8 additions & 50 deletions pettingzoo/utils/wrappers/base_parallel.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,26 @@
from __future__ import annotations

import warnings

import gymnasium.spaces
import numpy as np
from gymnasium.utils import seeding

from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv


class BaseParallelWrapper(ParallelEnv[AgentID, ObsType, ActionType]):
def __init__(self, env: ParallelEnv[AgentID, ObsType, ActionType]):
super().__init__()
self.env = env

self.metadata = env.metadata
try:
self.possible_agents = env.possible_agents
except AttributeError:
pass

# Not every environment has the .state_space attribute implemented
try:
self.state_space = (
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
pass
def __getattr__(self, name: str):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
self.np_random, _ = seeding.np_random(seed)

res, info = self.env.reset(seed=seed, options=options)
self.agents = self.env.agents
return res, info
return self.env.reset(seed=seed, options=options)

def step(
self, actions: dict[AgentID, ActionType]
Expand All @@ -45,9 +31,7 @@ def step(
dict[AgentID, bool],
dict[AgentID, dict],
]:
res = self.env.step(actions)
self.agents = self.env.agents
return res
return self.env.step(actions)

def render(self) -> None | np.ndarray | str | list:
return self.env.render()
Expand All @@ -62,32 +46,6 @@ def unwrapped(self) -> ParallelEnv:
def state(self) -> np.ndarray:
return self.env.state()

@property
def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead"
) from e

@property
def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead"
) from e

def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.observation_space(agent)

Expand Down
6 changes: 2 additions & 4 deletions pettingzoo/utils/wrappers/multi_episode_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,20 @@ def step(self, action: ActionType) -> None:
None:
"""
super().step(action)
if self.agents:
if self.env.agents:
return

# if we've crossed num_episodes, truncate all agents
# and let the environment terminate normally
if self._episodes_elapsed >= self._num_episodes:
self.truncations = {agent: True for agent in self.agents}
self.env.unwrapped.truncations = {agent: True for agent in self.env.agents}
return

# if no more agents and haven't had enough episodes,
# increment the number of episodes and the seed for reset
self._episodes_elapsed += 1
self._seed = self._seed + 1 if self._seed else None
super().reset(seed=self._seed, options=self._options)
self.truncations = {agent: False for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}

def __str__(self) -> str:
"""__str__.
Expand Down
31 changes: 18 additions & 13 deletions pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def __getattr__(self, value: str) -> Any:
elif value == "render_mode" and hasattr(self.env, "render_mode"):
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
elif value == "possible_agents":
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
try:
return self.env.possible_agents
except AttributeError:
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
elif value == "observation_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
Expand All @@ -58,20 +61,22 @@ def __getattr__(self, value: str) -> Any:
raise AttributeError(
"agent_order has been removed from the API. Please consider using agent_iter instead."
)
elif value in {
"rewards",
"terminations",
"truncations",
"infos",
"agent_selection",
"num_agents",
"agents",
}:
elif (
value
in {
"rewards",
"terminations",
"truncations",
"infos",
"agent_selection",
"num_agents",
"agents",
}
and not self._has_reset
):
raise AttributeError(f"{value} cannot be accessed before reset")
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{value}'"
)
return super().__getattr__(value)

def render(self) -> None | np.ndarray | str | list:
if not self._has_reset:
Expand Down
10 changes: 5 additions & 5 deletions pettingzoo/utils/wrappers/terminate_illegal.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def step(self, action: ActionType) -> None:
and not _prev_action_mask[action]
):
EnvLogger.warn_on_illegal_move()
self._cumulative_rewards[self.agent_selection] = 0
self.terminations = {d: True for d in self.agents}
self.truncations = {d: True for d in self.agents}
self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0
self.env.unwrapped.terminations = {d: True for d in self.agents}
self.env.unwrapped.truncations = {d: True for d in self.agents}
self._prev_obs = None
self._prev_info = None
self.rewards = {d: 0 for d in self.truncations}
self.rewards[current_agent] = float(self._illegal_value)
self.env.unwrapped.rewards = {d: 0 for d in self.truncations}
self.env.unwrapped.rewards[current_agent] = float(self._illegal_value)
self._accumulate_rewards()
self._deads_step_first()
self._terminated = True
Expand Down
3 changes: 1 addition & 2 deletions test/action_mask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from pettingzoo.test import api_test, seed_test
from pettingzoo.test import seed_test
from pettingzoo.test.example_envs import (
generated_agents_env_action_mask_info_v0,
generated_agents_env_action_mask_obs_v0,
Expand All @@ -20,7 +20,6 @@
def test_action_mask(env_constructor: Type[AECEnv]):
"""Test that environments function deterministically in cases where action mask is in observation, or in info."""
seed_test(env_constructor)
api_test(env_constructor())

# Step through the environment according to example code given in AEC documentation (following action mask)
env = env_constructor()
Expand Down

0 comments on commit 79de877

Please sign in to comment.