Skip to content

Commit

Permalink
Fix buffers to hopefully support nested dict spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Mar 25, 2024
1 parent fadc1c8 commit 49e8c12
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cogment_lab/utils/trial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def initialize_buffer(space: gym.Space | None, length: int) -> np.ndarray | dict
if space is None:
return np.empty((length,), dtype=np.float32)
elif isinstance(space, gym.spaces.Dict):
return {key: np.empty((length,) + space[key].shape, dtype=space[key].dtype) for key in space.spaces.keys()} # type: ignore
return {key: initialize_buffer(space[key], length) for key in space.spaces.keys()} # type: ignore
elif isinstance(space, gym.spaces.Tuple):
return {i: np.empty((length,) + space[i].shape, dtype=space[i].dtype) for i in range(len(space.spaces))} # type: ignore
return {i: initialize_buffer(space[i], length) for i in range(len(space.spaces))} # type: ignore
elif isinstance(space, gym.spaces.Text):
return np.empty((length,), dtype="<U" + str(space.max_length))
else: # Simple space
Expand All @@ -130,7 +130,7 @@ def write_to_buffer(
"""
if isinstance(buffer, dict):
for key in buffer.keys():
buffer[key][idx] = data[key] # type: ignore
write_to_buffer(buffer[key], data[key], idx)
else:
buffer[idx] = data

Expand Down

0 comments on commit 49e8c12

Please sign in to comment.