Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/provide access to wrapped attr #1140

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pettingzoo/atari/base_atari_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def __init__(

def _seed(self, seed=None):
if seed is None:
_, seed = seeding.np_random()
self.np_random, seed = seeding.np_random()
else:
self.np_random, seed = seeding.np_random(seed)
ffelten marked this conversation as resolved.
Show resolved Hide resolved
elliottower marked this conversation as resolved.
Show resolved Hide resolved
self.ale.setInt(b"random_seed", seed)
self.ale.loadROM(self.rom_path)
self.ale.setMode(self.mode)
Expand Down
83 changes: 6 additions & 77 deletions pettingzoo/utils/wrappers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from typing import Any

import gymnasium.spaces
Expand All @@ -19,72 +18,12 @@ def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
super().__init__()
self.env = env

try:
self.possible_agents = self.env.possible_agents
except AttributeError:
pass

self.metadata = self.env.metadata

# we don't want these defined as we don't want them used before they are gotten

# self.agent_selection = self.env.agent_selection

# self.rewards = self.env.rewards
# self.dones = self.env.dones

# we don't want to care one way or the other whether environments have an infos or not before reset
try:
self.infos = self.env.infos
except AttributeError:
pass

# Not every environment has the .state_space attribute implemented
try:
self.state_space = (
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
pass

def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
if name.startswith("_") and name != "_cumulative_rewards":
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

@property
ffelten marked this conversation as resolved.
Show resolved Hide resolved
def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environment's `observation_space` method instead"
) from e

@property
def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environment's `action_space` method instead"
) from e

def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.observation_space(agent)

def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.action_space(agent)

@property
def unwrapped(self) -> AECEnv:
return self.env.unwrapped
Expand All @@ -98,14 +37,6 @@ def render(self) -> None | np.ndarray | str | list:
def reset(self, seed: int | None = None, options: dict | None = None):
self.env.reset(seed=seed, options=options)

self.agent_selection = self.env.agent_selection
self.rewards = self.env.rewards
self.terminations = self.env.terminations
self.truncations = self.env.truncations
self.infos = self.env.infos
self.agents = self.env.agents
self._cumulative_rewards = self.env._cumulative_rewards

def observe(self, agent: AgentID) -> ObsType | None:
return self.env.observe(agent)

Expand All @@ -115,13 +46,11 @@ def state(self) -> np.ndarray:
def step(self, action: ActionType) -> None:
self.env.step(action)

self.agent_selection = self.env.agent_selection
self.rewards = self.env.rewards
self.terminations = self.env.terminations
self.truncations = self.env.truncations
self.infos = self.env.infos
self.agents = self.env.agents
self._cumulative_rewards = self.env._cumulative_rewards
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.observation_space(agent)

def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
ffelten marked this conversation as resolved.
Show resolved Hide resolved
return self.env.action_space(agent)

def __str__(self) -> str:
"""Returns a name which looks like: "max_observation<space_invaders_v1>"."""
Expand Down
58 changes: 8 additions & 50 deletions pettingzoo/utils/wrappers/base_parallel.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,26 @@
from __future__ import annotations

import warnings

import gymnasium.spaces
import numpy as np
from gymnasium.utils import seeding

from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv


class BaseParallelWrapper(ParallelEnv[AgentID, ObsType, ActionType]):
def __init__(self, env: ParallelEnv[AgentID, ObsType, ActionType]):
super().__init__()
self.env = env

self.metadata = env.metadata
try:
self.possible_agents = env.possible_agents
except AttributeError:
pass

# Not every environment has the .state_space attribute implemented
try:
self.state_space = (
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
)
except AttributeError:
pass
def __getattr__(self, name: str):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

