From 4daa3c7f818c98f9a67890be2934aed0604a86b8 Mon Sep 17 00:00:00 2001 From: Yu Ishihara Date: Fri, 8 Nov 2024 18:43:48 +0900 Subject: [PATCH] Add support for info elements with tuple values --- nnabla_rl/utils/data.py | 2 ++ tests/utils/test_data.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/nnabla_rl/utils/data.py b/nnabla_rl/utils/data.py index 7a76c0d9..14907d2c 100644 --- a/nnabla_rl/utils/data.py +++ b/nnabla_rl/utils/data.py @@ -81,6 +81,8 @@ def marshal_dict_experiences(dict_experiences: Sequence[Dict[str, Any]]) -> Dict try: if isinstance(data[0], Dict): marshaled_experiences.update({key: marshal_dict_experiences(data)}) + elif isinstance(data[0], tuple): + marshaled_experiences.update({key: marshal_experiences(data)}) else: marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))}) except ValueError as e: diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 181b7ed8..eee9adcb 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -128,6 +128,23 @@ def test_marshal_dict_experiences(self): np.testing.assert_allclose(np.asarray(key1_experiences), 1) np.testing.assert_allclose(np.asarray(key2_experiences), 2) + def test_marshal_dict_experiences_with_tuple_values(self): + experiences = {"key1": (1, 1, 1), "key2": (2, 2)} + dict_experiences = [{"key_parent": experiences}, {"key_parent": experiences}] + marshaled_experience = marshal_dict_experiences(dict_experiences) + + key1_experiences = marshaled_experience["key_parent"]["key1"] + key2_experiences = marshaled_experience["key_parent"]["key2"] + + assert len(key1_experiences) == 3 + assert len(key2_experiences) == 2 + for key1_value in key1_experiences: + assert key1_value.shape == (2, 1) + np.testing.assert_allclose(np.asarray(key1_value), 1) + for key2_value in key2_experiences: + assert key2_value.shape == (2, 1) + np.testing.assert_allclose(np.asarray(key2_value), 2) + def test_marshal_triple_nested_dict_experiences(self): experiences = {"key1": 1, "key2": 2} nested_experiences = {"nest1": experiences, "nest2": experiences}