From fdace2cb59cd43f6960bb771ac8bc24ad8f50b4d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 2 Jul 2024 09:50:17 +0100 Subject: [PATCH] Revert "[BugFix] Fix non-tensor passage in _StepMDP (#2260)" This reverts commit 5fa486ccdeaf179f63f5aeed9213f4da97e985c6. --- test/test_env.py | 51 +----------------------------------- torchrl/envs/batched_envs.py | 3 +-- torchrl/envs/utils.py | 5 ++-- 3 files changed, 4 insertions(+), 55 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index e6ca38b729c..bfda10f0e93 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -59,7 +59,6 @@ dense_stack_tds, LazyStackedTensorDict, TensorDict, - TensorDictBase, ) from tensordict.nn import TensorDictModuleBase from tensordict.utils import _unravel_key_to_tuple @@ -69,7 +68,6 @@ from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, - NonTensorSpec, UnboundedContinuousTensorSpec, ) from torchrl.envs import ( @@ -86,11 +84,7 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv -from torchrl.envs.transforms.transforms import ( - AutoResetEnv, - AutoResetTransform, - Transform, -) +from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform from torchrl.envs.utils import ( _StepMDP, _terminated_or_truncated, @@ -3194,49 +3188,6 @@ def test_parallel(self, bwad, use_buffers): r = env.rollout(N, break_when_any_done=bwad) assert r.get("non_tensor").tolist() == [list(range(N))] * 2 - class AddString(Transform): - def __init__(self): - super().__init__() - self._str = "0" - - def _call(self, td): - td["string"] = str(int(self._str) + 1) - self._str = td["string"] - return td - - def _reset( - self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase - ) -> TensorDictBase: - self._str = "0" - tensordict_reset["string"] = self._str - return tensordict_reset - - def transform_observation_spec(self, observation_spec): - observation_spec["string"] = NonTensorSpec(()) - return observation_spec - - @pytest.mark.parametrize("batched", ["serial", "parallel"]) - def test_partial_rest(self, batched): - env0 = lambda: CountingEnv(5).append_transform(self.AddString()) - env1 = lambda: CountingEnv(6).append_transform(self.AddString()) - if batched == "parallel": - env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx) - else: - env = SerialEnv(2, [env0, env1]) - s = env.reset() - i = 0 - for i in range(10): # noqa: B007 - s, s_ = env.step_and_maybe_reset( - s.set("action", torch.ones(2, 1, dtype=torch.int)) - ) - if s.get(("next", "done")).any(): - break - s = s_ - assert i == 5 - assert (s["next", "done"] == torch.tensor([[True], [False]])).all() - assert s_["string"] == ["0", "6"] - assert s["next", "string"] == ["6", "6"] - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7fb180ac121..500f457ad20 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -988,7 +988,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tds = [] for i, _env in enumerate(self._envs): if not needs_resetting[i]: - if out_tds is not None and tensordict is not None: + if not self._use_buffers and tensordict is not None: out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys) continue if tensordict is not None: @@ -1047,7 +1047,6 @@ def select_and_clone(name, tensor): filter_empty=True, ) if out_tds is not None: - print("out_tds", out_tds) out.update( LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys ) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 61c210acffa..337b7ef8f9e 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -41,7 +41,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.utils import is_non_tensor, NestedKey +from tensordict.utils import NestedKey from torch import nn as nn from torch.utils._pytree import tree_map from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger @@ -254,8 +254,6 @@ def _grab_and_place( if not _allow_absent_keys: raise KeyError(f"key {key} not found.") else: - if is_non_tensor(val): - val = val.clone() data_out._set_str( key, val, validated=True, inplace=False, non_blocking=False ) @@ -1405,6 +1403,7 @@ def _update_during_reset( reset = reset.any(-1) reset = reset.reshape(node.shape) # node.update(node.where(~reset, other=node_reset, pad=0)) + node.where(~reset, other=node_reset, out=node, pad=0) # node = node.clone() # idx = reset.nonzero(as_tuple=True)[0]