Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gym to Gymnasium #6166

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
Updated env.reset() to align with gymnasium standards
Zach-Attach committed Oct 30, 2024
commit 93e6a270f4d0d7b64c28f7143c879fd8e1fda166
21 changes: 12 additions & 9 deletions ml-agents-envs/mlagents_envs/envs/unity_gym_env.py
Original file line number Diff line number Diff line change
@@ -151,11 +151,21 @@ def __init__(
else:
self._observation_space = list_spaces[0] # only return the first one

def reset(self) -> Union[Tuple[List[np.ndarray], Dict], Tuple[np.ndarray, Dict]]:
def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None) -> Union[Tuple[List[np.ndarray], Dict], Tuple[np.ndarray, Dict]]:
"""Resets the state of the environment and returns an initial observation.
Returns: observation (object/list): the initial observation of the
Args:
seed (int, optional): The seed for the environment. Note that this does not set the seed for the Unity Environment.
options (dict, optional): Optional dict containing options for the environment. (Currently not implemented)
Returns:
observation (object/list): the initial observation of the
space.
info (dict): contains auxiliary diagnostic information.
"""
if options is not None:
logger.warning("Options are currently unsupported.")
if seed is not None:
super().reset(seed=seed)
logger.warning("reset(seed) does not change the seed in the Unity Environment or the action space")
self._env.reset()
decision_step, _ = self._env.get_steps(self.name)
n_agents = len(decision_step)
@@ -290,13 +300,6 @@ def close(self) -> None:
"""
self._env.close()

def seed(self, seed: Any = None) -> None:
"""Sets the seed for this env's random number generator(s).
Currently not implemented.
"""
logger.warning("Could not seed environment %s", self.name)
return

@staticmethod
def _check_agents(n_agents: int) -> None:
if n_agents > 1: