diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 2e0eeb80705..bf5f18bb375 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -131,6 +131,9 @@ class ReplayBuffer: .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -216,11 +219,20 @@ def __init__( checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 generator: torch.Generator | None = None, shared: bool = False, + compilable: bool = None, ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) + if compilable is not None: + self._storage._compilable = compilable + self._storage._len = self._storage._len + self._sampler = sampler if sampler is not None else RandomSampler() - self._writer = writer if writer is not None else RoundRobinWriter() + self._writer = ( + writer + if writer is not None + else RoundRobinWriter(compilable=bool(compilable)) + ) self._writer.register_storage(self._storage) self._get_collate_fn(collate_fn) @@ -601,7 +613,9 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - with self._replay_lock, self._write_lock: + is_compiling = torch.compiler.is_dynamo_compiling() + nc = contextlib.nullcontext() + with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) @@ -654,7 +668,7 @@ def update_priority( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: + with self._replay_lock if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) @@ -1753,6 +1767,7 @@ def __init__( num_buffer_sampled: int | None = None, generator: torch.Generator | None = None, shared: bool = False, + compilable: bool = False, **kwargs, ): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 20b2169cc8e..9880f89db84 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -57,10 +57,15 @@ class Storage: _rng: torch.Generator | None = None def __init__( - self, max_size: int, checkpointer: StorageCheckpointerBase | None = None + self, + max_size: int, + checkpointer: StorageCheckpointerBase | None = None, + compilable: bool = False, ) -> None: self.max_size = int(max_size) self.checkpointer = checkpointer + self._compilable = compilable + self._attached_entities_set = set() @property def checkpointer(self): @@ -80,11 +85,11 @@ def _is_full(self): def _attached_entities(self): # RBs that use a given instance of Storage should add # themselves to this set. - _attached_entities = self.__dict__.get("_attached_entities_set", None) - if _attached_entities is None: - _attached_entities = set() - self.__dict__["_attached_entities_set"] = _attached_entities - return _attached_entities + return getattr(self, "_attached_entities_set", None) + + @torch._dynamo.assume_constant_result + def _attached_entities_iter(self): + return list(self._attached_entities) @abc.abstractmethod def set(self, cursor: int, data: Any, *, set_cursor: bool = True): @@ -140,6 +145,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): ... + @torch._dynamo.disable() def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim if self.ndim == 1: @@ -330,6 +336,9 @@ class TensorStorage(Storage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -389,6 +398,7 @@ def __init__( *, device: torch.device = "cpu", ndim: int = 1, + compilable: bool = False, ): if not ((storage is None) ^ (max_size is None)): if storage is None: @@ -404,7 +414,7 @@ def __init__( else: max_size = tree_flatten(storage)[0][0].shape[0] self.ndim = ndim - super().__init__(max_size) + super().__init__(max_size, compilable=compilable) self.initialized = storage is not None if self.initialized: self._len = max_size @@ -423,16 +433,23 @@ def __init__( @property def _len(self): _len_value = self.__dict__.get("_len_value", None) + if not self._compilable or not isinstance(self._len_value, int): + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + return _len_value.value if _len_value is None: - _len_value = self._len_value = mp.Value("i", 0) - return _len_value.value + _len_value = self._len_value = 0 + return _len_value @_len.setter def _len(self, value): _len_value = self.__dict__.get("_len_value", None) - if _len_value is None: - _len_value = self._len_value = mp.Value("i", 0) - _len_value.value = value + if not self._compilable: + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + _len_value.value = value + else: + self._len_value = value @property def _total_shape(self): @@ -1184,9 +1201,9 @@ def _rng(self, value): for storage in self._storages: storage._rng = value - @property - def _attached_entities(self): - return set() + # @property + # def _attached_entities(self): + # return set() def extend(self, value): raise RuntimeError diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 3a95c3975cc..bd1efcb21e5 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -40,8 +40,9 @@ class Writer(ABC): _storage: Storage _rng: torch.Generator | None = None - def __init__(self) -> None: + def __init__(self, compilable: bool = False) -> None: self._storage = None + self._compilable = compilable def register_storage(self, storage: Storage) -> None: self._storage = storage @@ -138,10 +139,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class RoundRobinWriter(Writer): - """A RoundRobin Writer class for composable replay buffers.""" + """A RoundRobin Writer class for composable replay buffers. - def __init__(self, **kw) -> None: - super().__init__(**kw) + Args: + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + + """ + + def __init__(self, compilable: bool = False) -> None: + super().__init__(compilable=compilable) self._cursor = 0 def dumps(self, path): @@ -197,7 +205,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -213,30 +221,44 @@ def _empty(self): @property def _cursor(self): _cursor_value = self.__dict__.get("_cursor_value", None) + if not self._compilable or not isinstance(_cursor_value, int): + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + return _cursor_value.value if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value("i", 0) - return _cursor_value.value + _cursor_value = self._cursor_value = 0 + return _cursor_value @_cursor.setter def _cursor(self, value): - _cursor_value = self.__dict__.get("_cursor_value", None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value("i", 0) - _cursor_value.value = value + if not self._compilable: + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + _cursor_value.value = value + else: + self._cursor_value = value @property def _write_count(self): _write_count = self.__dict__.get("_write_count_value", None) + if not self._compilable or not isinstance(_write_count, int): + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + return _write_count.value if _write_count is None: - _write_count = self._write_count_value = mp.Value("i", 0) - return _write_count.value + _write_count = self._write_count_value = 0 + return _write_count @_write_count.setter def _write_count(self, value): - _write_count = self.__dict__.get("_write_count_value", None) - if _write_count is None: - _write_count = self._write_count_value = mp.Value("i", 0) - _write_count.value = value + if not self._compilable: + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + _write_count.value = value + else: + self._write_count_value = value def __getstate__(self): state = super().__getstate__() @@ -248,7 +270,7 @@ def __getstate__(self): def __setstate__(self, state): cursor = state.pop("cursor__context", None) - if cursor is not None: + if not state["_compilable"] and cursor is not None: _cursor_value = mp.Value("i", cursor) state["_cursor_value"] = _cursor_value self.__dict__.update(state)