From 49e8c12b17fe3b63dd933f67523ab3a07d87f369 Mon Sep 17 00:00:00 2001 From: ariel Date: Mon, 25 Mar 2024 20:52:10 +0100 Subject: [PATCH] Fix buffers to hopefully support nested dict spaces --- cogment_lab/utils/trial_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cogment_lab/utils/trial_utils.py b/cogment_lab/utils/trial_utils.py index f42cdb9..69d0c4c 100644 --- a/cogment_lab/utils/trial_utils.py +++ b/cogment_lab/utils/trial_utils.py @@ -103,9 +103,9 @@ def initialize_buffer(space: gym.Space | None, length: int) -> np.ndarray | dict if space is None: return np.empty((length,), dtype=np.float32) elif isinstance(space, gym.spaces.Dict): - return {key: np.empty((length,) + space[key].shape, dtype=space[key].dtype) for key in space.spaces.keys()} # type: ignore + return {key: initialize_buffer(space[key], length) for key in space.spaces.keys()} # type: ignore elif isinstance(space, gym.spaces.Tuple): - return {i: np.empty((length,) + space[i].shape, dtype=space[i].dtype) for i in range(len(space.spaces))} # type: ignore + return {i: initialize_buffer(space[i], length) for i in range(len(space.spaces))} # type: ignore elif isinstance(space, gym.spaces.Text): return np.empty((length,), dtype="