Skip to content

Commit

Permalink
Support creating variables with tuple states
Browse files Browse the repository at this point in the history
  • Loading branch information
ishihara-y committed Nov 6, 2024
1 parent 8155185 commit 008876a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
10 changes: 7 additions & 3 deletions nnabla_rl/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ def create_variable(batch_size: int, shape: Shape) -> Union[nn.Variable, Tuple[n
def create_variables(batch_size: int, shapes: Dict[str, Tuple[int, ...]]) -> Dict[str, nn.Variable]:
variables: Dict[str, nn.Variale] = {}
for name, shape in shapes.items():
state: nn.Variable = create_variable(batch_size, shape)
state.data.zero()
variables[name] = state
var: nn.Variable = create_variable(batch_size, shape)
if isinstance(var, tuple):
for v in var:
v.data.zero()
else:
var.data.zero()
variables[name] = var
return variables


Expand Down
33 changes: 32 additions & 1 deletion tests/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

import nnabla as nn
from nnabla_rl.utils.misc import create_attention_mask, create_variable
from nnabla_rl.utils.misc import create_attention_mask, create_variable, create_variables


class TestMisc:
Expand Down Expand Up @@ -54,6 +54,37 @@ def test_create_attention_mask(self):
actual_mask = create_attention_mask(num_query, num_key)
assert np.allclose(expected_mask.d, actual_mask.d)

def test_create_variables_int(self):
batch_size = 3
shape = 5
shapes = {"var_int": shape}

actual_vars = create_variables(batch_size, shapes)

for name, shape in shapes.items():
assert actual_vars[name].shape == (batch_size, shape)

def test_create_variables_tuple(self):
batch_size = 3
shape = (5, 6)
shapes = {"var_tuple": shape}

actual_vars = create_variables(batch_size, shapes)

for name, shape in shapes.items():
assert actual_vars[name].shape == (batch_size, *shape)

def test_create_variables_tuples(self):
batch_size = 3
shape = ((6,), (3,))
shapes = {"var_tuples": shape}

actual_vars = create_variables(batch_size, shapes)

for name, shape in shapes.items():
for actual_var, expected_shape in zip(actual_vars[name], shape):
assert actual_var.shape == (batch_size, *expected_shape)


if __name__ == "__main__":
pytest.main()

0 comments on commit 008876a

Please sign in to comment.