def reset(
self, seed: int | None = None, options: dict | None = None
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
self.np_random, _ = seeding.np_random(seed)

res, info = self.env.reset(seed=seed, options=options)
self.agents = self.env.agents
return res, info
return self.env.reset(seed=seed, options=options)
elliottower marked this conversation as resolved.
Show resolved Hide resolved

def step(
self, actions: dict[AgentID, ActionType]
Expand All @@ -45,9 +31,7 @@ def step(
dict[AgentID, bool],
dict[AgentID, dict],
]:
res = self.env.step(actions)
self.agents = self.env.agents
return res
return self.env.step(actions)

def render(self) -> None | np.ndarray | str | list:
return self.env.render()
Expand All @@ -62,32 +46,6 @@ def unwrapped(self) -> ParallelEnv:
def state(self) -> np.ndarray:
return self.env.state()

@property
elliottower marked this conversation as resolved.
Show resolved Hide resolved
def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
)
try:
return {
agent: self.observation_space(agent) for agent in self.possible_agents
}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead"
) from e

@property
def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
warnings.warn(
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
)
try:
return {agent: self.action_space(agent) for agent in self.possible_agents}
except AttributeError as e:
raise AttributeError(
"The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead"
) from e

def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return self.env.observation_space(agent)

Expand Down
6 changes: 2 additions & 4 deletions pettingzoo/utils/wrappers/multi_episode_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,20 @@ def step(self, action: ActionType) -> None:
None:
"""
super().step(action)
if self.agents:
if self.env.agents:
return

# if we've crossed num_episodes, truncate all agents
# and let the environment terminate normally
if self._episodes_elapsed >= self._num_episodes:
self.truncations = {agent: True for agent in self.agents}
self.env.unwrapped.truncations = {agent: True for agent in self.env.agents}
return

# if no more agents and haven't had enough episodes,
# increment the number of episodes and the seed for reset
self._episodes_elapsed += 1
self._seed = self._seed + 1 if self._seed else None
super().reset(seed=self._seed, options=self._options)
self.truncations = {agent: False for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}

def __str__(self) -> str:
"""__str__.
Expand Down
31 changes: 18 additions & 13 deletions pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def __getattr__(self, value: str) -> Any:
elif value == "render_mode" and hasattr(self.env, "render_mode"):
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
elif value == "possible_agents":
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
try:
return self.env.possible_agents
except AttributeError:
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
elif value == "observation_spaces":
raise AttributeError(
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
Expand All @@ -58,20 +61,22 @@ def __getattr__(self, value: str) -> Any:
raise AttributeError(
"agent_order has been removed from the API. Please consider using agent_iter instead."
)
elif value in {
"rewards",
"terminations",
"truncations",
"infos",
"agent_selection",
"num_agents",
"agents",
}:
elif (
value
in {
"rewards",
"terminations",
"truncations",
"infos",
"agent_selection",
"num_agents",
"agents",
}
and not self._has_reset
):
raise AttributeError(f"{value} cannot be accessed before reset")
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{value}'"
)
return super().__getattr__(value)

def render(self) -> None | np.ndarray | str | list:
if not self._has_reset:
Expand Down
10 changes: 5 additions & 5 deletions pettingzoo/utils/wrappers/terminate_illegal.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def step(self, action: ActionType) -> None:
and not _prev_action_mask[action]
):
EnvLogger.warn_on_illegal_move()
self._cumulative_rewards[self.agent_selection] = 0
self.terminations = {d: True for d in self.agents}
self.truncations = {d: True for d in self.agents}
self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0
self.env.unwrapped.terminations = {d: True for d in self.agents}
self.env.unwrapped.truncations = {d: True for d in self.agents}
self._prev_obs = None
self._prev_info = None
self.rewards = {d: 0 for d in self.truncations}
self.rewards[current_agent] = float(self._illegal_value)
self.env.unwrapped.rewards = {d: 0 for d in self.truncations}
self.env.unwrapped.rewards[current_agent] = float(self._illegal_value)
self._accumulate_rewards()
self._deads_step_first()
self._terminated = True
Expand Down
Loading