Skip to content

Commit

Permalink
[Feature] RB compability with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 803de44200c0e113df8abe17b059265ea794c627
Pull Request resolved: #2426
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 4e618a7 commit fb92c14
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 36 deletions.
21 changes: 18 additions & 3 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class ReplayBuffer:
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -216,11 +219,20 @@ def __init__(
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
if compilable is not None:
self._storage._compilable = compilable
self._storage._len = self._storage._len

self._sampler = sampler if sampler is not None else RandomSampler()
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer = (
writer
if writer is not None
else RoundRobinWriter(compilable=bool(compilable))
)
self._writer.register_storage(self._storage)

self._get_collate_fn(collate_fn)
Expand Down Expand Up @@ -601,7 +613,9 @@ def _add(self, data):
return index

def _extend(self, data: Sequence) -> torch.Tensor:
with self._replay_lock, self._write_lock:
is_compiling = torch.compiler.is_dynamo_compiling()
nc = contextlib.nullcontext()
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
if self.dim_extend > 0:
data = self._transpose(data)
index = self._writer.extend(data)
Expand Down Expand Up @@ -654,7 +668,7 @@ def update_priority(

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
with self._replay_lock if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext():
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
Expand Down Expand Up @@ -1753,6 +1767,7 @@ def __init__(
num_buffer_sampled: int | None = None,
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = False,
**kwargs,
):

Expand Down
47 changes: 32 additions & 15 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ class Storage:
_rng: torch.Generator | None = None

def __init__(
self, max_size: int, checkpointer: StorageCheckpointerBase | None = None
self,
max_size: int,
checkpointer: StorageCheckpointerBase | None = None,
compilable: bool = False,
) -> None:
self.max_size = int(max_size)
self.checkpointer = checkpointer
self._compilable = compilable
self._attached_entities_set = set()

@property
def checkpointer(self):
Expand All @@ -80,11 +85,11 @@ def _is_full(self):
def _attached_entities(self):
# RBs that use a given instance of Storage should add
# themselves to this set.
_attached_entities = self.__dict__.get("_attached_entities_set", None)
if _attached_entities is None:
_attached_entities = set()
self.__dict__["_attached_entities_set"] = _attached_entities
return _attached_entities
return getattr(self, "_attached_entities_set", None)

@torch._dynamo.assume_constant_result
def _attached_entities_iter(self):
return list(self._attached_entities)

@abc.abstractmethod
def set(self, cursor: int, data: Any, *, set_cursor: bool = True):
Expand Down Expand Up @@ -140,6 +145,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _empty(self):
...

@torch._dynamo.disable()
def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
Expand Down Expand Up @@ -330,6 +336,9 @@ class TensorStorage(Storage):
measuring the storage size. For instance, a storage of shape ``[3, 4]``
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
Defaults to ``1``.
compilable (bool, optional): whether the storage is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -389,6 +398,7 @@ def __init__(
*,
device: torch.device = "cpu",
ndim: int = 1,
compilable: bool = False,
):
if not ((storage is None) ^ (max_size is None)):
if storage is None:
Expand All @@ -404,7 +414,7 @@ def __init__(
else:
max_size = tree_flatten(storage)[0][0].shape[0]
self.ndim = ndim
super().__init__(max_size)
super().__init__(max_size, compilable=compilable)
self.initialized = storage is not None
if self.initialized:
self._len = max_size
Expand All @@ -423,16 +433,23 @@ def __init__(
@property
def _len(self):
_len_value = self.__dict__.get("_len_value", None)
if not self._compilable or not isinstance(self._len_value, int):
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
return _len_value.value
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
return _len_value.value
_len_value = self._len_value = 0
return _len_value

@_len.setter
def _len(self, value):
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value
if not self._compilable:
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value
else:
self._len_value = value

@property
def _total_shape(self):
Expand Down Expand Up @@ -1184,9 +1201,9 @@ def _rng(self, value):
for storage in self._storages:
storage._rng = value

@property
def _attached_entities(self):
return set()
# @property
# def _attached_entities(self):
# return set()

def extend(self, value):
raise RuntimeError
Expand Down
58 changes: 40 additions & 18 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class Writer(ABC):
_storage: Storage
_rng: torch.Generator | None = None

def __init__(self) -> None:
def __init__(self, compilable: bool = False) -> None:
self._storage = None
self._compilable = compilable

def register_storage(self, storage: Storage) -> None:
self._storage = storage
Expand Down Expand Up @@ -138,10 +139,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:


class RoundRobinWriter(Writer):
"""A RoundRobin Writer class for composable replay buffers."""
"""A RoundRobin Writer class for composable replay buffers.
def __init__(self, **kw) -> None:
super().__init__(**kw)
Args:
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
"""

def __init__(self, compilable: bool = False) -> None:
super().__init__(compilable=compilable)
self._cursor = 0

def dumps(self, path):
Expand Down Expand Up @@ -197,7 +205,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
index = self._replicate_index(index)
for ent in self._storage._attached_entities:
for ent in self._storage._attached_entities_iter():
ent.mark_update(index)
return index

Expand All @@ -213,30 +221,44 @@ def _empty(self):
@property
def _cursor(self):
_cursor_value = self.__dict__.get("_cursor_value", None)
if not self._compilable or not isinstance(_cursor_value, int):
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
return _cursor_value.value
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
return _cursor_value.value
_cursor_value = self._cursor_value = 0
return _cursor_value

@_cursor.setter
def _cursor(self, value):
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value
if not self._compilable:
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value
else:
self._cursor_value = value

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
if not self._compilable or not isinstance(_write_count, int):
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value
_write_count = self._write_count_value = 0
return _write_count

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value
if not self._compilable:
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value
else:
self._write_count_value = value

def __getstate__(self):
state = super().__getstate__()
Expand All @@ -248,7 +270,7 @@ def __getstate__(self):

def __setstate__(self, state):
cursor = state.pop("cursor__context", None)
if cursor is not None:
if not state["_compilable"] and cursor is not None:
_cursor_value = mp.Value("i", cursor)
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)
Expand Down

0 comments on commit fb92c14

Please sign in to comment.