Skip to content

Commit

Permalink
revise cachedataset runtime_cache modes (#5630)
Browse files Browse the repository at this point in the history
Fixes #5613

### Description

- `runtime_cache=False`: the default v1.0.1 behaviour
- `runtime_cache=True` or `"thread"`: single process, for caching cuda
tensors
- `runtime_cache="process"`: single process workflow + multiprocess
dataloader
- `runtime_cache=` user-provided object, could be used to pass a
container shared among processes

I feel in this way the user can determine what to use instead of
guessing and providing an automated solution... let me know what you
think @myron @Nic-Ma, I'm fine if this is eventually merged or not
merged

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Dec 2, 2022
1 parent 7b41e2e commit 025c107
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 60 deletions.
12 changes: 6 additions & 6 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class MedNISTDataset(Randomizable, CacheDataset):
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
the cache content at initializaiton.
the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.
Raises:
ValueError: When ``root_dir`` is not a directory.
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
runtime_cache: bool = False,
runtime_cache=False,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand Down Expand Up @@ -228,7 +228,7 @@ class DecathlonDataset(Randomizable, CacheDataset):
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
the cache content at initializaiton.
the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.
Raises:
ValueError: When ``root_dir`` is not a directory.
Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
runtime_cache: bool = False,
runtime_cache=False,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand Down Expand Up @@ -458,7 +458,7 @@ class TciaDataset(Randomizable, CacheDataset):
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
the cache content at initializaiton.
the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.
Example::
Expand Down Expand Up @@ -514,7 +514,7 @@ def __init__(
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
runtime_cache: bool = False,
runtime_cache=False,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand Down
6 changes: 0 additions & 6 deletions monai/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
init_seed = _g.initial_seed()
_seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item()
set_rnd(dataset, int(_seed))
# disable unnecessary multiprocessing caching
from monai.data.dataset import CacheDataset # avoid circular import

if isinstance(dataset, CacheDataset):
dataset.disable_share_memory_cache()

_g.manual_seed(init_seed)
if "collate_fn" not in kwargs:
kwargs["collate_fn"] = list_data_collate
Expand Down
90 changes: 44 additions & 46 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from torch.serialization import DEFAULT_PROTOCOL
from torch.utils.data import Dataset as _TorchDataset
Expand Down Expand Up @@ -751,7 +750,7 @@ def __init__(
as_contiguous: bool = True,
hash_as_key: bool = False,
hash_func: Callable[..., bytes] = pickle_hashing,
runtime_cache: bool = False,
runtime_cache: Union[bool, str, List, ListProxy] = False,
) -> None:
"""
Args:
Expand All @@ -777,18 +776,21 @@ def __init__(
the dataset has duplicated items or augmented dataset.
hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
the cache content at initialization, if `True`, it will cache during the first epoch
of model training, so it can start the first mini-batch earlier. please note that:
1. when using this option in multi-gpu distributed training,
`torch.cuda.set_device()` must be called before initializing this class.
2. if caching data that is in GPU memory during multi-gpu distributed training, this option
should not be used, since the underlying shared cache only works for CPU shared memory.
3. to execute `runtime cache` on GPU memory, must co-work with
`monai.data.DataLoader`, and can't work with `monai.data.DistributedSampler`
as GPU Tensor usually can't be shared in the multiprocessing context.
(try ``cache_dataset.disable_share_memory_cache()`` in case of GPU caching issues.)
runtime_cache: mode of cache at the runtime. Default to `False` to prepare
the cache content for the entire ``data`` during initialization, this potentially largely increase the
time required between the constructor called and first mini-batch generated.
Three options are provided to compute the cache on the fly after the dataset initialization:
1. ``"threads"`` or ``True``: use a regular ``list`` to store the cache items.
2. ``"processes"``: use a ListProxy to store the cache items, it can be shared among processes.
3. A list-like object: a users-provided container to be used to store the cache items.
For `thread-based` caching (typically for caching cuda tensors), option 1 is recommended.
For single process workflows with multiprocessing data loading, option 2 is recommended.
For multiprocessing workflows (typically for distributed training),
where this class is initialized in subprocesses, option 3 is recommended,
and the list-like object should be prepared in the main process and passed to all subprocesses.
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
"""
if not isinstance(transform, Compose):
Expand All @@ -808,10 +810,9 @@ def __init__(
self.cache_num = 0
self._cache: Union[List, ListProxy] = []
self._hash_keys: List = []
self._is_dist = dist.is_available() and dist.is_initialized()
self.set_data(data)

def set_data(self, data: Sequence):
def set_data(self, data: Sequence) -> None:
"""
Set the input data and run deterministic transforms to generate cache content.
Expand All @@ -825,44 +826,28 @@ def set_data(self, data: Sequence):
def _compute_cache_num(data_len: int):
self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)

def _compute_cache(indices=None):
if self.runtime_cache:
cache = Manager().list([None for _ in range(self.cache_num)])
if self._is_dist:
obj_list = [cache]
# broadcast the ListProxy to all the ranks, then share the same cache content at runtime
dist.broadcast_object_list(obj_list, src=0)
cache = obj_list[0]
else:
cache = self._fill_cache(indices)
return cache

if self.hash_as_key:
# only compute cache for the unique items of dataset, and record the last index for duplicated items
mapping = {self.hash_func(v): i for i, v in enumerate(data)}
mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
_compute_cache_num(len(mapping))
self._hash_keys = list(mapping)[: self.cache_num]
indices = list(mapping.values())[: self.cache_num]
else:
_compute_cache_num(len(self.data))
indices = list(range(self.cache_num))

self._cache = _compute_cache(indices)

def disable_share_memory_cache(self):
"""
If the cache content is a multiprocessing shared memory ListProxy, convert it to a regular python list.
Because multiprocessing ListProxy is not supported for the GPU caching, explicitly disable it.
"""
if self.runtime_cache:
if not self._is_dist:
self._cache = list(self._cache)
else:
warnings.warn(
"Unable to disable shared cache in DDP, when runtime_cache==True."
"Please use runtime_cache=False option to explicitly not use the shared cache."
)
if self.runtime_cache in (False, None): # prepare cache content immediately
self._cache = self._fill_cache(indices)
return
if isinstance(self.runtime_cache, str) and "process" in self.runtime_cache:
# this must be in the main process, not in dataloader's workers
self._cache = Manager().list([None] * self.cache_num)
return
if (self.runtime_cache is True) or (isinstance(self.runtime_cache, str) and "thread" in self.runtime_cache):
self._cache = [None] * self.cache_num
return
self._cache = self.runtime_cache # type: ignore
return

def _fill_cache(self, indices=None) -> List:
"""
Expand Down Expand Up @@ -1006,6 +991,7 @@ class SmartCacheDataset(Randomizable, CacheDataset):
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
runtime_cache: Default to `False`, other options are not implemented yet.
"""

Expand All @@ -1023,7 +1009,7 @@ def __init__(
seed: int = 0,
copy_cache: bool = True,
as_contiguous: bool = True,
runtime_cache: bool = False,
runtime_cache=False,
) -> None:
if shuffle:
self.set_random_state(seed=seed)
Expand All @@ -1034,8 +1020,20 @@ def __init__(
self._round: int = 1
self._replace_done: bool = False
self._replace_mgr: Optional[threading.Thread] = None
if runtime_cache is not False:
raise NotImplementedError("Options other than `runtime_cache=False` is not implemented yet.")

super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous)
super().__init__(
data=data,
transform=transform,
cache_num=cache_num,
cache_rate=cache_rate,
num_workers=num_init_workers,
progress=progress,
copy_cache=copy_cache,
as_contiguous=as_contiguous,
runtime_cache=False,
)
if self._cache is None:
self._cache = self._fill_cache()
if self.cache_num >= len(data):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration_segmentation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None,
# create a training data loader
if cachedataset == 2:
train_ds = monai.data.CacheDataset(
data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache=True
data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache="process"
)
elif cachedataset == 3:
train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sampler_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager

from monai.data import CacheDataset, DataLoader, DistributedSampler
from monai.transforms import ToTensor
Expand Down Expand Up @@ -48,9 +49,14 @@ def test_uneven(self):
@DistCall(nnodes=1, nproc_per_node=2, timeout=120)
def test_cachedataset(self):
data = [1, 2, 3, 4, 5]
dataset = CacheDataset(data=data, transform=ToTensor(track_meta=False), cache_rate=1.0, runtime_cache=True)
obj_list = [Manager().list([None] * len(data))]
dist.broadcast_object_list(obj_list, src=0)
dataset = CacheDataset(
data=data, transform=ToTensor(track_meta=False), cache_rate=1.0, runtime_cache=obj_list[0]
)
sampler = DistributedSampler(dataset=dataset, shuffle=False, even_divisible=False)
dataloader = DataLoader(dataset=dataset, sampler=sampler, batch_size=1, num_workers=2)
dist.barrier()
for i in range(3):
if i > 0:
# verify the runtime cache content is completed after first epoch
Expand Down

0 comments on commit 025c107

Please sign in to comment.