diff --git a/nnabla_rl/utils/misc.py b/nnabla_rl/utils/misc.py index 354bf76a..d48e9473 100644 --- a/nnabla_rl/utils/misc.py +++ b/nnabla_rl/utils/misc.py @@ -60,9 +60,13 @@ def create_variable(batch_size: int, shape: Shape) -> Union[nn.Variable, Tuple[n def create_variables(batch_size: int, shapes: Dict[str, Tuple[int, ...]]) -> Dict[str, nn.Variable]: variables: Dict[str, nn.Variale] = {} for name, shape in shapes.items(): - state: nn.Variable = create_variable(batch_size, shape) - state.data.zero() - variables[name] = state + var: nn.Variable = create_variable(batch_size, shape) + if isinstance(var, tuple): + for v in var: + v.data.zero() + else: + var.data.zero() + variables[name] = var return variables diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index d760f537..ac7198e9 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -16,7 +16,7 @@ import pytest import nnabla as nn -from nnabla_rl.utils.misc import create_attention_mask, create_variable +from nnabla_rl.utils.misc import create_attention_mask, create_variable, create_variables class TestMisc: @@ -54,6 +54,37 @@ def test_create_attention_mask(self): actual_mask = create_attention_mask(num_query, num_key) assert np.allclose(expected_mask.d, actual_mask.d) + def test_create_variables_int(self): + batch_size = 3 + shape = 5 + shapes = {"var_int": shape} + + actual_vars = create_variables(batch_size, shapes) + + for name, shape in shapes.items(): + assert actual_vars[name].shape == (batch_size, shape) + + def test_create_variables_tuple(self): + batch_size = 3 + shape = (5, 6) + shapes = {"var_tuple": shape} + + actual_vars = create_variables(batch_size, shapes) + + for name, shape in shapes.items(): + assert actual_vars[name].shape == (batch_size, *shape) + + def test_create_variables_tuples(self): + batch_size = 3 + shape = ((6,), (3,)) + shapes = {"var_tuples": shape} + + actual_vars = create_variables(batch_size, shapes) + + for name, shape in shapes.items(): + for actual_var, expected_shape in zip(actual_vars[name], shape): + assert actual_var.shape == (batch_size, *expected_shape) + if __name__ == "__main__": pytest.main()