diff --git a/test/test_env.py b/test/test_env.py index bfda10f0e93..e6ca38b729c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -59,6 +59,7 @@ dense_stack_tds, LazyStackedTensorDict, TensorDict, + TensorDictBase, ) from tensordict.nn import TensorDictModuleBase from tensordict.utils import _unravel_key_to_tuple @@ -68,6 +69,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, + NonTensorSpec, UnboundedContinuousTensorSpec, ) from torchrl.envs import ( @@ -84,7 +86,11 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv -from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform +from torchrl.envs.transforms.transforms import ( + AutoResetEnv, + AutoResetTransform, + Transform, +) from torchrl.envs.utils import ( _StepMDP, _terminated_or_truncated, @@ -3188,6 +3194,49 @@ def test_parallel(self, bwad, use_buffers): r = env.rollout(N, break_when_any_done=bwad) assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + class AddString(Transform): + def __init__(self): + super().__init__() + self._str = "0" + + def _call(self, td): + td["string"] = str(int(self._str) + 1) + self._str = td["string"] + return td + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + self._str = "0" + tensordict_reset["string"] = self._str + return tensordict_reset + + def transform_observation_spec(self, observation_spec): + observation_spec["string"] = NonTensorSpec(()) + return observation_spec + + @pytest.mark.parametrize("batched", ["serial", "parallel"]) + def test_partial_rest(self, batched): + env0 = lambda: CountingEnv(5).append_transform(self.AddString()) + env1 = lambda: CountingEnv(6).append_transform(self.AddString()) + if batched == "parallel": + env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx) + else: + env = SerialEnv(2, [env0, env1]) + s = env.reset() + i = 0 + for i in range(10): # noqa: B007 + s, s_ = env.step_and_maybe_reset( + s.set("action", torch.ones(2, 1, dtype=torch.int)) + ) + if s.get(("next", "done")).any(): + break + s = s_ + assert i == 5 + assert (s["next", "done"] == torch.tensor([[True], [False]])).all() + assert s_["string"] == ["0", "6"] + assert s["next", "string"] == ["6", "6"] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 500f457ad20..7fb180ac121 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -988,7 +988,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tds = [] for i, _env in enumerate(self._envs): if not needs_resetting[i]: - if not self._use_buffers and tensordict is not None: + if out_tds is not None and tensordict is not None: out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys) continue if tensordict is not None: @@ -1047,6 +1047,7 @@ def select_and_clone(name, tensor): filter_empty=True, ) if out_tds is not None: + print("out_tds", out_tds) out.update( LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys ) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 337b7ef8f9e..61c210acffa 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -41,7 +41,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.utils import NestedKey +from tensordict.utils import is_non_tensor, NestedKey from torch import nn as nn from torch.utils._pytree import tree_map from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger @@ -254,6 +254,8 @@ def _grab_and_place( if not _allow_absent_keys: raise KeyError(f"key {key} not found.") else: + if is_non_tensor(val): + val = val.clone() data_out._set_str( key, val, validated=True, inplace=False, non_blocking=False ) @@ -1403,7 +1405,6 @@ def _update_during_reset( reset = reset.any(-1) reset = reset.reshape(node.shape) # node.update(node.where(~reset, other=node_reset, pad=0)) - node.where(~reset, other=node_reset, out=node, pad=0) # node = node.clone() # idx = reset.nonzero(as_tuple=True)[0]