Skip to content

Commit

Permalink
Add support for info elements with tuple values
Browse files Browse the repository at this point in the history
  • Loading branch information
ishihara-y committed Nov 8, 2024
1 parent f595164 commit 4daa3c7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nnabla_rl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def marshal_dict_experiences(dict_experiences: Sequence[Dict[str, Any]]) -> Dict
try:
if isinstance(data[0], Dict):
marshaled_experiences.update({key: marshal_dict_experiences(data)})
elif isinstance(data[0], tuple):
marshaled_experiences.update({key: marshal_experiences(data)})
else:
marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))})
except ValueError as e:
Expand Down
17 changes: 17 additions & 0 deletions tests/utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@ def test_marshal_dict_experiences(self):
np.testing.assert_allclose(np.asarray(key1_experiences), 1)
np.testing.assert_allclose(np.asarray(key2_experiences), 2)

def test_marshal_dict_experiences_with_tuple_values(self):
experiences = {"key1": (1, 1, 1), "key2": (2, 2)}
dict_experiences = [{"key_parent": experiences}, {"key_parent": experiences}]
marshaled_experience = marshal_dict_experiences(dict_experiences)

key1_experiences = marshaled_experience["key_parent"]["key1"]
key2_experiences = marshaled_experience["key_parent"]["key2"]

assert len(key1_experiences) == 3
assert len(key2_experiences) == 2
for key1_value in key1_experiences:
assert key1_value.shape == (2, 1)
np.testing.assert_allclose(np.asarray(key1_value), 1)
for key2_value in key2_experiences:
assert key2_value.shape == (2, 1)
np.testing.assert_allclose(np.asarray(key2_value), 2)

def test_marshal_triple_nested_dict_experiences(self):
experiences = {"key1": 1, "key2": 2}
nested_experiences = {"nest1": experiences, "nest2": experiences}
Expand Down

0 comments on commit 4daa3c7

Please sign in to comment.