From 23ca67c2a4b47c1cf1725246931161cc41619241 Mon Sep 17 00:00:00 2001 From: yoavnavon Date: Fri, 22 Jul 2022 21:18:59 -0700 Subject: [PATCH] [Feature] Rename _TensorDict into TensorDictBase (#316) --- test/mocking_classes.py | 18 +- test/test_distributions.py | 4 +- test/test_rb.py | 14 +- test/test_tensor_spec.py | 4 +- test/test_tensordict.py | 6 +- torchrl/collectors/collectors.py | 62 ++- torchrl/collectors/utils.py | 4 +- torchrl/data/postprocs/postprocs.py | 4 +- torchrl/data/replay_buffers/replay_buffers.py | 16 +- torchrl/data/replay_buffers/storages.py | 12 +- torchrl/data/tensor_specs.py | 8 +- torchrl/data/tensordict/tensordict.py | 500 ++++++++++-------- torchrl/data/utils.py | 4 +- torchrl/envs/common.py | 50 +- torchrl/envs/env_creator.py | 4 +- torchrl/envs/gym_like.py | 10 +- torchrl/envs/transforms/transforms.py | 50 +- torchrl/envs/utils.py | 12 +- torchrl/envs/vec_env.py | 8 +- torchrl/modules/models/recipes/impala.py | 4 +- torchrl/modules/tensordict_module/common.py | 20 +- torchrl/modules/tensordict_module/deprec.py | 16 +- .../modules/tensordict_module/exploration.py | 12 +- .../tensordict_module/probabilistic.py | 18 +- torchrl/modules/tensordict_module/sequence.py | 8 +- torchrl/objectives/costs/common.py | 4 +- torchrl/objectives/costs/ddpg.py | 10 +- torchrl/objectives/costs/deprecated.py | 8 +- torchrl/objectives/costs/dqn.py | 8 +- torchrl/objectives/costs/impala.py | 6 +- torchrl/objectives/costs/ppo.py | 14 +- torchrl/objectives/costs/redq.py | 4 +- torchrl/objectives/costs/reinforce.py | 8 +- torchrl/objectives/costs/sac.py | 12 +- torchrl/objectives/costs/utils.py | 6 +- torchrl/objectives/returns/advantages.py | 20 +- torchrl/record/recorder.py | 4 +- torchrl/trainers/helpers/collectors.py | 4 +- torchrl/trainers/trainers.py | 66 +-- 39 files changed, 548 insertions(+), 494 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 23d79e295ad..48b2a3d8270 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -16,7 +16,7 @@ UnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, ) -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.envs.common import _EnvClass spec_dict = { @@ -110,7 +110,7 @@ def _step(self, tensordict): done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict({"reward": n, "done": done, "next_observation": n}, []) - def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self.max_val = max(self.counter + 100, self.counter * 2) n = torch.tensor([self.counter]).to(self.device).to(torch.get_default_dtype()) @@ -118,7 +118,7 @@ def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict({"done": done, "next_observation": n}, []) - def rand_step(self, tensordict: Optional[_TensorDict] = None) -> _TensorDict: + def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: return self.step(tensordict) @@ -144,7 +144,7 @@ def _get_in_obs(self, obs): def _get_out_obs(self, obs): return obs - def _reset(self, tensordict: _TensorDict) -> _TensorDict: + def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: self.counter += 1 state = torch.zeros(self.size) + self.counter tensordict = tensordict.select().set( @@ -156,8 +156,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict: def _step( self, - tensordict: _TensorDict, - ) -> _TensorDict: + tensordict: TensorDictBase, + ) -> TensorDictBase: tensordict = tensordict.to(self.device) a = tensordict.get("action") assert (a.sum(-1) == 1).all() @@ -199,7 +199,7 @@ def _get_in_obs(self, obs): def _get_out_obs(self, obs): return obs - def _reset(self, tensordict: _TensorDict) -> _TensorDict: + def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: self.counter += 1 self.step_count = 0 state = torch.zeros(self.size) + self.counter @@ -211,8 +211,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict: def _step( self, - tensordict: _TensorDict, - ) -> _TensorDict: + tensordict: TensorDictBase, + ) -> TensorDictBase: self.step_count += 1 tensordict = tensordict.to(self.device) a = tensordict.get("action") diff --git a/test/test_distributions.py b/test/test_distributions.py index 43433b95fd3..c36ad2d7b40 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -9,7 +9,7 @@ import torch from _utils_internal import get_available_devices from torch import nn, autograd -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.modules import ( TanhNormal, NormalParamWrapper, @@ -59,7 +59,7 @@ def test_delta(device, div_up, div_down): def _map_all(*tensors_or_other, device): for t in tensors_or_other: - if isinstance(t, (torch.Tensor, _TensorDict)): + if isinstance(t, (torch.Tensor, TensorDictBase)): yield t.to(device) else: yield t diff --git a/test/test_rb.py b/test/test_rb.py index dcb4c81e839..9b9e3fe5517 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -21,7 +21,7 @@ LazyMemmapStorage, LazyTensorStorage, ) -from torchrl.data.tensordict.tensordict import assert_allclose_td, _TensorDict +from torchrl.data.tensordict.tensordict import assert_allclose_td, TensorDictBase collate_fn_dict = { @@ -128,7 +128,7 @@ def test_add(self, rbtype, storage, size, prefetch): data = self._get_datum(rbtype) rb.add(data) s = rb._storage[0] - if isinstance(s, _TensorDict): + if isinstance(s, TensorDictBase): assert (s == data.select(*s.keys())).all() else: assert (s == data).all() @@ -142,12 +142,12 @@ def test_extend(self, rbtype, storage, size, prefetch): for d in data[-length:]: found_similar = False for b in rb._storage: - if isinstance(b, _TensorDict): + if isinstance(b, TensorDictBase): b = b.exclude("index").select(*set(d.keys()).intersection(b.keys())) d = d.select(*set(d.keys()).intersection(b.keys())) value = b == d - if isinstance(value, (torch.Tensor, _TensorDict)): + if isinstance(value, (torch.Tensor, TensorDictBase)): value = value.all() if value: found_similar = True @@ -160,18 +160,18 @@ def test_sample(self, rbtype, storage, size, prefetch): data = self._get_data(rbtype, size=5) rb.extend(data) new_data = rb.sample(3) - if not isinstance(new_data, (torch.Tensor, _TensorDict)): + if not isinstance(new_data, (torch.Tensor, TensorDictBase)): new_data = new_data[0] for d in new_data: found_similar = False for b in data: - if isinstance(b, _TensorDict): + if isinstance(b, TensorDictBase): b = b.exclude("index").select(*set(d.keys()).intersection(b.keys())) d = d.select(*set(d.keys()).intersection(b.keys())) value = b == d - if isinstance(value, (torch.Tensor, _TensorDict)): + if isinstance(value, (torch.Tensor, TensorDictBase)): value = value.all() if value: found_similar = True diff --git a/test/test_tensor_spec.py b/test/test_tensor_spec.py index 0467b6429c6..23ab668f41a 100644 --- a/test/test_tensor_spec.py +++ b/test/test_tensor_spec.py @@ -18,7 +18,7 @@ UnboundedContinuousTensorSpec, OneHotDiscreteTensorSpec, ) -from torchrl.data.tensordict.tensordict import TensorDict, _TensorDict +from torchrl.data.tensordict.tensordict import TensorDict, TensorDictBase @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) @@ -376,7 +376,7 @@ def test_nested_composite_spec(self, is_complete, device, dtype): ts = self._composite_spec(is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(is_complete, device, dtype) td = ts.rand() - assert isinstance(td["nested_cp"], _TensorDict) + assert isinstance(td["nested_cp"], TensorDictBase) keys = list(td.keys()) for key in keys: if key != "nested_cp": diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 2411f8e9b77..f9082e2d019 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -18,7 +18,7 @@ LazyStackedTensorDict, stack as stack_td, pad, - _TensorDict, + TensorDictBase, ) from torchrl.data.tensordict.utils import _getitem_batch_size, convert_ellipsis_to_idx @@ -833,7 +833,7 @@ def test_masking_set(self, td_name, device, from_list): def zeros_like(item, n, d): if isinstance(item, (MemmapTensor, torch.Tensor)): return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device) - elif isinstance(item, _TensorDict): + elif isinstance(item, TensorDictBase): batch_size = item.batch_size batch_size = [n, *batch_size[d:]] out = TensorDict( @@ -1344,7 +1344,7 @@ def test_flatten_keys(self, td_name, device, inplace, separator): td_flatten = td.flatten_keys(inplace=inplace, separator=separator) for key, value in td_flatten.items(): - assert not isinstance(value, _TensorDict) + assert not isinstance(value, TensorDictBase) assert ( separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) in td_flatten.keys() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b09d933f6cc..191a7a3bd78 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -32,7 +32,7 @@ from torchrl.envs.transforms import TransformedEnv from ..data import TensorSpec -from ..data.tensordict.tensordict import _TensorDict, TensorDict +from ..data.tensordict.tensordict import TensorDictBase, TensorDict from ..data.utils import CloudpickleWrapper, DEVICE_TYPING from ..envs.common import _EnvClass from ..envs.vec_env import _BatchedEnv @@ -62,7 +62,7 @@ def __init__(self, action_spec: TensorSpec): """ self.action_spec = action_spec - def __call__(self, td: _TensorDict) -> _TensorDict: + def __call__(self, td: TensorDictBase) -> TensorDictBase: return td.set("action", self.action_spec.rand(td.batch_size)) @@ -81,7 +81,10 @@ class _DataCollector(IterableDataset, metaclass=abc.ABCMeta): def _get_policy_and_device( self, policy: Optional[ - Union[ProbabilisticTensorDictModule, Callable[[_TensorDict], _TensorDict]] + Union[ + ProbabilisticTensorDictModule, + Callable[[TensorDictBase], TensorDictBase], + ] ] = None, device: Optional[DEVICE_TYPING] = None, ) -> Tuple[ @@ -143,11 +146,11 @@ def update_policy_weights_(self) -> None: if self.get_weights_fn is not None: self.policy.load_state_dict(self.get_weights_fn()) - def __iter__(self) -> Iterator[_TensorDict]: + def __iter__(self) -> Iterator[TensorDictBase]: return self.iterator() @abc.abstractmethod - def iterator(self) -> Iterator[_TensorDict]: + def iterator(self) -> Iterator[TensorDictBase]: raise NotImplementedError @abc.abstractmethod @@ -174,7 +177,7 @@ class SyncDataCollector(_DataCollector): Args: create_env_fn (Callable), returns an instance of _EnvClass class. policy (Callable, optional): Policy to be executed in the environment. - Must accept _TensorDict object as input. + Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. The iterator will stop once the total number of frames equates or exceeds the total number of frames passed to the collector. @@ -230,7 +233,10 @@ def __init__( _EnvClass, "EnvCreator", Sequence[Callable[[], _EnvClass]] ], policy: Optional[ - Union[ProbabilisticTensorDictModule, Callable[[_TensorDict], _TensorDict]] + Union[ + ProbabilisticTensorDictModule, + Callable[[TensorDictBase], TensorDictBase], + ] ] = None, total_frames: Optional[int] = -1, create_env_kwargs: Optional[dict] = None, @@ -238,7 +244,7 @@ def __init__( frames_per_batch: int = 200, init_random_frames: int = -1, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[_TensorDict], _TensorDict]] = None, + postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: bool = True, device: DEVICE_TYPING = None, passing_device: DEVICE_TYPING = "cpu", @@ -336,10 +342,10 @@ def set_seed(self, seed: int) -> int: """ return self.env.set_seed(seed) - def iterator(self) -> Iterator[_TensorDict]: + def iterator(self) -> Iterator[TensorDictBase]: """Iterates through the DataCollector. - Yields: _TensorDict objects containing (chunks of) trajectories + Yields: TensorDictBase objects containing (chunks of) trajectories """ total_frames = self.total_frames @@ -371,7 +377,7 @@ def iterator(self) -> Iterator[_TensorDict]: if self._frames >= self.total_frames: break - def _cast_to_policy(self, td: _TensorDict) -> _TensorDict: + def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase: policy_device = self.device if hasattr(self.policy, "in_keys"): td = td.select(*self.policy.in_keys) @@ -384,8 +390,8 @@ def _cast_to_policy(self, td: _TensorDict) -> _TensorDict: return self._td_policy def _cast_to_env( - self, td: _TensorDict, dest: Optional[_TensorDict] = None - ) -> _TensorDict: + self, td: TensorDictBase, dest: Optional[TensorDictBase] = None + ) -> TensorDictBase: env_device = self.env_device if dest is None: if self._td_env is None: @@ -434,11 +440,11 @@ def _reset_if_necessary(self) -> None: self._tensordict.set("step_count", steps) @torch.no_grad() - def rollout(self) -> _TensorDict: + def rollout(self) -> TensorDictBase: """Computes a rollout in the environment using the provided policy. Returns: - _TensorDict containing the computed rollout. + TensorDictBase containing the computed rollout. """ if self.reset_at_each_iter: @@ -573,7 +579,7 @@ class _MultiDataCollector(_DataCollector): Args: create_env_fn (list of Callabled): list of Callables, each returning an instance of _EnvClass policy (Callable, optional): Instance of ProbabilisticTensorDictModule class. - Must accept _TensorDict object as input. + Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of frames may well be greater than this as the closing signals are sent to the workers only once the total number of frames has been collected on the server. @@ -623,7 +629,10 @@ def __init__( self, create_env_fn: Sequence[Callable[[], _EnvClass]], policy: Optional[ - Union[ProbabilisticTensorDictModule, Callable[[_TensorDict], _TensorDict]] + Union[ + ProbabilisticTensorDictModule, + Callable[[TensorDictBase], TensorDictBase], + ] ] = None, total_frames: Optional[int] = -1, create_env_kwargs: Optional[Sequence[dict]] = None, @@ -631,7 +640,7 @@ def __init__( frames_per_batch: int = 200, init_random_frames: int = -1, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[_TensorDict], _TensorDict]] = None, + postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: bool = True, devices: DEVICE_TYPING = None, seed: Optional[int] = None, @@ -918,7 +927,7 @@ def frames_per_batch_worker(self): def _queue_len(self) -> int: return self.num_workers - def iterator(self) -> Iterator[_TensorDict]: + def iterator(self) -> Iterator[TensorDictBase]: i = -1 frames = 0 out_tensordicts_shared = OrderedDict() @@ -1021,7 +1030,7 @@ def __init__(self, *args, **kwargs): def frames_per_batch_worker(self): return self.frames_per_batch - def _get_from_queue(self, timeout=None) -> Tuple[int, int, _TensorDict]: + def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) if j == 0: data, idx = new_data @@ -1036,7 +1045,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, _TensorDict]: def _queue_len(self) -> int: return 1 - def iterator(self) -> Iterator[_TensorDict]: + def iterator(self) -> Iterator[TensorDictBase]: if self.update_at_each_batch: self.update_policy_weights_() @@ -1116,7 +1125,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of _EnvClass policy (Callable, optional): Instance of ProbabilisticTensorDictModule class. - Must accept _TensorDict object as input. + Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of frames may well be greater than this as the closing signals are @@ -1169,7 +1178,10 @@ def __init__( self, create_env_fn: Callable[[], _EnvClass], policy: Optional[ - Union[ProbabilisticTensorDictModule, Callable[[_TensorDict], _TensorDict]] + Union[ + ProbabilisticTensorDictModule, + Callable[[TensorDictBase], TensorDictBase], + ] ] = None, total_frames: Optional[int] = -1, create_env_kwargs: Optional[dict] = None, @@ -1177,7 +1189,7 @@ def __init__( frames_per_batch: int = 200, init_random_frames: int = -1, reset_at_each_iter: bool = False, - postproc: Optional[Callable[[_TensorDict], _TensorDict]] = None, + postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: bool = True, device: Optional[Union[int, str, torch.device]] = None, passing_device: Union[int, str, torch.device] = "cpu", @@ -1208,7 +1220,7 @@ def _main_async_collector( queue_out: queues.Queue, create_env_fn: Union[_EnvClass, "EnvCreator", Callable[[], _EnvClass]], create_env_kwargs: dict, - policy: Callable[[_TensorDict], _TensorDict], + policy: Callable[[TensorDictBase], TensorDictBase], frames_per_worker: int, max_frames_per_traj: int, frames_per_batch: int, diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 560bfedf17c..949f8e6f5c1 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -8,7 +8,7 @@ import torch from torchrl.data import TensorDict -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase def _stack_output(fun) -> Callable: @@ -27,7 +27,7 @@ def stacked_output_fun(*args, **kwargs): return stacked_output_fun -def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict: +def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase: """Takes a tensordict with a key traj_ids that indicates the id of each trajectory. From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T """ diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 504bf8476d3..db9538b0793 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -11,7 +11,7 @@ from torch import nn from torch.nn import functional as F -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.data.utils import expand_as_right __all__ = ["MultiStep"] @@ -139,7 +139,7 @@ def __init__( ).reshape(1, 1, -1), ) - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Args: tensordict: TennsorDict instance with Batch x Time-steps x ... dimensions. diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 97308ba4c41..d5573c23d69 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -27,7 +27,7 @@ to_torch, ) from torchrl.data.tensordict.tensordict import ( - _TensorDict, + TensorDictBase, stack as stack_td, LazyStackedTensorDict, ) @@ -616,7 +616,7 @@ def collate_fn(x): ) self.priority_key = priority_key - def _get_priority(self, tensordict: _TensorDict) -> torch.Tensor: + def _get_priority(self, tensordict: TensorDictBase) -> torch.Tensor: if self.priority_key in tensordict.keys(): if tensordict.batch_dims: tensordict = tensordict.clone(recursive=False) @@ -633,16 +633,16 @@ def _get_priority(self, tensordict: _TensorDict) -> torch.Tensor: priority = self._default_priority return priority - def add(self, tensordict: _TensorDict) -> torch.Tensor: + def add(self, tensordict: TensorDictBase) -> torch.Tensor: priority = self._get_priority(tensordict) index = super().add(tensordict, priority) tensordict.set("index", index) return index def extend( - self, tensordicts: Union[_TensorDict, List[_TensorDict]] + self, tensordicts: Union[TensorDictBase, List[TensorDictBase]] ) -> torch.Tensor: - if isinstance(tensordicts, _TensorDict): + if isinstance(tensordicts, TensorDictBase): if self.priority_key in tensordicts.keys(): priorities = tensordicts.get(self.priority_key) else: @@ -665,7 +665,7 @@ def extend( else: priorities = [self._get_priority(td) for td in tensordicts] - if not isinstance(tensordicts, _TensorDict): + if not isinstance(tensordicts, TensorDictBase): stacked_td = torch.stack(tensordicts, 0) else: stacked_td = tensordicts @@ -673,7 +673,7 @@ def extend( stacked_td.set("index", idx, inplace=True) return idx - def update_priority(self, tensordict: _TensorDict) -> None: + def update_priority(self, tensordict: TensorDictBase) -> None: """Updates the priorities of the tensordicts stored in the replay buffer. @@ -691,7 +691,7 @@ def update_priority(self, tensordict: _TensorDict) -> None: ) return super().update_priority(tensordict.get("index"), priority=priority) - def sample(self, size: int, return_weight: bool = False) -> _TensorDict: + def sample(self, size: int, return_weight: bool = False) -> TensorDictBase: """ Gather a batch of tensordicts according to the non-uniform multinomial distribution with weights computed with the priority_key of each diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index abeac5d4e2d..8ad9c3a9624 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -6,7 +6,7 @@ from torchrl.data.replay_buffers.utils import INT_CLASSES from torchrl.data.tensordict.memmap import MemmapTensor -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict __all__ = ["Storage", "ListStorage", "LazyMemmapStorage", "LazyTensorStorage"] @@ -95,7 +95,7 @@ def __init__(self, size, scratch_dir=None, device=None): self.device = device if device else torch.device("cpu") self._len = 0 - def _init(self, data: Union[_TensorDict, torch.Tensor]) -> None: + def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: print("Creating a TensorStorage...") if isinstance(data, torch.Tensor): # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype @@ -109,7 +109,7 @@ def _init(self, data: Union[_TensorDict, torch.Tensor]) -> None: out = TensorDict({}, [self.size, *data.shape]) print("The storage is being created: ") for key, tensor in data.items(): - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): out[key] = tensor.expand(self.size).clone().zero_() else: out[key] = torch.empty( @@ -125,7 +125,7 @@ def _init(self, data: Union[_TensorDict, torch.Tensor]) -> None: def set( self, cursor: Union[int, Sequence[int], slice], - data: Union[_TensorDict, torch.Tensor], + data: Union[TensorDictBase, torch.Tensor], ): if isinstance(cursor, INT_CLASSES): self._len = max(self._len, cursor + 1) @@ -173,7 +173,7 @@ def __init__(self, size, scratch_dir=None, device=None): self.device = device if device else torch.device("cpu") self._len = 0 - def _init(self, data: Union[_TensorDict, torch.Tensor]) -> None: + def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: print("Creating a MemmapStorage...") if isinstance(data, torch.Tensor): # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype @@ -188,7 +188,7 @@ def _init(self, data: Union[_TensorDict, torch.Tensor]) -> None: out = TensorDict({}, [self.size, *data.shape]) print("The storage is being created: ") for key, tensor in data.items(): - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): out[key] = ( tensor.expand(self.size) .clone() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a49c98e806e..d2e318f2fa0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -36,7 +36,7 @@ "CompositeSpec", ] -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict DEVICE_TYPING = Union[torch.device, str, int] @@ -989,7 +989,7 @@ def __repr__(self) -> str: def type_check( self, - value: Union[torch.Tensor, _TensorDict], + value: Union[torch.Tensor, TensorDictBase], selected_keys: Union[str, Optional[Sequence[str]]] = None, ): if isinstance(value, torch.Tensor) and isinstance(selected_keys, str): @@ -1002,7 +1002,7 @@ def type_check( ): self._specs[_key].type_check(value[_key], _key) - def is_in(self, val: Union[dict, _TensorDict]) -> bool: + def is_in(self, val: Union[dict, TensorDictBase]) -> bool: return all( [ item.is_in(val.get(key)) @@ -1011,7 +1011,7 @@ def is_in(self, val: Union[dict, _TensorDict]) -> bool: ] ) - def project(self, val: _TensorDict) -> _TensorDict: + def project(self, val: TensorDictBase) -> TensorDictBase: for key, item in self.items(): if item is None: continue diff --git a/torchrl/data/tensordict/tensordict.py b/torchrl/data/tensordict/tensordict.py index 80182d03b30..5434beed110 100644 --- a/torchrl/data/tensordict/tensordict.py +++ b/torchrl/data/tensordict/tensordict.py @@ -61,14 +61,14 @@ COMPATIBLE_TYPES = Union[ torch.Tensor, MemmapTensor, -] # None? # leaves space for _TensorDict +] # None? # leaves space for TensorDictBase _STR_MIXED_INDEX_ERROR = "Received a mixed string-non string index. Only string-only or string-free indices are supported." -class _TensorDict(Mapping, metaclass=abc.ABCMeta): +class TensorDictBase(Mapping, metaclass=abc.ABCMeta): """ - _TensorDict is an abstract parent class for TensorDicts, the torchrl + TensorDictBase is an abstract parent class for TensorDicts, the torchrl data container. """ @@ -94,7 +94,7 @@ def _make_meta(self, key: str) -> MetaTensor: @property def shape(self) -> torch.Size: - """See _TensorDict.batch_size""" + """See TensorDictBase.batch_size""" return self.batch_size @property @@ -233,7 +233,7 @@ def _check_device(self) -> None: def set( self, key: str, item: COMPATIBLE_TYPES, inplace: bool = False, **kwargs - ) -> _TensorDict: + ) -> TensorDictBase: """Sets a new key-value pair. Args: @@ -252,7 +252,7 @@ def set( @abc.abstractmethod def set_( self, key: str, item: COMPATIBLE_TYPES, no_check: bool = False - ) -> _TensorDict: + ) -> TensorDictBase: """Sets a value to an existing key while keeping the original storage. Args: @@ -273,7 +273,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: """Stacks a list of values onto an existing key while keeping the original storage. Args: @@ -295,7 +295,7 @@ def _stack_onto_at_( list_item: List[COMPATIBLE_TYPES], dim: int, idx: INDEX_TYPING, - ) -> _TensorDict: + ) -> TensorDictBase: """Similar to _stack_onto_ but on a specific index. Only works with regular TensorDicts.""" raise RuntimeError( f"Cannot call _stack_onto_at_ with {self.__class__.__name__}. " @@ -345,7 +345,7 @@ def _get_meta(self, key) -> MetaTensor: f" {sorted(list(self.keys()))}" ) - def apply_(self, fn: Callable) -> _TensorDict: + def apply_(self, fn: Callable) -> TensorDictBase: """Applies a callable to all values stored in the tensordict and re-writes them in-place. @@ -365,7 +365,7 @@ def apply_(self, fn: Callable) -> _TensorDict: def apply( self, fn: Callable, batch_size: Optional[Sequence[int]] = None - ) -> _TensorDict: + ) -> TensorDictBase: """Applies a callable to all values stored in the tensordict and sets them in a new tensordict. @@ -394,16 +394,16 @@ def apply( def update( self, - input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, inplace: bool = False, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: """Updates the TensorDict with values from either a dictionary or another TensorDict. Args: - input_dict_or_td (_TensorDict or dict): Does not keyword arguments + input_dict_or_td (TensorDictBase or dict): Does not keyword arguments (unlike `dict.update()`). clone (bool, optional): whether the tensors in the input ( tensor) dict should be cloned before being set. Default is @@ -433,9 +433,9 @@ def update( def update_( self, - input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, - ) -> _TensorDict: + ) -> TensorDictBase: """Updates the TensorDict in-place with values from either a dictionary or another TensorDict. @@ -443,7 +443,7 @@ def update_( throw an error if the key is unknown to the TensorDict Args: - input_dict_or_td (_TensorDict or dict): Does not keyword + input_dict_or_td (TensorDictBase or dict): Does not keyword arguments (unlike `dict.update()`). clone (bool, optional): whether the tensors in the input ( tensor) dict should be cloned before being set. Default is @@ -469,10 +469,10 @@ def update_( def update_at_( self, - input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], idx: INDEX_TYPING, clone: bool = False, - ) -> _TensorDict: + ) -> TensorDictBase: """Updates the TensorDict in-place at the specified index with values from either a dictionary or another TensorDict. @@ -480,7 +480,7 @@ def update_at_( key is unknown to the TensorDict. Args: - input_dict_or_td (_TensorDict or dict): Does not keyword arguments + input_dict_or_td (TensorDictBase or dict): Does not keyword arguments (unlike `dict.update()`). idx (int, torch.Tensor, iterable, slice): index of the tensordict where the update should occur. @@ -545,7 +545,7 @@ def _process_tensor( device = self.device tensor = tensor.to(device) elif self._device_safe() is None: - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): device = tensor._device_safe() if device is not None: self.device = device @@ -565,7 +565,7 @@ def _process_tensor( if check_tensor_shape and tensor.shape[: self.batch_dims] != self.batch_size: # if TensorDict, let's try to map it to the desired shape if ( - isinstance(tensor, _TensorDict) + isinstance(tensor, TensorDictBase) and tensor.batch_size[: self.batch_dims] != self.batch_size ): tensor = tensor.clone(recursive=False) @@ -579,14 +579,14 @@ def _process_tensor( # minimum ndimension is 1 if tensor.ndimension() == self.ndimension() and not isinstance( - tensor, _TensorDict + tensor, TensorDictBase ): tensor = tensor.unsqueeze(-1) return tensor @abc.abstractmethod - def pin_memory(self) -> _TensorDict: + def pin_memory(self) -> TensorDictBase: """Calls pin_memory() on the stored tensors.""" raise NotImplementedError(f"{self.__class__.__name__}") @@ -639,7 +639,7 @@ def keys(self) -> KeysView: raise NotImplementedError(f"{self.__class__.__name__}") - def expand(self, *shape: int) -> _TensorDict: + def expand(self, *shape: int) -> TensorDictBase: """Expands each tensors of the tensordict according to `tensor.expand(*shape, *tensor.shape)` @@ -652,7 +652,7 @@ def expand(self, *shape: int) -> _TensorDict: """ d = dict() for key, value in self.items(): - if isinstance(value, _TensorDict): + if isinstance(value, TensorDictBase): d[key] = value.expand(*shape) else: d[key] = value.expand(*shape, *value.shape) @@ -665,7 +665,7 @@ def expand(self, *shape: int) -> _TensorDict: def __bool__(self) -> bool: raise ValueError("Converting a tensordict to boolean value is not permitted") - def __ne__(self, other: object) -> _TensorDict: + def __ne__(self, other: object) -> TensorDictBase: """XOR operation over two tensordicts, for evey key. The two tensordicts must have the same key set. @@ -675,10 +675,10 @@ def __ne__(self, other: object) -> _TensorDict: """ - if not isinstance(other, _TensorDict): + if not isinstance(other, TensorDictBase): raise TypeError( f"TensorDict comparision requires both objects to be " - f"_TensorDict subclass, got {type(other)}" + f"TensorDictBase subclass, got {type(other)}" ) keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -693,7 +693,7 @@ def __ne__(self, other: object) -> _TensorDict: batch_size=self.batch_size, source=d, device=self._device_safe() ) - def __eq__(self, other: object) -> _TensorDict: + def __eq__(self, other: object) -> TensorDictBase: """Compares two tensordicts against each other, for evey key. The two tensordicts must have the same key set. @@ -702,12 +702,12 @@ def __eq__(self, other: object) -> _TensorDict: tensors of the same shape as the original tensors. """ - if not isinstance(other, (_TensorDict, float, int)): + if not isinstance(other, (TensorDictBase, float, int)): raise TypeError( f"TensorDict comparision requires both objects to be " - f"_TensorDict subclass, got {type(other)}" + f"TensorDictBase subclass, got {type(other)}" ) - if not isinstance(other, _TensorDict): + if not isinstance(other, TensorDictBase): return TensorDict( {key: value == other for key, value in self.items()}, self.batch_size ) @@ -723,7 +723,7 @@ def __eq__(self, other: object) -> _TensorDict: ) @abc.abstractmethod - def del_(self, key: str) -> _TensorDict: + def del_(self, key: str) -> TensorDictBase: """Deletes a key of the tensordict. Args: @@ -736,7 +736,7 @@ def del_(self, key: str) -> _TensorDict: raise NotImplementedError(f"{self.__class__.__name__}") @abc.abstractmethod - def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: """Selects the keys of the tensordict and returns an new tensordict with only the selected keys. @@ -755,14 +755,14 @@ def select(self, *keys: str, inplace: bool = False) -> _TensorDict: """ raise NotImplementedError(f"{self.__class__.__name__}") - def exclude(self, *keys: str, inplace: bool = False) -> _TensorDict: + def exclude(self, *keys: str, inplace: bool = False) -> TensorDictBase: keys = [key for key in self.keys() if key not in keys] return self.select(*keys, inplace=inplace) @abc.abstractmethod def set_at_( self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING - ) -> _TensorDict: + ) -> TensorDictBase: """Sets the values in-place at the index indicated by `idx`. Args: @@ -776,12 +776,12 @@ def set_at_( """ raise NotImplementedError(f"{self.__class__.__name__}") - def copy_(self, tensordict: _TensorDict) -> _TensorDict: - """See `_TensorDict.update_`.""" + def copy_(self, tensordict: TensorDictBase) -> TensorDictBase: + """See `TensorDictBase.update_`.""" return self.update_(tensordict) - def copy_at_(self, tensordict: _TensorDict, idx: INDEX_TYPING) -> _TensorDict: - """See `_TensorDict.update_at_`.""" + def copy_at_(self, tensordict: TensorDictBase, idx: INDEX_TYPING) -> TensorDictBase: + """See `TensorDictBase.update_at_`.""" return self.update_at_(tensordict, idx) def get_at( @@ -805,7 +805,7 @@ def get_at( return value @abc.abstractmethod - def share_memory_(self) -> _TensorDict: + def share_memory_(self) -> TensorDictBase: """Places all the tensors in shared memory. Returns: @@ -815,7 +815,7 @@ def share_memory_(self) -> _TensorDict: raise NotImplementedError(f"{self.__class__.__name__}") @abc.abstractmethod - def memmap_(self, prefix=None) -> _TensorDict: + def memmap_(self, prefix=None) -> TensorDictBase: """Writes all tensors onto a MemmapTensor. Args: @@ -830,7 +830,7 @@ def memmap_(self, prefix=None) -> _TensorDict: raise NotImplementedError(f"{self.__class__.__name__}") @abc.abstractmethod - def detach_(self) -> _TensorDict: + def detach_(self) -> TensorDictBase: """Detach the tensors in the tensordict in-place. Returns: @@ -839,7 +839,7 @@ def detach_(self) -> _TensorDict: """ raise NotImplementedError(f"{self.__class__.__name__}") - def detach(self) -> _TensorDict: + def detach(self) -> TensorDictBase: """Detach the tensors in the tensordict. Returns: @@ -854,7 +854,7 @@ def detach(self) -> _TensorDict: ) def to_tensordict(self): - """Returns a regular TensorDict instance from the _TensorDict. + """Returns a regular TensorDict instance from the TensorDictBase. Returns: a new TensorDict object containing the same values. @@ -868,13 +868,13 @@ def to_tensordict(self): _run_checks=False, ) - def zero_(self) -> _TensorDict: + def zero_(self) -> TensorDictBase: """Zeros all tensors in the tensordict in-place.""" for key in self.keys(): self.fill_(key, 0) return self - def unbind(self, dim: int) -> Tuple[_TensorDict, ...]: + def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: """Returns a tuple of indexed tensordicts unbound along the indicated dimension. Resulting tensordicts will share the storage of the initial tensordict. @@ -886,7 +886,7 @@ def unbind(self, dim: int) -> Tuple[_TensorDict, ...]: ] return tuple(self[_idx] for _idx in idx) - def chunk(self, chunks: int, dim: int = 0) -> Tuple[_TensorDict, ...]: + def chunk(self, chunks: int, dim: int = 0) -> Tuple[TensorDictBase, ...]: """Attempts to split a tendordict into the specified number of chunks. Each chunk is a view of the input tensordict. @@ -917,8 +917,8 @@ def chunk(self, chunks: int, dim: int = 0) -> Tuple[_TensorDict, ...]: dim = len(self.batch_size) + dim return tuple(self[(*[slice(None) for _ in range(dim)], idx)] for idx in indices) - def clone(self, recursive: bool = True) -> _TensorDict: - """Clones a _TensorDict subclass instance onto a new TensorDict. + def clone(self, recursive: bool = True) -> TensorDictBase: + """Clones a TensorDictBase subclass instance onto a new TensorDict. Args: recursive (bool, optional): if True, each tensor contained in the @@ -944,20 +944,22 @@ def __torch_function__( if kwargs is None: kwargs = {} if func not in TD_HANDLED_FUNCTIONS or not all( - issubclass(t, (torch.Tensor, _TensorDict)) for t in types + issubclass(t, (torch.Tensor, TensorDictBase)) for t in types ): return NotImplemented return TD_HANDLED_FUNCTIONS[func](*args, **kwargs) @abc.abstractmethod - def to(self, dest: Union[DEVICE_TYPING, Type, torch.Size], **kwargs) -> _TensorDict: - """Maps a _TensorDict subclass either on a new device or to another - _TensorDict subclass (if permitted). Casting tensors to a new dtype + def to( + self, dest: Union[DEVICE_TYPING, Type, torch.Size], **kwargs + ) -> TensorDictBase: + """Maps a TensorDictBase subclass either on a new device or to another + TensorDictBase subclass (if permitted). Casting tensors to a new dtype is not allowed, as tensordicts are not bound to contain a single tensor dtype. Args: - dest (device, size or _TensorDict subclass): destination of the + dest (device, size or TensorDictBase subclass): destination of the tensordict. If it is a torch.Size object, the batch_size will be updated provided that it is compatible with the stored tensors. @@ -988,18 +990,18 @@ def _check_new_batch_size(self, new_size: torch.Size): def _change_batch_size(self, new_size: torch.Size): raise NotImplementedError - def cpu(self) -> _TensorDict: + def cpu(self) -> TensorDictBase: """Casts a tensordict to cpu (if not already on cpu).""" return self.to("cpu") - def cuda(self, device: int = 0) -> _TensorDict: + def cuda(self, device: int = 0) -> TensorDictBase: """Casts a tensordict to a cuda device (if not already on it).""" return self.to(f"cuda:{device}") @abc.abstractmethod def masked_fill_( self, mask: torch.Tensor, value: Union[float, bool] - ) -> _TensorDict: + ) -> TensorDictBase: """Fills the values corresponding to the mask with the desired value. Args: @@ -1023,7 +1025,9 @@ def masked_fill_( raise NotImplementedError @abc.abstractmethod - def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorDict: + def masked_fill( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> TensorDictBase: """Out-of-place version of masked_fill Args: @@ -1046,7 +1050,7 @@ def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorD """ raise NotImplementedError - def masked_select(self, mask: torch.Tensor) -> _TensorDict: + def masked_select(self, mask: torch.Tensor) -> TensorDictBase: """Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values. @@ -1087,7 +1091,7 @@ def is_contiguous(self) -> bool: raise NotImplementedError @abc.abstractmethod - def contiguous(self) -> _TensorDict: + def contiguous(self) -> TensorDictBase: """ Returns: @@ -1107,7 +1111,7 @@ def to_dict(self) -> dict: """ return {key: value for key, value in self.items()} - def unsqueeze(self, dim: int) -> _TensorDict: + def unsqueeze(self, dim: int) -> TensorDictBase: """Unsqueeze all tensors for a dimension comprised in between `-td.batch_dims` and `td.batch_dims` and returns them in a new tensordict. @@ -1133,7 +1137,7 @@ def unsqueeze(self, dim: int) -> _TensorDict: inv_op_kwargs={"dim": dim}, ) - def squeeze(self, dim: int) -> _TensorDict: + def squeeze(self, dim: int) -> TensorDictBase: """Squeezes all tensors for a dimension comprised in between `-td.batch_dims+1` and `td.batch_dims-1` and returns them in a new tensordict. @@ -1166,7 +1170,7 @@ def reshape( self, *shape: int, size: Optional[Union[List, Tuple, torch.Size]] = None, - ) -> _TensorDict: + ) -> TensorDictBase: """Returns a contiguous, reshaped tensor of the desired shape. Args: @@ -1201,7 +1205,7 @@ def view( self, *shape: int, size: Optional[Union[List, Tuple, torch.Size]] = None, - ) -> _TensorDict: + ) -> TensorDictBase: """Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size. @@ -1241,7 +1245,7 @@ def permute( self, *dims_list: int, dims=None, - ) -> _TensorDict: + ) -> TensorDictBase: """Returns a view of a tensordict with the batch dimensions permuted according to dims Args: @@ -1323,7 +1327,7 @@ def __repr__(self) -> str: string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str]) return f"{type(self).__name__}(\n{string})" - def all(self, dim: int = None) -> Union[bool, _TensorDict]: + def all(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if all values are True/non-null in the tensordict. Args: @@ -1349,7 +1353,7 @@ def all(self, dim: int = None) -> Union[bool, _TensorDict]: ) return all(value.all() for value in self.values()) - def any(self, dim: int = None) -> Union[bool, _TensorDict]: + def any(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if any value is True/non-null in the tensordict. Args: @@ -1375,7 +1379,7 @@ def any(self, dim: int = None) -> Union[bool, _TensorDict]: ) return any([value.any() for key, value in self.items()]) - def get_sub_tensordict(self, idx: INDEX_TYPING) -> _TensorDict: + def get_sub_tensordict(self, idx: INDEX_TYPING) -> TensorDictBase: """Returns a SubTensorDict with the desired index.""" sub_td = SubTensorDict( source=self, @@ -1390,7 +1394,9 @@ def __iter__(self) -> Generator: for i in range(length): yield self[i] - def flatten_keys(self, separator: str = ",", inplace: bool = True) -> _TensorDict: + def flatten_keys( + self, separator: str = ",", inplace: bool = True + ) -> TensorDictBase: to_flatten = [] for key, meta_value in self.items_meta(): if meta_value.is_tensordict(): @@ -1421,7 +1427,9 @@ def flatten_keys(self, separator: str = ",", inplace: bool = True) -> _TensorDic tensordict_out.set(key, value) return tensordict_out - def unflatten_keys(self, separator: str = ",", inplace: bool = True) -> _TensorDict: + def unflatten_keys( + self, separator: str = ",", inplace: bool = True + ) -> TensorDictBase: to_unflatten = defaultdict(lambda: list()) for key in self.keys(): if separator in key[1:-1]: @@ -1460,7 +1468,7 @@ def __len__(self) -> int: """ return self.shape[0] if self.batch_dims else 0 - def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict: + def __getitem__(self, idx: INDEX_TYPING) -> TensorDictBase: """Indexes all tensors according to idx and returns a new tensordict where the values share the storage of the original tensors (even when the index is a torch.Tensor). Any in-place modification to the @@ -1536,7 +1544,7 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict: # in all cases not accounted for above return self.get_sub_tensordict(idx) - def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None: + def __setitem__(self, index: INDEX_TYPING, value: TensorDictBase) -> None: if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) if isinstance(index, list): @@ -1592,13 +1600,15 @@ def __setitem__(self, index: INDEX_TYPING, value: _TensorDict) -> None: else: subtd.set(key, item) - def __delitem__(self, index: INDEX_TYPING) -> _TensorDict: + def __delitem__(self, index: INDEX_TYPING) -> TensorDictBase: if isinstance(index, str): return self.del_(index) raise IndexError(f"Index has to a string but received {index}.") @abc.abstractmethod - def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + def rename_key( + self, old_key: str, new_key: str, safe: bool = False + ) -> TensorDictBase: """Renames a key with a new string. Args: @@ -1613,7 +1623,7 @@ def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorD """ raise NotImplementedError - def fill_(self, key: str, value: Union[float, bool]) -> _TensorDict: + def fill_(self, key: str, value: Union[float, bool]) -> TensorDictBase: """Fills a tensor pointed by the key with the a given value. Args: @@ -1637,7 +1647,7 @@ def fill_(self, key: str, value: Union[float, bool]) -> _TensorDict: self.set_(key, tensor) return self - def empty(self) -> _TensorDict: + def empty(self) -> TensorDictBase: """Returns a new, empty tensordict with the same device and batch size.""" return self.select() @@ -1657,7 +1667,7 @@ def is_locked(self, value: bool): self._is_locked = value -class TensorDict(_TensorDict): +class TensorDict(TensorDictBase): """A batched dictionary of tensors. TensorDict is a tensor container where all tensors are stored in a @@ -1746,11 +1756,11 @@ def __new__(cls, *args, **kwargs): cls._lazy = False cls._is_shared = None cls._is_memmap = None - return _TensorDict.__new__(cls) + return TensorDictBase.__new__(cls) def __init__( self, - source: Union[_TensorDict, dict], + source: Union[TensorDictBase, dict], batch_size: Optional[Union[Sequence[int], torch.Size, int]] = None, device: Optional[DEVICE_TYPING] = None, _meta_source: Optional[dict] = None, @@ -1765,9 +1775,9 @@ def __init__( self._is_shared = _is_shared self._is_memmap = _is_memmap - if not isinstance(source, (_TensorDict, dict)): + if not isinstance(source, (TensorDictBase, dict)): raise ValueError( - "A TensorDict source is expected to be a _TensorDict " + "A TensorDict source is expected to be a TensorDictBase " f"sub-type or a dictionary, found type(source)={type(source)}." ) if isinstance( @@ -1784,7 +1794,7 @@ def __init__( batch_size = torch.Size(batch_size) self._batch_size = batch_size - elif isinstance(source, _TensorDict): + elif isinstance(source, TensorDictBase): self._batch_size = source.batch_size else: raise ValueError( @@ -1792,7 +1802,7 @@ def __init__( "instance and it could not be retrieved from source." ) - if isinstance(source, _TensorDict) and device is None: + if isinstance(source, TensorDictBase) and device is None: device = source._device_safe() elif device is not None: device = torch.device(device) @@ -1810,7 +1820,7 @@ def __init__( else _meta_source[key] ) if ( - isinstance(value, _TensorDict) + isinstance(value, TensorDictBase) and value.batch_size[: self.batch_dims] != self.batch_size ): value.batch_size = self.batch_size @@ -1834,7 +1844,7 @@ def _make_meta(self, key: str) -> MetaTensor: proc_value, _is_memmap=is_memmap, _is_shared=is_shared, - _is_tensordict=isinstance(proc_value, _TensorDict), + _is_tensordict=isinstance(proc_value, TensorDictBase), ) @property @@ -1921,23 +1931,23 @@ def _check_device(self) -> None: if len(devices) > 1: raise RuntimeError(f"Found more than one device: {devices}") - def pin_memory(self) -> _TensorDict: + def pin_memory(self) -> TensorDictBase: if self.device == torch.device("cpu"): for key, value in self.items(): - if isinstance(value, _TensorDict) or ( + if isinstance(value, TensorDictBase) or ( value.dtype in (torch.half, torch.float, torch.double) ): self.set(key, value.pin_memory(), inplace=False) return self - def expand(self, *shape: int) -> _TensorDict: + def expand(self, *shape: int) -> TensorDictBase: """Expands every tensor with `(*shape, *tensor.shape)` and returns the same tensordict with new tensors with expanded shapes. """ _batch_size = torch.Size([*shape, *self.batch_size]) d = dict() for key, value in self.items(): - if isinstance(value, _TensorDict): + if isinstance(value, TensorDictBase): d[key] = value.expand(*shape) else: d[key] = value.expand(*shape, *value.shape) @@ -1950,7 +1960,7 @@ def set( inplace: bool = False, _run_checks: bool = True, _meta_val: Optional[MetaTensor] = None, - ) -> _TensorDict: + ) -> TensorDictBase: """Sets a value in the TensorDict. If inplace=True (default is False), and if the key already exists, set will call set_ (in place setting). """ @@ -1987,13 +1997,15 @@ def set( del self._dict_meta[key] return self - def del_(self, key: str) -> _TensorDict: + def del_(self, key: str) -> TensorDictBase: del self._tensordict[key] if key in self._dict_meta: del self._dict_meta[key] return self - def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + def rename_key( + self, old_key: str, new_key: str, safe: bool = False + ) -> TensorDictBase: if not isinstance(old_key, str): raise TypeError( f"Expected old_name to be a string but found {type(old_key)}" @@ -2016,7 +2028,7 @@ def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorD def set_( self, key: str, value: COMPATIBLE_TYPES, no_check: bool = False - ) -> _TensorDict: + ) -> TensorDictBase: if not no_check: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") @@ -2079,7 +2091,7 @@ def _stack_onto_at_( def set_at_( self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING - ) -> _TensorDict: + ) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") if not isinstance(key, str): @@ -2118,7 +2130,7 @@ def get( else: return self._default_get(key, default) - def share_memory_(self) -> _TensorDict: + def share_memory_(self) -> TensorDictBase: if self.is_memmap(): raise RuntimeError( "memmap and shared memory are mutually exclusive features." @@ -2138,12 +2150,12 @@ def share_memory_(self) -> _TensorDict: self._is_shared = True return self - def detach_(self) -> _TensorDict: + def detach_(self) -> TensorDictBase: for key, value in self.items(): value.detach_() return self - def memmap_(self, prefix=None) -> _TensorDict: + def memmap_(self, prefix=None) -> TensorDictBase: if self.is_shared() and self.device == torch.device("cpu"): raise RuntimeError( "memmap and shared memory are mutually exclusive features." @@ -2164,8 +2176,10 @@ def memmap_(self, prefix=None) -> _TensorDict: self._is_memmap = True return self - def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorDict: - if isinstance(dest, type) and issubclass(dest, _TensorDict): + def to( + self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs + ) -> TensorDictBase: + if isinstance(dest, type) and issubclass(dest, TensorDictBase): if isinstance(self, dest): return self td = dest( @@ -2200,25 +2214,27 @@ def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorD def masked_fill_( self, mask: torch.Tensor, value: Union[float, int, bool] - ) -> _TensorDict: + ) -> TensorDictBase: for key, item in self.items(): mask_expand = expand_as_right(mask, item) item.masked_fill_(mask_expand, value) return self - def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorDict: + def masked_fill( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> TensorDictBase: td_copy = self.clone() return td_copy.masked_fill_(mask, value) def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) - def contiguous(self) -> _TensorDict: + def contiguous(self) -> TensorDictBase: if not self.is_contiguous(): return self.clone() return self - def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: d = {key: value for (key, value) in self.items() if key in keys} d_meta = { key: value @@ -2258,14 +2274,16 @@ def decorator(func): # @implements_for_td(torch.testing.assert_allclose) TODO def assert_allclose_td( - actual: _TensorDict, - expected: _TensorDict, + actual: TensorDictBase, + expected: TensorDictBase, rtol: float = None, atol: float = None, equal_nan: bool = True, msg: str = "", ) -> bool: - if not isinstance(actual, _TensorDict) or not isinstance(expected, _TensorDict): + if not isinstance(actual, TensorDictBase) or not isinstance( + expected, TensorDictBase + ): raise TypeError("assert_allclose inputs must be of TensorDict type") set1 = set(actual.keys()) set2 = set(expected.keys()) @@ -2279,7 +2297,7 @@ def assert_allclose_td( for key in keys: input1 = actual.get(key) input2 = expected.get(key) - if isinstance(input1, _TensorDict): + if isinstance(input1, TensorDictBase): assert_allclose_td(input1, input2, rtol=rtol, atol=atol) continue @@ -2302,42 +2320,42 @@ def assert_allclose_td( @implements_for_td(torch.unbind) -def unbind(td: _TensorDict, *args, **kwargs) -> Tuple[_TensorDict, ...]: +def unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]: return td.unbind(*args, **kwargs) @implements_for_td(torch.clone) -def clone(td: _TensorDict, *args, **kwargs) -> _TensorDict: +def clone(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.clone(*args, **kwargs) @implements_for_td(torch.squeeze) -def squeeze(td: _TensorDict, *args, **kwargs) -> _TensorDict: +def squeeze(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.squeeze(*args, **kwargs) @implements_for_td(torch.unsqueeze) -def unsqueeze(td: _TensorDict, *args, **kwargs) -> _TensorDict: +def unsqueeze(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.unsqueeze(*args, **kwargs) @implements_for_td(torch.masked_select) -def masked_select(td: _TensorDict, *args, **kwargs) -> _TensorDict: +def masked_select(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: return td.masked_select(*args, **kwargs) @implements_for_td(torch.permute) -def permute(td: _TensorDict, dims) -> _TensorDict: +def permute(td: TensorDictBase, dims) -> TensorDictBase: return td.permute(*dims) @implements_for_td(torch.cat) def cat( - list_of_tensordicts: Sequence[_TensorDict], + list_of_tensordicts: Sequence[TensorDictBase], dim: int = 0, device: DEVICE_TYPING = None, - out: _TensorDict = None, -) -> _TensorDict: + out: TensorDictBase = None, +) -> TensorDictBase: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") if not device: @@ -2380,12 +2398,12 @@ def cat( @implements_for_td(torch.stack) def stack( - list_of_tensordicts: Sequence[_TensorDict], + list_of_tensordicts: Sequence[TensorDictBase], dim: int = 0, - out: _TensorDict = None, + out: TensorDictBase = None, strict=False, contiguous=False, -) -> _TensorDict: +) -> TensorDictBase: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") batch_size = list_of_tensordicts[0].batch_size @@ -2473,7 +2491,7 @@ def stack( return out -def pad(tensordict: _TensorDict, pad_size: Sequence[int], value: float = 0.0): +def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0): """Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict @@ -2528,7 +2546,7 @@ def pad(tensordict: _TensorDict, pad_size: Sequence[int], value: float = 0.0): if len(pad_size) < len(tensor.shape) * 2: cur_pad = [0] * (len(tensor.shape) * 2 - len(pad_size)) + reverse_pad - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): padded = pad(tensor, pad_size, value) else: padded = torch.nn.functional.pad(tensor, cur_pad, value=value) @@ -2539,10 +2557,10 @@ def pad(tensordict: _TensorDict, pad_size: Sequence[int], value: float = 0.0): # @implements_for_td(torch.nn.utils.rnn.pad_sequence) def pad_sequence_td( - list_of_tensordicts: Sequence[_TensorDict], + list_of_tensordicts: Sequence[TensorDictBase], batch_first: bool = True, padding_value: float = 0.0, - out: _TensorDict = None, + out: TensorDictBase = None, device: Optional[DEVICE_TYPING] = None, ): if not list_of_tensordicts: @@ -2574,7 +2592,7 @@ def pad_sequence_td( return out -class SubTensorDict(_TensorDict): +class SubTensorDict(TensorDictBase): """ A TensorDict that only sees an index of the stored tensors. @@ -2626,7 +2644,7 @@ class SubTensorDict(_TensorDict): def __init__( self, - source: _TensorDict, + source: TensorDictBase, idx: INDEX_TYPING, batch_size: Optional[Sequence[int]] = None, ): @@ -2634,9 +2652,9 @@ def __init__( self._is_shared = None self._is_memmap = None - if not isinstance(source, _TensorDict): + if not isinstance(source, TensorDictBase): raise TypeError( - f"Expected source to be a subclass of _TensorDict, " + f"Expected source to be a subclass of TensorDictBase, " f"got {type(source)}" ) self._source = source @@ -2672,7 +2690,7 @@ def device(self, value: DEVICE_TYPING) -> None: def _device_safe(self) -> Union[None, torch.device]: return self._source._device_safe() - def _preallocate(self, key: str, value: COMPATIBLE_TYPES) -> _TensorDict: + def _preallocate(self, key: str, value: COMPATIBLE_TYPES) -> TensorDictBase: return self._source.set(key, value) def set( @@ -2681,11 +2699,11 @@ def set( tensor: COMPATIBLE_TYPES, inplace: bool = False, _run_checks: bool = True, - ) -> _TensorDict: + ) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") keys = set(self.keys()) - if isinstance(tensor, _TensorDict) and tensor.batch_size != self.batch_size: + if isinstance(tensor, TensorDictBase) and tensor.batch_size != self.batch_size: tensor.batch_size = self.batch_size if inplace and key in keys: return self.set_(key, tensor) @@ -2700,7 +2718,7 @@ def set( ) parent = self.get_parent_tensordict() - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): tensor_expand = TensorDict( { key: _expand_to_match_shape( @@ -2756,8 +2774,10 @@ def _stack_onto_( self._source._stack_onto_at_(key, list_item, dim=dim, idx=self.idx) return self - def to(self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs) -> _TensorDict: - if isinstance(dest, type) and issubclass(dest, _TensorDict): + def to( + self, dest: Union[DEVICE_TYPING, torch.Size, Type], **kwargs + ) -> TensorDictBase: + if isinstance(dest, type) and issubclass(dest, TensorDictBase): if isinstance(self, dest): return self return dest( @@ -2837,7 +2857,7 @@ def get_at( def update_( self, - input_dict: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, ) -> SubTensorDict: return self.update_at_( @@ -2846,7 +2866,7 @@ def update_( def update_at_( self, - input_dict: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], idx: INDEX_TYPING, discard_idx_attr: bool = False, clone: bool = False, @@ -2867,8 +2887,8 @@ def update_at_( ) return self - def get_parent_tensordict(self) -> _TensorDict: - if not isinstance(self._source, _TensorDict): + def get_parent_tensordict(self) -> TensorDictBase: + if not isinstance(self._source, TensorDictBase): raise TypeError( f"SubTensorDict was initialized with a source of type" f" {self._source.__class__.__name__}, " @@ -2876,7 +2896,7 @@ def get_parent_tensordict(self) -> _TensorDict: ) return self._source - def del_(self, key: str) -> _TensorDict: + def del_(self, key: str) -> TensorDictBase: self._source = self._source.del_(key) return self @@ -2891,7 +2911,7 @@ def clone(self, recursive: bool = True) -> SubTensorDict: def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) - def contiguous(self) -> _TensorDict: + def contiguous(self) -> TensorDictBase: if self.is_contiguous(): return self return TensorDict( @@ -2900,13 +2920,13 @@ def contiguous(self) -> _TensorDict: device=self._device_safe(), ) - def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: if inplace: self._source = self._source.select(*keys) return self return self._source.select(*keys)[self.idx] - def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: + def expand(self, *shape: int, inplace: bool = False) -> TensorDictBase: new_source = self._source.expand(*shape) idx = tuple(slice(None) for _ in shape) + tuple(self.idx) if inplace: @@ -2926,36 +2946,38 @@ def rename_key( self._source.rename_key(old_key, new_key, safe=safe) return self - def pin_memory(self) -> _TensorDict: + def pin_memory(self) -> TensorDictBase: self._source.pin_memory() return self - def detach_(self) -> _TensorDict: + def detach_(self) -> TensorDictBase: raise RuntimeError("Detaching a sub-tensordict in-place cannot be done.") def masked_fill_( self, mask: torch.Tensor, value: Union[float, bool] - ) -> _TensorDict: + ) -> TensorDictBase: for key, item in self.items(): self.set_(key, torch.full_like(item, value)) return self - def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorDict: + def masked_fill( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> TensorDictBase: td_copy = self.clone() return td_copy.masked_fill_(mask, value) - def memmap_(self, prefix=None) -> _TensorDict: + def memmap_(self, prefix=None) -> TensorDictBase: raise RuntimeError( "Converting a sub-tensordict values to memmap cannot be done." ) - def share_memory_(self) -> _TensorDict: + def share_memory_(self) -> TensorDictBase: raise RuntimeError( "Casting a sub-tensordict values to shared memory cannot be done." ) -def merge_tensordicts(*tensordicts: _TensorDict) -> _TensorDict: +def merge_tensordicts(*tensordicts: TensorDictBase) -> TensorDictBase: if len(tensordicts) < 2: raise RuntimeError( f"at least 2 tensordicts must be provided, got" f" {len(tensordicts)}" @@ -2966,7 +2988,7 @@ def merge_tensordicts(*tensordicts: _TensorDict) -> _TensorDict: return TensorDict({}, [], device=td._device_safe()).update(d) -class LazyStackedTensorDict(_TensorDict): +class LazyStackedTensorDict(TensorDictBase): """A Lazy stack of TensorDicts. When stacking TensorDicts together, the default behaviour is to put them @@ -2999,7 +3021,7 @@ class LazyStackedTensorDict(_TensorDict): def __init__( self, - *tensordicts: _TensorDict, + *tensordicts: TensorDictBase, stack_dim: int = 0, batch_size: Optional[Sequence[int]] = None, # TODO: remove ): @@ -3015,9 +3037,9 @@ def __init__( "at least one tensordict must be provided to " "StackedTensorDict to be instantiated" ) - if not isinstance(tensordicts[0], _TensorDict): + if not isinstance(tensordicts[0], TensorDictBase): raise TypeError( - f"Expected input to be _TensorDict instance" + f"Expected input to be TensorDictBase instance" f" but got {type(tensordicts[0])} instead." ) if stack_dim < 0: @@ -3028,9 +3050,9 @@ def __init__( device = tensordicts[0]._device_safe() for i, td in enumerate(tensordicts[1:]): - if not isinstance(td, _TensorDict): + if not isinstance(td, TensorDictBase): raise TypeError( - f"Expected input to be _TensorDict instance" + f"Expected input to be TensorDictBase instance" f" but got {type(tensordicts[0])} instead." ) _bs = td.batch_size @@ -3043,7 +3065,7 @@ def __init__( f"cannot be created. Got td[0].batch_size={_batch_size} " f"and td[i].batch_size={_bs} " ) - self.tensordicts: List[_TensorDict] = list(tensordicts) + self.tensordicts: List[TensorDictBase] = list(tensordicts) self.stack_dim = stack_dim self._batch_size = self._compute_batch_size(_batch_size, stack_dim, N) self._update_valid_keys() @@ -3118,10 +3140,10 @@ def _compute_batch_size( s.insert(stack_dim, N) return torch.Size(s) - def set(self, key: str, tensor: COMPATIBLE_TYPES, **kwargs) -> _TensorDict: + def set(self, key: str, tensor: COMPATIBLE_TYPES, **kwargs) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): if tensor.batch_size[: self.batch_dims] != self.batch_size: tensor.batch_size = self.clone(recursive=False).batch_size if self.batch_size != tensor.shape[: self.batch_dims]: @@ -3144,11 +3166,11 @@ def set(self, key: str, tensor: COMPATIBLE_TYPES, **kwargs) -> _TensorDict: def set_( self, key: str, tensor: COMPATIBLE_TYPES, no_check: bool = False - ) -> _TensorDict: + ) -> TensorDictBase: if not no_check: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") - if isinstance(tensor, _TensorDict): + if isinstance(tensor, TensorDictBase): if tensor.batch_size[: self.batch_dims] != self.batch_size: tensor.batch_size = self.clone(recursive=False).batch_size if self.batch_size != tensor.shape[: self.batch_dims]: @@ -3178,7 +3200,7 @@ def set_( def set_at_( self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING - ) -> _TensorDict: + ) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") sub_td = self[idx] @@ -3190,7 +3212,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: if dim == self.stack_dim: for source, tensordict_dest in zip(list_item, self.tensordicts): tensordict_dest.set_(key, source) @@ -3229,7 +3251,7 @@ def _make_meta(self, key: str) -> MetaTensor: def is_contiguous(self) -> bool: return False - def contiguous(self) -> _TensorDict: + def contiguous(self) -> TensorDictBase: source = {key: value for key, value in self.items()} batch_size = self.batch_size device = self._device_safe() @@ -3242,7 +3264,7 @@ def contiguous(self) -> _TensorDict: ) return out - def clone(self, recursive: bool = True) -> _TensorDict: + def clone(self, recursive: bool = True) -> TensorDictBase: if recursive: return LazyStackedTensorDict( *[td.clone() for td in self.tensordicts], @@ -3252,13 +3274,13 @@ def clone(self, recursive: bool = True) -> _TensorDict: *[td for td in self.tensordicts], stack_dim=self.stack_dim ) - def pin_memory(self) -> _TensorDict: + def pin_memory(self) -> TensorDictBase: for td in self.tensordicts: td.pin_memory() return self - def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: - if isinstance(dest, type) and issubclass(dest, _TensorDict): + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> TensorDictBase: + if isinstance(dest, type) and issubclass(dest, TensorDictBase): if isinstance(self, dest): return self kwargs.update({"batch_size": self.batch_size}) @@ -3306,7 +3328,7 @@ def _update_valid_keys(self) -> None: valid_keys = valid_keys.intersection(td.keys()) self._valid_keys = sorted(list(valid_keys)) - def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: # if len(set(self.valid_keys).intersection(keys)) != len(keys): # raise KeyError( # f"Selected and existing keys mismatch, got self.valid_keys" @@ -3320,7 +3342,7 @@ def select(self, *keys: str, inplace: bool = False) -> _TensorDict: stack_dim=self.stack_dim, ) - def __setitem__(self, item: INDEX_TYPING, value: _TensorDict) -> _TensorDict: + def __setitem__(self, item: INDEX_TYPING, value: TensorDictBase) -> TensorDictBase: if isinstance(item, list): item = torch.tensor(item, device=self.device) if isinstance(item, tuple) and any( @@ -3345,7 +3367,7 @@ def __setitem__(self, item: INDEX_TYPING, value: _TensorDict) -> _TensorDict: ) return super().__setitem__(item, value) - def __getitem__(self, item: INDEX_TYPING) -> _TensorDict: + def __getitem__(self, item: INDEX_TYPING) -> TensorDictBase: if item is Ellipsis or (isinstance(item, tuple) and Ellipsis in item): item = convert_ellipsis_to_idx(item, self.batch_size) if isinstance(item, tuple) and sum( @@ -3415,7 +3437,7 @@ def __getitem__(self, item: INDEX_TYPING) -> _TensorDict: ) if len(_sub_item): tensordicts = self.tensordicts[_sub_item[0]] - if isinstance(tensordicts, _TensorDict): + if isinstance(tensordicts, TensorDictBase): return tensordicts else: tensordicts = self.tensordicts @@ -3435,30 +3457,30 @@ def __getitem__(self, item: INDEX_TYPING) -> _TensorDict: f"{item.__class__.__name__} is not supported yet" ) - def del_(self, key: str, **kwargs) -> _TensorDict: + def del_(self, key: str, **kwargs) -> TensorDictBase: for td in self.tensordicts: td.del_(key, **kwargs) self._valid_keys.remove(key) return self - def share_memory_(self) -> _TensorDict: + def share_memory_(self) -> TensorDictBase: for td in self.tensordicts: td.share_memory_() self._is_shared = True return self - def detach_(self) -> _TensorDict: + def detach_(self) -> TensorDictBase: for td in self.tensordicts: td.detach_() return self - def memmap_(self, prefix=None) -> _TensorDict: + def memmap_(self, prefix=None) -> TensorDictBase: for td in self.tensordicts: td.memmap_(prefix=prefix) self._is_memmap = True return self - def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: + def expand(self, *shape: int, inplace: bool = False) -> TensorDictBase: stack_dim = self.stack_dim + len(shape) tensordicts = [td.expand(*shape) for td in self.tensordicts] if inplace: @@ -3468,8 +3490,8 @@ def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: return torch.stack(tensordicts, stack_dim) def update( - self, input_dict_or_td: _TensorDict, clone: bool = False, **kwargs - ) -> _TensorDict: + self, input_dict_or_td: TensorDictBase, clone: bool = False, **kwargs + ) -> TensorDictBase: if input_dict_or_td is self: # no op return self @@ -3486,10 +3508,10 @@ def update( def update_( self, - input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: if input_dict_or_td is self: # no op return self @@ -3504,7 +3526,9 @@ def update_( self.set_(key, value, **kwargs) return self - def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + def rename_key( + self, old_key: str, new_key: str, safe: bool = False + ) -> TensorDictBase: for td in self.tensordicts: td.rename_key(old_key, new_key, safe=safe) self._valid_keys = sorted( @@ -3514,30 +3538,32 @@ def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorD def masked_fill_( self, mask: torch.Tensor, value: Union[float, bool] - ) -> _TensorDict: + ) -> TensorDictBase: mask_unbind = mask.unbind(dim=self.stack_dim) for _mask, td in zip(mask_unbind, self.tensordicts): td.masked_fill_(_mask, value) return self - def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorDict: + def masked_fill( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> TensorDictBase: td_copy = self.clone() return td_copy.masked_fill_(mask, value) -class SavedTensorDict(_TensorDict): +class SavedTensorDict(TensorDictBase): _safe = False _lazy = False def __init__( self, - source: _TensorDict, + source: TensorDictBase, device: Optional[torch.device] = None, batch_size: Optional[Sequence[int]] = None, ): - if not isinstance(source, _TensorDict): + if not isinstance(source, TensorDictBase): raise TypeError( - f"Expected source to be a _TensorDict instance, but got {type(source)} instead." + f"Expected source to be a TensorDictBase instance, but got {type(source)} instead." ) elif isinstance(source, SavedTensorDict): source = source._load() @@ -3562,7 +3588,7 @@ def __init__( if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") - def _save(self, tensordict: _TensorDict) -> None: + def _save(self, tensordict: TensorDictBase) -> None: self._version = uuid.uuid1() self._keys = list(tensordict.keys()) self._batch_size = tensordict.batch_size @@ -3577,7 +3603,7 @@ def _make_meta(self, key: str) -> MetaTensor: ) return self._dict_meta["key"] - def _load(self) -> _TensorDict: + def _load(self) -> TensorDictBase: return torch.load(self.filename, map_location=self._device_safe()) @property @@ -3627,7 +3653,7 @@ def get( td = self._load() return td.get(key, default=default) - def set(self, key: str, value: COMPATIBLE_TYPES, **kwargs) -> _TensorDict: + def set(self, key: str, value: COMPATIBLE_TYPES, **kwargs) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") td = self._load() @@ -3635,7 +3661,7 @@ def set(self, key: str, value: COMPATIBLE_TYPES, **kwargs) -> _TensorDict: self._save(td) return self - def expand(self, *shape: int, inplace: bool = False) -> _TensorDict: + def expand(self, *shape: int, inplace: bool = False) -> TensorDictBase: td = self._load() td = td.expand(*shape) if inplace: @@ -3648,7 +3674,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") td = self._load() @@ -3658,7 +3684,7 @@ def _stack_onto_( def set_( self, key: str, value: COMPATIBLE_TYPES, no_check: bool = False - ) -> _TensorDict: + ) -> TensorDictBase: if not no_check and self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") self.set(key, value) @@ -3666,7 +3692,7 @@ def set_( def set_at_( self, key: str, value: COMPATIBLE_TYPES, idx: INDEX_TYPING - ) -> _TensorDict: + ) -> TensorDictBase: if self.is_locked: raise RuntimeError("Cannot modify immutable TensorDict") td = self._load() @@ -3676,10 +3702,10 @@ def set_at_( def update( self, - input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: if input_dict_or_td is self: # no op return self @@ -3698,9 +3724,9 @@ def update( def update_( self, - input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], _TensorDict], + input_dict_or_td: Union[Dict[str, COMPATIBLE_TYPES], TensorDictBase], clone: bool = False, - ) -> _TensorDict: + ) -> TensorDictBase: if input_dict_or_td is self: return self return self.update(input_dict_or_td, clone=clone) @@ -3715,15 +3741,15 @@ def is_shared(self, no_check: bool = False) -> bool: def is_memmap(self, no_check: bool = False) -> bool: return False - def share_memory_(self) -> _TensorDict: + def share_memory_(self) -> TensorDictBase: raise RuntimeError("SavedTensorDict cannot be put in shared memory.") - def memmap_(self, prefix=None) -> _TensorDict: + def memmap_(self, prefix=None) -> TensorDictBase: raise RuntimeError( "SavedTensorDict and memmap are mutually exclusive features." ) - def detach_(self) -> _TensorDict: + def detach_(self) -> TensorDictBase: raise RuntimeError("SavedTensorDict cannot be put detached.") def items(self) -> Iterator[Tuple[str, COMPATIBLE_TYPES]]: @@ -3743,20 +3769,22 @@ def values(self) -> Iterator[COMPATIBLE_TYPES]: def is_contiguous(self) -> bool: return False - def contiguous(self) -> _TensorDict: + def contiguous(self) -> TensorDictBase: return self._load().contiguous() - def clone(self, recursive: bool = True) -> _TensorDict: + def clone(self, recursive: bool = True) -> TensorDictBase: return SavedTensorDict(self, device=self.device) - def select(self, *keys: str, inplace: bool = False) -> _TensorDict: + def select(self, *keys: str, inplace: bool = False) -> TensorDictBase: _source = self.contiguous().select(*keys) if inplace: self._save(_source) return self return SavedTensorDict(source=_source) - def rename_key(self, old_key: str, new_key: str, safe: bool = False) -> _TensorDict: + def rename_key( + self, old_key: str, new_key: str, safe: bool = False + ) -> TensorDictBase: td = self._load() td.rename_key(old_key, new_key, safe=safe) self._save(td) @@ -3769,7 +3797,7 @@ def __repr__(self) -> str: ) def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs): - if isinstance(dest, type) and issubclass(dest, _TensorDict): + if isinstance(dest, type) and issubclass(dest, TensorDictBase): if isinstance(self, dest): return self kwargs.update({"batch_size": self.batch_size}) @@ -3807,13 +3835,13 @@ def _change_batch_size(self, new_size: torch.Size): del self._orig_batch_size self._batch_size = new_size - def del_(self, key: str) -> _TensorDict: + def del_(self, key: str) -> TensorDictBase: td = self._load() td = td.del_(key) self._save(td) return self - def pin_memory(self) -> _TensorDict: + def pin_memory(self) -> TensorDictBase: raise RuntimeError("pin_memory requires tensordicts that live in memory.") def __reduce__(self, *args, **kwargs): @@ -3825,7 +3853,7 @@ def __reduce__(self, *args, **kwargs): return super(SavedTensorDict, self_copy).__reduce__(*args, **kwargs) return super().__reduce__(*args, **kwargs) - def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict: + def __getitem__(self, idx: INDEX_TYPING) -> TensorDictBase: if isinstance(idx, list): idx = torch.tensor(idx, device=self.device) if isinstance(idx, tuple) and any( @@ -3867,25 +3895,27 @@ def __getitem__(self, idx: INDEX_TYPING) -> _TensorDict: def masked_fill_( self, mask: torch.Tensor, value: Union[float, bool] - ) -> _TensorDict: + ) -> TensorDictBase: td = self._load() td.masked_fill_(mask, value) self._save(td) return self - def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorDict: + def masked_fill( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> TensorDictBase: td_copy = self.clone() return td_copy.masked_fill_(mask, value) -class _CustomOpTensorDict(_TensorDict): +class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" _lazy = True def __init__( self, - source: _TensorDict, + source: TensorDictBase, custom_op: str, inv_op: Optional[str] = None, custom_op_kwargs: Optional[dict] = None, @@ -3897,9 +3927,9 @@ def __init__( self._is_shared = None self._is_memmap = None - if not isinstance(source, _TensorDict): + if not isinstance(source, TensorDictBase): raise TypeError( - f"Expected source to be a _TensorDict isntance, " + f"Expected source to be a TensorDictBase isntance, " f"but got {type(source)} instead." ) self._source = source @@ -4001,7 +4031,7 @@ def get( ) return self._default_get(key, default) - def set(self, key: str, value: COMPATIBLE_TYPES, **kwargs) -> _TensorDict: + def set(self, key: str, value: COMPATIBLE_TYPES, **kwargs) -> TensorDictBase: if self.inv_op is None: raise Exception( f"{self.__class__.__name__} does not support setting values. " @@ -4058,7 +4088,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: raise RuntimeError( f"stacking tensordicts is not allowed for type {type(self)}" f"consider calling 'to_tensordict()` first" @@ -4085,7 +4115,7 @@ def select(self, *keys: str, inplace: bool = False) -> _CustomOpTensorDict: self_copy._source = self_copy._source.select(*keys) return self_copy - def clone(self, recursive: bool = True) -> _TensorDict: + def clone(self, recursive: bool = True) -> TensorDictBase: if not recursive: return copy(self) return TensorDict( @@ -4097,7 +4127,7 @@ def clone(self, recursive: bool = True) -> _TensorDict: def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) - def contiguous(self) -> _TensorDict: + def contiguous(self) -> TensorDictBase: if self.is_contiguous(): return self return self.to(TensorDict) @@ -4112,8 +4142,8 @@ def del_(self, key: str) -> _CustomOpTensorDict: self._source = self._source.del_(key) return self - def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: - if isinstance(dest, type) and issubclass(dest, _TensorDict): + def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> TensorDictBase: + if isinstance(dest, type) and issubclass(dest, TensorDictBase): if isinstance(self, dest): return self return dest(source=self) @@ -4130,7 +4160,7 @@ def to(self, dest: Union[DEVICE_TYPING, Type], **kwargs) -> _TensorDict: f"instance, {dest} not allowed" ) - def pin_memory(self) -> _TensorDict: + def pin_memory(self) -> TensorDictBase: self._source.pin_memory() return self @@ -4139,7 +4169,7 @@ def detach_(self): def masked_fill_( self, mask: torch.Tensor, value: Union[float, bool] - ) -> _TensorDict: + ) -> TensorDictBase: for key, item in self.items(): # source_meta_tensor = self._get_meta(key) val = self._source.get(key) @@ -4153,7 +4183,9 @@ def masked_fill_( self._source.set(key, val) return self - def masked_fill(self, mask: torch.Tensor, value: Union[float, bool]) -> _TensorDict: + def masked_fill( + self, mask: torch.Tensor, value: Union[float, bool] + ) -> TensorDictBase: td_copy = self.clone() return td_copy.masked_fill_(mask, value) @@ -4187,7 +4219,7 @@ class UnsqueezedTensorDict(_CustomOpTensorDict): True """ - def squeeze(self, dim: int) -> _TensorDict: + def squeeze(self, dim: int) -> TensorDictBase: if dim < 0: dim = self.batch_dims + dim if dim == self.custom_op_kwargs.get("dim"): @@ -4199,7 +4231,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: unsqueezed_dim = self.custom_op_kwargs["dim"] diff_to_apply = 1 if dim < unsqueezed_dim else 0 list_item_unsqueeze = [ @@ -4214,7 +4246,7 @@ class SqueezedTensorDict(_CustomOpTensorDict): See the `UnsqueezedTensorDict` class documentation for more information. """ - def unsqueeze(self, dim: int) -> _TensorDict: + def unsqueeze(self, dim: int) -> TensorDictBase: if dim < 0: dim = self.batch_dims + dim + 1 inv_op_dim = self.inv_op_kwargs.get("dim") @@ -4229,7 +4261,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: squeezed_dim = self.custom_op_kwargs["dim"] # dim=0, squeezed_dim=2, [3, 4, 5] [3, 4, 1, 5] [[4, 5], [4, 5], [4, 5]] => unsq 1 # dim=1, squeezed_dim=2, [3, 4, 5] [3, 4, 1, 5] [[3, 5], [3, 5], [3, 5], [3, 4]] => unsq 1 @@ -4260,7 +4292,7 @@ def _update_inv_op_kwargs(self, tensor: torch.Tensor) -> Dict: def view( self, *shape, size: Optional[Union[List, Tuple, torch.Size]] = None - ) -> _TensorDict: + ) -> TensorDictBase: if len(shape) == 0 and size is not None: return self.view(*size) elif len(shape) == 1 and isinstance(shape[0], (list, tuple, torch.Size)): @@ -4297,7 +4329,7 @@ def permute( self, *dims_list: int, dims=None, - ) -> _TensorDict: + ) -> TensorDictBase: if len(dims_list) == 0: dims_list = dims elif len(dims_list) == 1 and not isinstance(dims_list[0], int): @@ -4345,7 +4377,7 @@ def _stack_onto_( key: str, list_item: List[COMPATIBLE_TYPES], dim: int, - ) -> _TensorDict: + ) -> TensorDictBase: permute_dims = self.custom_op_kwargs["dims"] inv_permute_dims = np.argsort(permute_dims) @@ -4363,7 +4395,7 @@ def _stack_onto_( return self -def _td_fields(td: _TensorDict) -> str: +def _td_fields(td: TensorDictBase) -> str: return indent( "\n" + ",\n".join( @@ -4374,7 +4406,7 @@ def _td_fields(td: _TensorDict) -> str: def _check_keys( - list_of_tensordicts: Sequence[_TensorDict], strict: bool = False + list_of_tensordicts: Sequence[TensorDictBase], strict: bool = False ) -> Set[str]: keys: Set[str] = set() for td in list_of_tensordicts: @@ -4394,7 +4426,7 @@ def _check_keys( return keys -_accepted_classes = (torch.Tensor, MemmapTensor, _TensorDict) +_accepted_classes = (torch.Tensor, MemmapTensor, TensorDictBase) def _expand_to_match_shape(parent_batch_size, tensor, self_batch_dims, self_device): diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index ba55e1a0039..73eb53ea913 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -62,8 +62,8 @@ def __call__(self, **kwargs) -> Any: def expand_as_right( - tensor: Union[torch.Tensor, "MemmapTensor", "_TensorDict"], - dest: Union[torch.Tensor, "MemmapTensor", "_TensorDict"], + tensor: Union[torch.Tensor, "MemmapTensor", "TensorDictBase"], + dest: Union[torch.Tensor, "MemmapTensor", "TensorDictBase"], ): """Expand a tensor on the right to match another tensor shape. Args: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 74a6f075992..c9485d0c599 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -15,7 +15,7 @@ from torchrl import seed_generator, prod from torchrl.data import CompositeSpec, TensorDict, TensorSpec -from ..data.tensordict.tensordict import _TensorDict +from ..data.tensordict.tensordict import TensorDictBase from ..data.utils import DEVICE_TYPING from .utils import get_available_libraries, step_tensordict @@ -68,7 +68,7 @@ def keys(self) -> dict: def build_tensordict( self, next_observation: bool = True, log_prob: bool = False - ) -> _TensorDict: + ) -> TensorDictBase: """returns a TensorDict with empty tensors of the desired shape""" # build a tensordict from specs td = TensorDict({}, batch_size=torch.Size([])) @@ -123,11 +123,11 @@ class _EnvClass: last reset Methods: - step (_TensorDict -> _TensorDict): step in the environment - reset (_TensorDict, optional -> _TensorDict): reset the environment + step (TensorDictBase -> TensorDictBase): step in the environment + reset (TensorDictBase, optional -> TensorDictBase): reset the environment set_seed (int -> int): sets the seed of the environment - rand_step (_TensorDict, optional -> _TensorDict): random step given the action spec - rollout (Callable, ... -> _TensorDict): executes a rollout in the environment with the given policy (or random + rand_step (TensorDictBase, optional -> TensorDictBase): random step given the action spec + rollout (Callable, ... -> TensorDictBase): executes a rollout in the environment with the given policy (or random steps if no policy is provided) """ @@ -204,14 +204,14 @@ def observation_spec(self) -> TensorSpec: def observation_spec(self, value: TensorSpec) -> None: self._observation_spec = value - def step(self, tensordict: _TensorDict) -> _TensorDict: + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. Step accepts a single argument, tensordict, which usually carries an 'action' key which indicates the action to be taken. Step will call an out-place private method, _step, which is the method to be re-written by _EnvClass subclasses. Args: - tensordict (_TensorDict): Tensordict containing the action to be taken. + tensordict (TensorDictBase): Tensordict containing the action to be taken. Returns: the input tensordict, modified in place with the resulting observations, done state and reward @@ -271,24 +271,24 @@ def train(self, mode: bool = True) -> _EnvClass: def _step( self, - tensordict: _TensorDict, - ) -> _TensorDict: + tensordict: TensorDictBase, + ) -> TensorDictBase: raise NotImplementedError - def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: raise NotImplementedError def reset( self, - tensordict: Optional[_TensorDict] = None, + tensordict: Optional[TensorDictBase] = None, execute_step: bool = True, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: """Resets the environment. As for step and _step, only the private method `_reset` should be overwritten by _EnvClass subclasses. Args: - tensordict (_TensorDict, optional): tensordict to be used to contain the resulting new observation. + tensordict (TensorDictBase, optional): tensordict to be used to contain the resulting new observation. In some cases, this input can also be used to pass argument to the reset function. execute_step (bool, optional): if True, a `step_tensordict` is executed on the output TensorDict, hereby removing the `"next_"` prefixes from the keys. @@ -307,7 +307,7 @@ def reset( "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " "tensordict.select()) inside _reset before writing new tensors onto this new instance." ) - if not isinstance(tensordict_reset, _TensorDict): + if not isinstance(tensordict_reset, TensorDictBase): raise RuntimeError( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) @@ -362,7 +362,7 @@ def _set_seed(self, seed: Optional[int]): def set_state(self): raise NotImplementedError - def _assert_tensordict_shape(self, tensordict: _TensorDict) -> None: + def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: if tensordict.batch_size != self.batch_size: raise RuntimeError( f"Expected a tensordict with shape==env.shape, " @@ -379,11 +379,11 @@ def is_done_set_fn(self, val: torch.Tensor) -> None: is_done = property(is_done_get_fn, is_done_set_fn) - def rand_step(self, tensordict: Optional[_TensorDict] = None) -> _TensorDict: + def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: """Performs a random step in the environment given the action_spec attribute. Args: - tensordict (_TensorDict, optional): tensordict where the resulting info should be written. + tensordict (TensorDictBase, optional): tensordict where the resulting info should be written. Returns: a tensordict object with the new observation after a random step in the environment. The action will @@ -410,13 +410,13 @@ def specs(self) -> Specs: def rollout( self, max_steps: int, - policy: Optional[Callable[[_TensorDict], _TensorDict]] = None, - callback: Optional[Callable[[_TensorDict, ...], _TensorDict]] = None, + policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, + callback: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, auto_reset: bool = True, auto_cast_to_device: bool = False, break_when_any_done: bool = True, - tensordict: Optional[_TensorDict] = None, - ) -> _TensorDict: + tensordict: Optional[TensorDictBase] = None, + ) -> TensorDictBase: """Executes a rollout in the environment. The function will stop as soon as one of the contained environments @@ -487,7 +487,7 @@ def policy(td): out_td = torch.stack(tensordicts, len(self.batch_size)) return out_td - def _select_observation_keys(self, tensordict: _TensorDict) -> Iterator[str]: + def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]: for key in tensordict.keys(): if key.rfind("observation") >= 0: yield key @@ -679,8 +679,8 @@ def _set_seed(self, seed: Optional[int]): def make_tensordict( env: _EnvClass, - policy: Optional[Callable[[_TensorDict, ...], _TensorDict]] = None, -) -> _TensorDict: + policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, +) -> TensorDictBase: """ Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 4558eb9a8f2..5da7f5dc16d 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -10,7 +10,7 @@ import torch -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import _EnvClass @@ -94,7 +94,7 @@ def __init__( def share_memory(self, state_dict: OrderedDict) -> None: for key, item in list(state_dict.items()): - if isinstance(item, (_TensorDict,)): + if isinstance(item, (TensorDictBase,)): if not item.is_shared(): item.share_memory_() else: diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 50860f56892..ed39222ca39 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -7,7 +7,7 @@ import torch from torchrl.data import TensorDict -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.common import _EnvWrapper __all__ = ["GymLikeEnv", "default_info_dict_reader"] @@ -38,7 +38,7 @@ def __init__(self, keys=None): keys = [] self.keys = keys - def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict: + def __call__(self, info_dict: dict, tensordict: TensorDictBase) -> TensorDictBase: if not isinstance(info_dict, dict) and len(self.keys): warnings.warn( f"Found an info_dict of type {type(info_dict)} " @@ -78,7 +78,7 @@ def __new__(cls, *args, **kwargs): cls._info_dict_reader = None return super().__new__(cls, *args, **kwargs) - def _step(self, tensordict: _TensorDict) -> _TensorDict: + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: action = tensordict.get("action") action_np = self.action_spec.to_numpy(action, safe=False) @@ -114,7 +114,9 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict: return tensordict_out - def _reset(self, tensordict: Optional[_TensorDict] = None, **kwargs) -> _TensorDict: + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: obs, *_ = self._output_transform((self._env.reset(**kwargs),)) tensordict_out = TensorDict( source=self._read_obs(obs), diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b148ebf0cd3..e1b5ae28f3e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -31,7 +31,7 @@ BinaryDiscreteTensorSpec, DEVICE_TYPING, ) -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.envs.common import _EnvClass, make_tensordict from torchrl.envs.transforms import functional as F from torchrl.envs.transforms.utils import FiniteTensor @@ -125,7 +125,7 @@ def __init__( keys_inv_out = copy(self.keys_inv_in) self.keys_inv_out = keys_inv_out - def reset(self, tensordict: _TensorDict) -> _TensorDict: + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Resets a tranform if it is stateful.""" return tensordict @@ -147,7 +147,7 @@ def _apply_transform(self, obs: torch.Tensor) -> None: """ raise NotImplementedError - def _call(self, tensordict: _TensorDict) -> _TensorDict: + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform. @@ -159,7 +159,7 @@ def _call(self, tensordict: _TensorDict) -> _TensorDict: tensordict.set(key_out, observation, inplace=self.inplace) return tensordict - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: self._call(tensordict) return tensordict @@ -169,7 +169,7 @@ def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: else: return obs - def _inv_call(self, tensordict: _TensorDict) -> _TensorDict: + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: self._check_inplace() for key_in, key_out in zip(self.keys_inv_in, self.keys_inv_out): for key_in in tensordict.keys(): @@ -177,7 +177,7 @@ def _inv_call(self, tensordict: _TensorDict) -> _TensorDict: tensordict.set(key_out, observation, inplace=self.inplace) return tensordict - def inv(self, tensordict: _TensorDict) -> _TensorDict: + def inv(self, tensordict: TensorDictBase) -> TensorDictBase: self._inv_call(tensordict) return tensordict @@ -357,7 +357,7 @@ def reward_spec(self) -> TensorSpec: reward_spec = self._reward_spec return reward_spec - def _step(self, tensordict: _TensorDict) -> _TensorDict: + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # selected_keys = [key for key in tensordict.keys() if "action" in key] # tensordict_in = tensordict.select(*selected_keys).clone() tensordict_in = self.transform.inv(tensordict.clone(recursive=False)) @@ -371,7 +371,7 @@ def set_seed(self, seed: int) -> int: """Set the seeds of the environment""" return self.base_env.set_seed(seed) - def _reset(self, tensordict: Optional[_TensorDict] = None, **kwargs): + def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): out_tensordict = self.base_env.reset(execute_step=False, **kwargs) out_tensordict = self.transform.reset(out_tensordict) out_tensordict = self.transform(out_tensordict) @@ -547,7 +547,7 @@ def __init__(self, *transforms: Transform): for t in self.transforms: t.set_parent(self) - def _call(self, tensordict: _TensorDict) -> _TensorDict: + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: for t in self.transforms: tensordict = t(tensordict) return tensordict @@ -578,12 +578,12 @@ def dump(self, **kwargs) -> None: for t in self: t.dump(**kwargs) - def reset(self, tensordict: _TensorDict) -> _TensorDict: + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: for t in self.transforms: tensordict = t.reset(tensordict) return tensordict - def init(self, tensordict: _TensorDict) -> None: + def init(self, tensordict: TensorDictBase) -> None: for t in self.transforms: t.init(tensordict) @@ -1140,7 +1140,7 @@ def __init__( self.cat_dim = cat_dim self.buffer = [] - def reset(self, tensordict: _TensorDict) -> _TensorDict: + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: self.buffer = [] return tensordict @@ -1240,7 +1240,7 @@ class FiniteTensorDictCheck(Transform): def __init__(self): super().__init__(keys_in=[]) - def _call(self, tensordict: _TensorDict) -> _TensorDict: + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: source = {} for key, item in tensordict.items(): try: @@ -1384,7 +1384,7 @@ def __init__( self.dim = dim self.del_keys = del_keys - def _call(self, tensordict: _TensorDict) -> _TensorDict: + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if all([key in tensordict.keys() for key in self.keys_in]): out_tensor = torch.cat( [tensordict.get(key) for key in self.keys_in], dim=self.dim @@ -1548,7 +1548,7 @@ def __init__(self, noops: int = 30, random: bool = True): def base_env(self): return self.parent - def reset(self, tensordict: _TensorDict) -> _TensorDict: + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Do no-op action for a number of steps in [1, noop_max].""" parent = self.parent keys = tensordict.keys() @@ -1602,7 +1602,7 @@ class PinMemoryTransform(Transform): def __init__(self): super().__init__([]) - def _call(self, tensordict: _TensorDict) -> _TensorDict: + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict.pin_memory() @@ -1624,7 +1624,7 @@ def __init__( self.state_dim = state_dim self.action_dim = action_dim - def reset(self, tensordict: _TensorDict) -> _TensorDict: + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = super().reset(tensordict=tensordict) if self.state_dim is None or self.action_dim is None: tensordict.set( @@ -1657,14 +1657,14 @@ class VecNorm(Transform): statistics are not updated. If multiple processes are running a similar environment, one can pass a - _TensorDict instance that is placed in shared memory: if so, every time + TensorDictBase instance that is placed in shared memory: if so, every time the normalization layer is queried it will update the values for all processes that share the same reference. Args: keys_in (iterable of str, optional): keys to be updated. default: ["next_observation", "reward"] - shared_td (_TensorDict, optional): A shared tensordict containing the + shared_td (TensorDictBase, optional): A shared tensordict containing the keys of the transform. decay (number, optional): decay rate of the moving average. default: 0.99 @@ -1695,7 +1695,7 @@ class VecNorm(Transform): def __init__( self, keys_in: Optional[Sequence[str]] = None, - shared_td: Optional[_TensorDict] = None, + shared_td: Optional[TensorDictBase] = None, decay: float = 0.9999, eps: float = 1e-4, ) -> None: @@ -1724,7 +1724,7 @@ def __init__( self.decay = decay self.eps = eps - def _call(self, tensordict: _TensorDict) -> _TensorDict: + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: for key in self.keys_in: if key not in tensordict.keys(): continue @@ -1737,7 +1737,7 @@ def _call(self, tensordict: _TensorDict) -> _TensorDict: tensordict.set_(key, new_val) return tensordict - def _init(self, tensordict: _TensorDict, key: str) -> None: + def _init(self, tensordict: TensorDictBase, key: str) -> None: if self._td is None or key + "_sum" not in self._td.keys(): td_view = tensordict.view(-1) td_select = td_view[0] @@ -1794,7 +1794,7 @@ def build_td_for_shared_vecnorm( env: _EnvClass, keys_prefix: Optional[Sequence[str]] = None, memmap: bool = False, - ) -> _TensorDict: + ) -> TensorDictBase: """Creates a shared tensordict that can be sent to different processes for normalization across processes. @@ -1854,10 +1854,10 @@ def build_td_for_shared_vecnorm( return td_select.memmap_() return td_select.share_memory_() - def get_extra_state(self) -> _TensorDict: + def get_extra_state(self) -> TensorDictBase: return self._td - def set_extra_state(self, td: _TensorDict) -> None: + def set_extra_state(self, td: TensorDictBase) -> None: if not td.is_shared(): raise RuntimeError( "Only shared tensordicts can be set in VecNorm transforms" diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 41b1e538325..67d332886d2 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -8,7 +8,7 @@ import pkg_resources from torch.autograd.grad_mode import _DecoratorContextManager -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} @@ -19,21 +19,21 @@ def __get__(self, cls, owner): def step_tensordict( - tensordict: _TensorDict, - next_tensordict: _TensorDict = None, + tensordict: TensorDictBase, + next_tensordict: TensorDictBase = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = True, exclude_action: bool = True, -) -> _TensorDict: +) -> TensorDictBase: """ Given a tensordict retrieved after a step, returns another tensordict with all the 'next_' prefixes are removed, i.e. all the `'next_some_other_string'` keys will be renamed onto `'some_other_string'` keys. Args: - tensordict (_TensorDict): tensordict with keys to be renamed - next_tensordict (_TensorDict, optional): destination tensordict + tensordict (TensorDictBase): tensordict with keys to be renamed + next_tensordict (TensorDictBase, optional): destination tensordict keep_other (bool, optional): if True, all keys that do not start with `'next_'` will be kept. Default is True. exclude_reward (bool, optional): if True, the `"reward"` key will be discarded diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index c3c9dd08629..5b09ec1eb50 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -18,7 +18,7 @@ from torchrl import _check_for_faulty_process from torchrl.data import TensorDict, TensorSpec -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _EnvClass, make_tensordict from torchrl.envs.env_creator import EnvCreator @@ -513,7 +513,7 @@ def set_seed(self, seed: int) -> int: return seed @_check_start - def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict is not None and "reset_workers" in tensordict.keys(): self._assert_tensordict_shape(tensordict) reset_workers = tensordict.get("reset_workers") @@ -654,7 +654,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: raise RuntimeError(f"Expected 'loaded' but received {msg}") @_check_start - def _step(self, tensordict: _TensorDict) -> _TensorDict: + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self._assert_tensordict_shape(tensordict) self.shared_tensordict_parent.update_(tensordict.select(*self.env_input_keys)) @@ -716,7 +716,7 @@ def set_seed(self, seed: int) -> int: return seed @_check_start - def _reset(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: cmd_out = "reset" if tensordict is not None and "reset_workers" in tensordict.keys(): self._assert_tensordict_shape(tensordict) diff --git a/torchrl/modules/models/recipes/impala.py b/torchrl/modules/models/recipes/impala.py index 89140f8cc72..88a8c50362f 100644 --- a/torchrl/modules/models/recipes/impala.py +++ b/torchrl/modules/models/recipes/impala.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase # TODO: code small architecture ref in Impala paper @@ -173,7 +173,7 @@ def _allocate_masked_x(self, x, mask): class ImpalaNetTensorDict(ImpalaNet): observation_key = "pixels" - def forward(self, tensordict: _TensorDict): + def forward(self, tensordict: TensorDictBase): x = tensordict.get(self.observation_key) done = tensordict.get("done").squeeze(-1) reward = tensordict.get("reward").squeeze(-1) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index ee588be1c3c..bacef78a9a0 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -27,7 +27,7 @@ TensorSpec, CompositeSpec, ) -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase __all__ = [ "TensorDictModule", @@ -234,12 +234,12 @@ def spec(self, spec: TensorSpec) -> None: def _write_to_tensordict( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, tensors: List, - tensordict_out: Optional[_TensorDict] = None, + tensordict_out: Optional[TensorDictBase] = None, out_keys: Optional[Iterable[str]] = None, vmap: Optional[int] = None, - ) -> _TensorDict: + ) -> TensorDictBase: if out_keys is None: out_keys = self.out_keys @@ -338,10 +338,10 @@ def _call_module( def forward( self, - tensordict: _TensorDict, - tensordict_out: Optional[_TensorDict] = None, + tensordict: TensorDictBase, + tensordict_out: Optional[TensorDictBase] = None, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys) tensors = self._call_module(tensors, **kwargs) if not isinstance(tensors, tuple): @@ -354,12 +354,12 @@ def forward( ) return tensordict_out - def random(self, tensordict: _TensorDict) -> _TensorDict: + def random(self, tensordict: TensorDictBase) -> TensorDictBase: """Samples a random element in the target space, irrespective of any input. If multiple output keys are present, only the first will be written in the input `tensordict`. Args: - tensordict (_TensorDict): tensordict where the output value should be written. + tensordict (TensorDictBase): tensordict where the output value should be written. Returns: the original tensordict with a new/updated value for the output key. @@ -369,7 +369,7 @@ def random(self, tensordict: _TensorDict) -> _TensorDict: tensordict.set(key0, self.spec.rand(tensordict.batch_size)) return tensordict - def random_sample(self, tensordict: _TensorDict) -> _TensorDict: + def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: """see TensorDictModule.random(...)""" return self.random(tensordict) diff --git a/torchrl/modules/tensordict_module/deprec.py b/torchrl/modules/tensordict_module/deprec.py index 98f44d18f5b..e0ec0f7664b 100644 --- a/torchrl/modules/tensordict_module/deprec.py +++ b/torchrl/modules/tensordict_module/deprec.py @@ -8,7 +8,7 @@ from torch import Tensor, nn, distributions as d from torchrl.data import TensorSpec, DEVICE_TYPING -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.utils import exploration_mode from torchrl.modules import TensorDictModule, Delta, distributions_maps @@ -161,14 +161,14 @@ def __init__( def get_dist( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, **kwargs, ) -> Tuple[torch.distributions.Distribution, ...]: """Calls the module using the tensors retrieved from the 'in_keys' attribute and returns a distribution using its output. Args: - tensordict (_TensorDict): tensordict with the input values for the creation of the distribution. + tensordict (TensorDictBase): tensordict with the input values for the creation of the distribution. Returns: a distribution along with other tensors returned by the module. @@ -216,10 +216,10 @@ def build_dist_from_params( def forward( self, - tensordict: _TensorDict, - tensordict_out: Optional[_TensorDict] = None, + tensordict: TensorDictBase, + tensordict_out: Optional[TensorDictBase] = None, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: dist, *tensors = self.get_dist(tensordict, **kwargs) out_tensor = self._dist_sample( @@ -236,13 +236,13 @@ def forward( tensordict_out.set("_".join([self.out_keys[0], "log_prob"]), log_prob) return tensordict_out - def log_prob(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: + def log_prob(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: """ Samples/computes an action using the module and writes this value onto the input tensordict along with its log-probability. Args: - tensordict (_TensorDict): tensordict containing the in_keys specified in the initializer. + tensordict (TensorDictBase): tensordict containing the in_keys specified in the initializer. Returns: the same tensordict with the out_keys values added/updated as well as a diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index f7001d1a7fd..59e3db57c21 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -19,7 +19,7 @@ __all__ = ["EGreedyWrapper", "OrnsteinUhlenbeckProcessWrapper"] -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase class EGreedyWrapper(TensorDictModuleWrapper): @@ -95,7 +95,7 @@ def step(self, frames: int = 1) -> None: ).item(), ) - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = self.td_module.forward(tensordict) if exploration_mode() == "random" or exploration_mode() is None: out = tensordict.get(self.td_module.out_keys[0]) @@ -241,7 +241,7 @@ def step(self, frames: int = 1) -> None: f"number of frames." ) - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = super().forward(tensordict) if exploration_mode() == "random" or exploration_mode() is None: tensordict = self.ou.add_sample(tensordict, self.eps.item()) @@ -290,7 +290,7 @@ def noise_key(self): def steps_key(self): return self._steps_key # + str(id(self)) - def _make_noise_pair(self, tensordict: _TensorDict) -> None: + def _make_noise_pair(self, tensordict: TensorDictBase) -> None: tensordict.set( self.noise_key, torch.zeros(tensordict.get(self.key).shape, device=tensordict.device), @@ -304,7 +304,9 @@ def _make_noise_pair(self, tensordict: _TensorDict) -> None: ), ) - def add_sample(self, tensordict: _TensorDict, eps: float = 1.0) -> _TensorDict: + def add_sample( + self, tensordict: TensorDictBase, eps: float = 1.0 + ) -> TensorDictBase: if self.noise_key not in tensordict.keys(): self._make_noise_pair(tensordict) diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 878c23c7216..b2740c929ac 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -12,7 +12,7 @@ from torch import distributions as d from torchrl.data import TensorSpec -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.utils import exploration_mode, set_exploration_mode from torchrl.modules.distributions import distributions_maps, Delta from torchrl.modules.tensordict_module.common import TensorDictModule, _check_all_str @@ -189,7 +189,7 @@ def __init__( self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False self.return_log_prob = return_log_prob - def _call_module(self, tensordict: _TensorDict, **kwargs) -> _TensorDict: + def _call_module(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: return self.module(tensordict, **kwargs) def make_functional_with_buffers(self, clone: bool = True): @@ -211,10 +211,10 @@ def make_functional_with_buffers(self, clone: bool = True): def get_dist( self, - tensordict: _TensorDict, - tensordict_out: Optional[_TensorDict] = None, + tensordict: TensorDictBase, + tensordict_out: Optional[TensorDictBase] = None, **kwargs, - ) -> Tuple[d.Distribution, _TensorDict]: + ) -> Tuple[d.Distribution, TensorDictBase]: interaction_mode = exploration_mode() if interaction_mode is None: interaction_mode = self.default_interaction_mode @@ -225,7 +225,7 @@ def get_dist( dist = self.build_dist_from_params(tensordict_out) return dist, tensordict_out - def build_dist_from_params(self, tensordict_out: _TensorDict) -> d.Distribution: + def build_dist_from_params(self, tensordict_out: TensorDictBase) -> d.Distribution: try: dist = self.distribution_class( **tensordict_out.select(*self.dist_param_keys) @@ -246,10 +246,10 @@ def build_dist_from_params(self, tensordict_out: _TensorDict) -> d.Distribution: def forward( self, - tensordict: _TensorDict, - tensordict_out: Optional[_TensorDict] = None, + tensordict: TensorDictBase, + tensordict_out: Optional[TensorDictBase] = None, **kwargs, - ) -> _TensorDict: + ) -> TensorDictBase: dist, tensordict_out = self.get_dist( tensordict, tensordict_out=tensordict_out, **kwargs diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 48f12a7b2be..78d6d7574ec 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -16,7 +16,7 @@ TensorSpec, CompositeSpec, ) -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.modules.tensordict_module.common import TensorDictModule from torchrl.modules.tensordict_module.probabilistic import ( ProbabilisticTensorDictModule, @@ -192,8 +192,8 @@ def _split_param( return out def forward( - self, tensordict: _TensorDict, tensordict_out=None, **kwargs - ) -> _TensorDict: + self, tensordict: TensorDictBase, tensordict_out=None, **kwargs + ) -> TensorDictBase: if "params" in kwargs and "buffers" in kwargs: param_splits = self._split_param(kwargs["params"], "params") buffer_splits = self._split_param(kwargs["buffers"], "buffers") @@ -316,7 +316,7 @@ def make_functional_with_buffers(self, clone: bool = True): def get_dist( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, **kwargs, ) -> Tuple[torch.distributions.Distribution, ...]: L = len(self.module) diff --git a/torchrl/objectives/costs/common.py b/torchrl/objectives/costs/common.py index d3788a090b9..b4a9d74104f 100644 --- a/torchrl/objectives/costs/common.py +++ b/torchrl/objectives/costs/common.py @@ -15,7 +15,7 @@ from torch import nn from torch.nn import Parameter -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.modules import TensorDictModule @@ -32,7 +32,7 @@ def __init__(self): super().__init__() self._param_maps = dict() - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*". Splitting the loss in its component can then be used by the trainer to log the various loss values throughout diff --git a/torchrl/objectives/costs/ddpg.py b/torchrl/objectives/costs/ddpg.py index 27031f754c1..85e8ea88d53 100644 --- a/torchrl/objectives/costs/ddpg.py +++ b/torchrl/objectives/costs/ddpg.py @@ -9,7 +9,7 @@ import torch -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.modules import TensorDictModule from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.costs.utils import ( @@ -66,13 +66,13 @@ def __init__( self.gamma = gamma self.loss_funtion = loss_function - def forward(self, input_tensordict: _TensorDict) -> TensorDict: + def forward(self, input_tensordict: TensorDictBase) -> TensorDict: """Computes the DDPG losses given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign a priority to items in the tensordict. Args: - input_tensordict (_TensorDict): a tensordict with keys ["done", "reward"] and the in_keys of the actor + input_tensordict (TensorDictBase): a tensordict with keys ["done", "reward"] and the in_keys of the actor and value networks. Returns: @@ -111,7 +111,7 @@ def forward(self, input_tensordict: _TensorDict) -> TensorDict: def _loss_actor( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys).detach() td_copy = self.actor_network( @@ -127,7 +127,7 @@ def _loss_actor( def _loss_value( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() diff --git a/torchrl/objectives/costs/deprecated.py b/torchrl/objectives/costs/deprecated.py index bdad5cad70e..aa959462460 100644 --- a/torchrl/objectives/costs/deprecated.py +++ b/torchrl/objectives/costs/deprecated.py @@ -7,7 +7,7 @@ from torch import Tensor from torchrl.data import TensorDict -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.utils import set_exploration_mode, step_tensordict from torchrl.modules import TensorDictModule from torchrl.objectives import ( @@ -127,7 +127,7 @@ def alpha(self): alpha = self.log_alpha.exp() return alpha - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor, sample_log_prob = self._actor_loss(tensordict) loss_qval = self._qvalue_loss(tensordict) @@ -149,7 +149,7 @@ def forward(self, tensordict: _TensorDict) -> _TensorDict: return td_out - def _actor_loss(self, tensordict: _TensorDict) -> Tuple[Tensor, Tensor]: + def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: obs_keys = self.actor_network.in_keys tensordict_clone = tensordict.select(*obs_keys) # to avoid overwriting keys with set_exploration_mode("random"): @@ -176,7 +176,7 @@ def _actor_loss(self, tensordict: _TensorDict) -> Tuple[Tensor, Tensor]: ).mean(0) return loss_actor, tensordict_clone.get("sample_log_prob") - def _qvalue_loss(self, tensordict: _TensorDict) -> Tensor: + def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: tensordict_save = tensordict next_obs_keys = [key for key in tensordict.keys() if key.startswith("next_")] diff --git a/torchrl/objectives/costs/dqn.py b/torchrl/objectives/costs/dqn.py index a11a399cdd4..e68e5dabee8 100644 --- a/torchrl/objectives/costs/dqn.py +++ b/torchrl/objectives/costs/dqn.py @@ -11,7 +11,7 @@ DistributionalQValueActor, QValueActor, ) -from ...data.tensordict.tensordict import _TensorDict +from ...data.tensordict.tensordict import TensorDictBase from .common import LossModule from .utils import distance_loss, next_state_value @@ -58,14 +58,14 @@ def __init__( self.loss_function = loss_function self.priority_key = priority_key - def forward(self, input_tensordict: _TensorDict) -> TensorDict: + def forward(self, input_tensordict: TensorDictBase) -> TensorDict: """ Computes the DQN loss given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign a priority to items in the tensordict. Args: - input_tensordict (_TensorDict): a tensordict with keys ["done", "reward", "action"] and the in_keys of + input_tensordict (TensorDictBase): a tensordict with keys ["done", "reward", "action"] and the in_keys of the value network. Returns: @@ -167,7 +167,7 @@ def __init__( create_target_params=self.delay_value, ) - def forward(self, input_tensordict: _TensorDict) -> TensorDict: + def forward(self, input_tensordict: TensorDictBase) -> TensorDict: # from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py device = self.device tensordict = TensorDict( diff --git a/torchrl/objectives/costs/impala.py b/torchrl/objectives/costs/impala.py index e6f0ff74988..66be170fd72 100644 --- a/torchrl/objectives/costs/impala.py +++ b/torchrl/objectives/costs/impala.py @@ -5,7 +5,7 @@ import torch -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.modules import ProbabilisticTensorDictModule from torchrl.objectives.returns.vtrace import vtrace @@ -18,7 +18,7 @@ def __init__(self, value_model: ProbabilisticTensorDictModule): def device(self) -> torch.device: return next(self.value_model.parameters()).device - def forward(self, tensordict: _TensorDict) -> None: + def forward(self, tensordict: TensorDictBase) -> None: tensordict_device = tensordict.to(self.device) self.value_model_device(tensordict_device) # udpates the value key gamma = tensordict_device.get("gamma") @@ -35,7 +35,7 @@ def forward(self, tensordict: _TensorDict) -> None: class VTraceEstimator: - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_device = tensordict.to(device) rewards = tensordict_device.get("reward") vals = tensordict_device.get("value") diff --git a/torchrl/objectives/costs/ppo.py b/torchrl/objectives/costs/ppo.py index 80a37cadb5f..560a16138c8 100644 --- a/torchrl/objectives/costs/ppo.py +++ b/torchrl/objectives/costs/ppo.py @@ -9,7 +9,7 @@ import torch from torch import distributions as d -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.envs.utils import step_tensordict from torchrl.modules import TensorDictModule from ...modules.tensordict_module import ProbabilisticTensorDictModule @@ -66,7 +66,7 @@ def __init__( critic_coef: float = 1.0, gamma: float = 0.99, loss_critic_type: str = "smooth_l1", - advantage_module: Optional[Callable[[_TensorDict], _TensorDict]] = None, + advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() self.convert_to_functional(actor, "actor") @@ -95,7 +95,7 @@ def get_entropy_bonus(self, dist: Optional[d.Distribution] = None) -> torch.Tens return entropy.unsqueeze(-1) def _log_weight( - self, tensordict: _TensorDict + self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get("action") @@ -116,7 +116,7 @@ def _log_weight( log_weight = log_prob - prev_log_prob return log_weight, dist - def loss_critic(self, tensordict: _TensorDict) -> torch.Tensor: + def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: if self.advantage_diff_key in tensordict.keys(): advantage_diff = tensordict.get(self.advantage_diff_key) if not advantage_diff.requires_grad: @@ -147,7 +147,7 @@ def loss_critic(self, tensordict: _TensorDict) -> torch.Tensor: ) return self.critic_coef * loss_value - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.advantage_module is not None: tensordict = self.advantage_module( tensordict, @@ -227,7 +227,7 @@ def __init__( math.log1p(self.clip_epsilon), ) - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.advantage_module is not None: tensordict = self.advantage_module(tensordict) tensordict = tensordict.clone() @@ -351,7 +351,7 @@ def __init__( self.decrement = decrement self.samples_mc_kl = samples_mc_kl - def forward(self, tensordict: _TensorDict) -> TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDict: if self.advantage_module is not None: tensordict = self.advantage_module(tensordict) tensordict = tensordict.clone() diff --git a/torchrl/objectives/costs/redq.py b/torchrl/objectives/costs/redq.py index f15530f0800..feb2bd25e75 100644 --- a/torchrl/objectives/costs/redq.py +++ b/torchrl/objectives/costs/redq.py @@ -11,7 +11,7 @@ import torch from torch import Tensor -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.envs.utils import set_exploration_mode, step_tensordict from torchrl.modules import TensorDictModule from torchrl.objectives.costs.common import LossModule @@ -144,7 +144,7 @@ def alpha(self): alpha = self.log_alpha.exp() return alpha - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys next_obs_keys = [key for key in tensordict.keys() if key.startswith("next_")] tensordict_select = tensordict.select( diff --git a/torchrl/objectives/costs/reinforce.py b/torchrl/objectives/costs/reinforce.py index 6f2f1deff58..7c7a11e8908 100644 --- a/torchrl/objectives/costs/reinforce.py +++ b/torchrl/objectives/costs/reinforce.py @@ -2,7 +2,7 @@ import torch -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.envs.utils import step_tensordict from torchrl.modules import TensorDictModule, ProbabilisticTensorDictModule from torchrl.objectives import distance_loss @@ -19,7 +19,7 @@ class ReinforceLoss(LossModule): def __init__( self, actor_network: ProbabilisticTensorDictModule, - advantage_module: Callable[[_TensorDict], _TensorDict], + advantage_module: Callable[[TensorDictBase], TensorDictBase], critic: Optional[TensorDictModule] = None, delay_value: bool = False, gamma: float = 0.99, @@ -61,7 +61,7 @@ def __init__( self.advantage_module = advantage_module - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # get advantage tensordict = self.advantage_module( tensordict, @@ -88,7 +88,7 @@ def forward(self, tensordict: _TensorDict) -> _TensorDict: return td_out - def loss_critic(self, tensordict: _TensorDict) -> torch.Tensor: + def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: if self.advantage_diff_key in tensordict.keys(): advantage_diff = tensordict.get(self.advantage_diff_key) if not advantage_diff.requires_grad: diff --git a/torchrl/objectives/costs/sac.py b/torchrl/objectives/costs/sac.py index 8386a236764..cbb42c5aec3 100644 --- a/torchrl/objectives/costs/sac.py +++ b/torchrl/objectives/costs/sac.py @@ -11,7 +11,7 @@ import torch from torch import Tensor -from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict from torchrl.modules import ProbabilisticActor from torchrl.modules import TensorDictModule from torchrl.modules.tensordict_module.actors import ( @@ -167,7 +167,7 @@ def device(self) -> torch.device: "At least one of the networks of SACLoss must have trainable " "parameters." ) - def forward(self, tensordict: _TensorDict) -> _TensorDict: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if tensordict.ndimension() > 1: tensordict = tensordict.view(-1) @@ -197,7 +197,7 @@ def forward(self, tensordict: _TensorDict) -> _TensorDict: [], ) - def _loss_actor(self, tensordict: _TensorDict) -> Tensor: + def _loss_actor(self, tensordict: TensorDictBase) -> Tensor: # KL lossa with set_exploration_mode("random"): dist = self.actor_network.get_dist( @@ -229,7 +229,7 @@ def _loss_actor(self, tensordict: _TensorDict) -> Tensor: tensordict.set("_log_prob", log_prob.detach()) return self._alpha * log_prob - min_q_logprob - def _loss_qvalue(self, tensordict: _TensorDict) -> Tuple[Tensor, Tensor]: + def _loss_qvalue(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: actor_critic = ActorCriticWrapper(self.actor_network, self.value_network) params = list(self.target_actor_network_params) + list( self.target_value_network_params @@ -283,7 +283,7 @@ def _loss_qvalue(self, tensordict: _TensorDict) -> Tuple[Tensor, Tensor]: return loss_value, priority_value - def _loss_value(self, tensordict: _TensorDict) -> Tensor: + def _loss_value(self, tensordict: TensorDictBase) -> Tensor: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( @@ -328,7 +328,7 @@ def _loss_value(self, tensordict: _TensorDict) -> Tensor: ) return loss_value - def _loss_alpha(self, tensordict: _TensorDict) -> Tensor: + def _loss_alpha(self, tensordict: TensorDictBase) -> Tensor: log_pi = tensordict.get("_log_prob") if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter diff --git a/torchrl/objectives/costs/utils.py b/torchrl/objectives/costs/utils.py index 61a0f7a81ac..a7adbd900e5 100644 --- a/torchrl/objectives/costs/utils.py +++ b/torchrl/objectives/costs/utils.py @@ -11,7 +11,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.utils import step_tensordict from torchrl.modules import TensorDictModule @@ -291,7 +291,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @torch.no_grad() def next_state_value( - tensordict: _TensorDict, + tensordict: TensorDictBase, operator: Optional[TensorDictModule] = None, next_val_key: str = "state_action_value", gamma: float = 0.99, @@ -307,7 +307,7 @@ def next_state_value( from the input tensordict. Args: - tensordict (_TensorDict): Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps + tensordict (TensorDictBase): Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps rewards). operator (ProbabilisticTDModule, optional): the value function operator. Should write a 'next_val_key' key-value in the input tensordict when called. It does not need to be provided if pred_next_val is given. diff --git a/torchrl/objectives/returns/advantages.py b/torchrl/objectives/returns/advantages.py index 3a97696200e..59d88b5b2d4 100644 --- a/torchrl/objectives/returns/advantages.py +++ b/torchrl/objectives/returns/advantages.py @@ -17,7 +17,7 @@ # entropy_loss = entropy_loss + entropy from torch import Tensor -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.utils import step_tensordict from torchrl.modules import TensorDictModule from torchrl.objectives.returns.functional import ( @@ -62,17 +62,17 @@ def __init__( def __call__( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, target_buffers: Optional[List[Tensor]] = None, - ) -> _TensorDict: + ) -> TensorDictBase: """Computes the GAE given the data in tensordict. Args: - tensordict (_TensorDict): A TensorDict containing the data (observation, action, reward, done state) + tensordict (TensorDictBase): A TensorDict containing the data (observation, action, reward, done state) necessary to compute the value estimates and the GAE. Returns: @@ -163,17 +163,17 @@ def __init__( def __call__( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, target_buffers: Optional[List[Tensor]] = None, - ) -> _TensorDict: + ) -> TensorDictBase: """Computes the GAE given the data in tensordict. Args: - tensordict (_TensorDict): A TensorDict containing the data (observation, action, reward, done state) + tensordict (TensorDictBase): A TensorDict containing the data (observation, action, reward, done state) necessary to compute the value estimates and the GAE. Returns: @@ -268,17 +268,17 @@ def __init__( def __call__( self, - tensordict: _TensorDict, + tensordict: TensorDictBase, *unused_args, params: Optional[List[Tensor]] = None, buffers: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, target_buffers: Optional[List[Tensor]] = None, - ) -> _TensorDict: + ) -> TensorDictBase: """Computes the GAE given the data in tensordict. Args: - tensordict (_TensorDict): A TensorDict containing the data (observation, action, reward, done state) + tensordict (TensorDictBase): A TensorDict containing the data (observation, action, reward, done state) necessary to compute the value estimates and the GAE. Returns: diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index fd3d6b6bfd5..99b1b684a07 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -13,7 +13,7 @@ except ImportError: center_crop_fn = None -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.trainers.loggers import Logger @@ -163,7 +163,7 @@ def __init__( self.skip = skip self.count = 0 - def _call(self, td: _TensorDict) -> _TensorDict: + def _call(self, td: TensorDictBase) -> TensorDictBase: self.count += 1 if self.count % self.skip == 0: _td = td diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 3152073ec85..2bc4812fdbb 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -13,7 +13,7 @@ MultiSyncDataCollector, ) from torchrl.data import MultiStep -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.envs import ParallelEnv __all__ = [ @@ -163,7 +163,7 @@ def _make_collector( collector_class: Type, env_fns: Union[Callable, List[Callable]], env_kwargs: Optional[Union[dict, List[dict]]], - policy: Callable[[_TensorDict], _TensorDict], + policy: Callable[[TensorDictBase], TensorDictBase], max_frames_per_traj: int = -1, frames_per_batch: int = 200, total_frames: Optional[int] = None, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 963eb135413..3a230252841 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -30,7 +30,7 @@ TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) -from torchrl.data.tensordict.tensordict import _TensorDict +from torchrl.data.tensordict.tensordict import TensorDictBase from torchrl.data.utils import expand_right, DEVICE_TYPING from torchrl.envs.common import _EnvClass from torchrl.envs.transforms import TransformedEnv @@ -76,7 +76,7 @@ class Trainer: loss module and an optimizer. Args: - collector (Sequence[_TensorDict]): An iterable returning batches of + collector (Sequence[TensorDictBase]): An iterable returning batches of data in a TensorDict form of shape [batch x time steps]. total_frames (int): Total number of frames to be collected during training. @@ -120,7 +120,7 @@ def __init__( collector: _DataCollector, total_frames: int, frame_skip: int, - loss_module: Union[LossModule, Callable[[_TensorDict], _TensorDict]], + loss_module: Union[LossModule, Callable[[TensorDictBase], TensorDictBase]], optimizer: optim.Optimizer, logger: Optional[Logger] = None, optim_steps_per_batch: int = 500, @@ -242,7 +242,9 @@ def collector(self, collector: _DataCollector) -> None: def register_op(self, dest: str, op: Callable, **kwargs) -> None: if dest == "batch_process": - _check_input_output_typehint(op, input=_TensorDict, output=_TensorDict) + _check_input_output_typehint( + op, input=TensorDictBase, output=TensorDictBase + ) self._batch_process_ops.append((op, kwargs)) elif dest == "pre_optim_steps": @@ -250,11 +252,15 @@ def register_op(self, dest: str, op: Callable, **kwargs) -> None: self._pre_optim_ops.append((op, kwargs)) elif dest == "process_optim_batch": - _check_input_output_typehint(op, input=_TensorDict, output=_TensorDict) + _check_input_output_typehint( + op, input=TensorDictBase, output=TensorDictBase + ) self._process_optim_batch_ops.append((op, kwargs)) elif dest == "post_loss": - _check_input_output_typehint(op, input=_TensorDict, output=_TensorDict) + _check_input_output_typehint( + op, input=TensorDictBase, output=TensorDictBase + ) self._post_loss_ops.append((op, kwargs)) elif dest == "post_steps": @@ -267,19 +273,19 @@ def register_op(self, dest: str, op: Callable, **kwargs) -> None: elif dest == "pre_steps_log": _check_input_output_typehint( - op, input=_TensorDict, output=Tuple[str, float] + op, input=TensorDictBase, output=Tuple[str, float] ) self._pre_steps_log_ops.append((op, kwargs)) elif dest == "post_steps_log": _check_input_output_typehint( - op, input=_TensorDict, output=Tuple[str, float] + op, input=TensorDictBase, output=Tuple[str, float] ) self._post_steps_log_ops.append((op, kwargs)) elif dest == "post_optim_log": _check_input_output_typehint( - op, input=_TensorDict, output=Tuple[str, float] + op, input=TensorDictBase, output=Tuple[str, float] ) self._post_optim_log_ops.append((op, kwargs)) @@ -291,10 +297,10 @@ def register_op(self, dest: str, op: Callable, **kwargs) -> None: ) # Process batch - def _process_batch_hook(self, batch: _TensorDict) -> _TensorDict: + def _process_batch_hook(self, batch: TensorDictBase) -> TensorDictBase: for op, kwargs in self._batch_process_ops: out = op(batch, **kwargs) - if isinstance(out, _TensorDict): + if isinstance(out, TensorDictBase): batch = out return batch @@ -302,7 +308,7 @@ def _post_steps_hook(self) -> None: for op, kwargs in self._post_steps_ops: op(**kwargs) - def _post_optim_log(self, batch: _TensorDict) -> None: + def _post_optim_log(self, batch: TensorDictBase) -> None: for op, kwargs in self._post_optim_log_ops: result = op(batch, **kwargs) if result is not None: @@ -315,14 +321,14 @@ def _pre_optim_hook(self): def _process_optim_batch_hook(self, batch): for op, kwargs in self._process_optim_batch_ops: out = op(batch, **kwargs) - if isinstance(out, _TensorDict): + if isinstance(out, TensorDictBase): batch = out return batch def _post_loss_hook(self, batch): for op, kwargs in self._post_loss_ops: out = op(batch, **kwargs) - if isinstance(out, _TensorDict): + if isinstance(out, TensorDictBase): batch = out return batch @@ -330,13 +336,13 @@ def _post_optim_hook(self): for op, kwargs in self._post_optim_ops: op(**kwargs) - def _pre_steps_log_hook(self, batch: _TensorDict) -> None: + def _pre_steps_log_hook(self, batch: TensorDictBase) -> None: for op, kwargs in self._pre_steps_log_ops: result = op(batch, **kwargs) if result is not None: self._log(**result) - def _post_steps_log_hook(self, batch: _TensorDict) -> None: + def _post_steps_log_hook(self, batch: TensorDictBase) -> None: for op, kwargs in self._post_steps_log_ops: result = op(batch, **kwargs) if result is not None: @@ -381,7 +387,7 @@ def shutdown(self): print("shutting down collector") self.collector.shutdown() - def _optimizer_step(self, losses_td: _TensorDict) -> _TensorDict: + def _optimizer_step(self, losses_td: TensorDictBase) -> TensorDictBase: # sum all keys that start with 'loss_' loss = sum([item for key, item in losses_td.items() if key.startswith("loss")]) loss.backward() @@ -391,7 +397,7 @@ def _optimizer_step(self, losses_td: _TensorDict) -> _TensorDict: self.optimizer.zero_grad() return losses_td.detach().set("grad_norm", grad_norm) - def optim_steps(self, batch: _TensorDict) -> None: + def optim_steps(self, batch: TensorDictBase) -> None: # average_grad_norm = 0.0 average_losses = None @@ -410,7 +416,7 @@ def optim_steps(self, batch: _TensorDict) -> None: self._post_optim_log(sub_batch_device) if average_losses is None: - average_losses: _TensorDict = losses_detached + average_losses: TensorDictBase = losses_detached else: for key, item in losses_detached.items(): val = average_losses.get(key) @@ -510,7 +516,7 @@ def __init__(self, keys: Sequence[str]): ) self.keys = keys - def __call__(self, batch: _TensorDict) -> _TensorDict: + def __call__(self, batch: TensorDictBase) -> TensorDictBase: return batch.select(*self.keys) @@ -546,7 +552,7 @@ def __init__( self.memmap = memmap self.device = device - def extend(self, batch: _TensorDict) -> _TensorDict: + def extend(self, batch: TensorDictBase) -> TensorDictBase: if "mask" in batch.keys(): batch = batch[batch.get("mask").squeeze(-1)] else: @@ -559,12 +565,12 @@ def extend(self, batch: _TensorDict) -> _TensorDict: batch = batch.memmap_().to(self.device) self.replay_buffer.extend(batch) - def sample(self, batch: _TensorDict) -> _TensorDict: + def sample(self, batch: TensorDictBase) -> TensorDictBase: sample = self.replay_buffer.sample(self.batch_size) sample = sample.contiguous() return sample.to(self.device) - def update_priority(self, batch: _TensorDict) -> None: + def update_priority(self, batch: TensorDictBase) -> None: if isinstance(self.replay_buffer, TensorDictPrioritizedReplayBuffer): self.replay_buffer.update_priority(batch) @@ -587,7 +593,7 @@ def __init__(self, logname="r_training", log_pbar: bool = False): self.logname = logname self.log_pbar = log_pbar - def __call__(self, batch: _TensorDict) -> Dict: + def __call__(self, batch: TensorDictBase) -> Dict: if "mask" in batch.keys(): return { self.logname: batch.get("reward")[batch.get("mask").squeeze(-1)] @@ -625,7 +631,7 @@ def __init__( self.scale = scale @torch.no_grad() - def update_reward_stats(self, batch: _TensorDict) -> None: + def update_reward_stats(self, batch: TensorDictBase) -> None: reward = batch.get("reward") if "mask" in batch.keys(): reward = reward[batch.get("mask").squeeze(-1)] @@ -656,7 +662,7 @@ def update_reward_stats(self, batch: _TensorDict) -> None: self._reward_stats["std"] = var.clamp_min(1e-6).sqrt() self._update_has_been_called = True - def normalize_reward(self, tensordict: _TensorDict) -> _TensorDict: + def normalize_reward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.to_tensordict() # make sure it is not a SubTensorDict reward = tensordict.get("reward") reward = reward - self._reward_stats["mean"].to(tensordict.device) @@ -666,7 +672,7 @@ def normalize_reward(self, tensordict: _TensorDict) -> _TensorDict: return tensordict -def mask_batch(batch: _TensorDict) -> _TensorDict: +def mask_batch(batch: TensorDictBase) -> TensorDictBase: """Batch masking hook. If a tensordict contained padded trajectories but only single events are @@ -728,7 +734,7 @@ def __init__( self.sub_traj_len = sub_traj_len self.min_sub_traj_len = min_sub_traj_len - def __call__(self, batch: _TensorDict) -> _TensorDict: + def __call__(self, batch: TensorDictBase) -> TensorDictBase: """Sub-sampled part of a batch randomly. If the batch has one dimension, a random subsample of length @@ -872,7 +878,7 @@ def __init__( self.log_pbar = log_pbar @torch.inference_mode() - def __call__(self, batch: _TensorDict) -> Dict: + def __call__(self, batch: TensorDictBase) -> Dict: out = None if self._count % self.record_interval == 0: with set_exploration_mode(self.exploration_mode): @@ -963,7 +969,7 @@ def __init__(self, frame_skip: int, log_pbar: bool = False): self.frame_skip = frame_skip self.log_pbar = log_pbar - def __call__(self, batch: _TensorDict) -> Dict: + def __call__(self, batch: TensorDictBase) -> Dict: if "mask" in batch.keys(): current_frames = batch.get("mask").sum().item() * self.frame_skip else: