Skip to content

Commit 025c107

Browse files
authored
revise cachedataset runtime_cache modes (#5630)
Fixes #5613 ### Description - `runtime_cache=False`: the default v1.0.1 behaviour - `runtime_cache=True` or `"thread"`: single process, for caching cuda tensors - `runtime_cache="process"`: single process workflow + multiprocess dataloader - `runtime_cache=` user-provided object, could be used to pass a container shared among processes I feel in this way the user can determine what to use instead of guessing and providing an automated solution... let me know what you think @myron @Nic-Ma, I'm fine if this is eventually merged or not merged ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 7b41e2e commit 025c107

File tree

5 files changed

+58
-60
lines changed

5 files changed

+58
-60
lines changed

monai/apps/datasets.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class MedNISTDataset(Randomizable, CacheDataset):
7171
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
7272
it may help improve the performance of following logic.
7373
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
74-
the cache content at initializaiton.
74+
the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.
7575
7676
Raises:
7777
ValueError: When ``root_dir`` is not a directory.
@@ -99,7 +99,7 @@ def __init__(
9999
progress: bool = True,
100100
copy_cache: bool = True,
101101
as_contiguous: bool = True,
102-
runtime_cache: bool = False,
102+
runtime_cache=False,
103103
) -> None:
104104
root_dir = Path(root_dir)
105105
if not root_dir.is_dir():
@@ -228,7 +228,7 @@ class DecathlonDataset(Randomizable, CacheDataset):
228228
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
229229
it may help improve the performance of following logic.
230230
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
231-
the cache content at initializaiton.
231+
the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.
232232
233233
Raises:
234234
ValueError: When ``root_dir`` is not a directory.
@@ -296,7 +296,7 @@ def __init__(
296296
progress: bool = True,
297297
copy_cache: bool = True,
298298
as_contiguous: bool = True,
299-
runtime_cache: bool = False,
299+
runtime_cache=False,
300300
) -> None:
301301
root_dir = Path(root_dir)
302302
if not root_dir.is_dir():
@@ -458,7 +458,7 @@ class TciaDataset(Randomizable, CacheDataset):
458458
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
459459
it may help improve the performance of following logic.
460460
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
461-
the cache content at initializaiton.
461+
the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.
462462
463463
Example::
464464
@@ -514,7 +514,7 @@ def __init__(
514514
progress: bool = True,
515515
copy_cache: bool = True,
516516
as_contiguous: bool = True,
517-
runtime_cache: bool = False,
517+
runtime_cache=False,
518518
) -> None:
519519
root_dir = Path(root_dir)
520520
if not root_dir.is_dir():

monai/data/dataloader.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
8080
init_seed = _g.initial_seed()
8181
_seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item()
8282
set_rnd(dataset, int(_seed))
83-
# disable unnecessary multiprocessing caching
84-
from monai.data.dataset import CacheDataset # avoid circular import
85-
86-
if isinstance(dataset, CacheDataset):
87-
dataset.disable_share_memory_cache()
88-
8983
_g.manual_seed(init_seed)
9084
if "collate_fn" not in kwargs:
9185
kwargs["collate_fn"] = list_data_collate

monai/data/dataset.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import numpy as np
2828
import torch
29-
import torch.distributed as dist
3029
from torch.multiprocessing import Manager
3130
from torch.serialization import DEFAULT_PROTOCOL
3231
from torch.utils.data import Dataset as _TorchDataset
@@ -751,7 +750,7 @@ def __init__(
751750
as_contiguous: bool = True,
752751
hash_as_key: bool = False,
753752
hash_func: Callable[..., bytes] = pickle_hashing,
754-
runtime_cache: bool = False,
753+
runtime_cache: Union[bool, str, List, ListProxy] = False,
755754
) -> None:
756755
"""
757756
Args:
@@ -777,18 +776,21 @@ def __init__(
777776
the dataset has duplicated items or augmented dataset.
778777
hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
779778
defaults to `monai.data.utils.pickle_hashing`.
780-
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
781-
the cache content at initialization, if `True`, it will cache during the first epoch
782-
of model training, so it can start the first mini-batch earlier. please note that:
783-
1. when using this option in multi-gpu distributed training,
784-
`torch.cuda.set_device()` must be called before initializing this class.
785-
2. if caching data that is in GPU memory during multi-gpu distributed training, this option
786-
should not be used, since the underlying shared cache only works for CPU shared memory.
787-
3. to execute `runtime cache` on GPU memory, must co-work with
788-
`monai.data.DataLoader`, and can't work with `monai.data.DistributedSampler`
789-
as GPU Tensor usually can't be shared in the multiprocessing context.
790-
(try ``cache_dataset.disable_share_memory_cache()`` in case of GPU caching issues.)
779+
runtime_cache: mode of cache at the runtime. Default to `False` to prepare
780+
the cache content for the entire ``data`` during initialization, this potentially largely increase the
781+
time required between the constructor called and first mini-batch generated.
782+
Three options are provided to compute the cache on the fly after the dataset initialization:
791783
784+
1. ``"threads"`` or ``True``: use a regular ``list`` to store the cache items.
785+
2. ``"processes"``: use a ListProxy to store the cache items, it can be shared among processes.
786+
3. A list-like object: a users-provided container to be used to store the cache items.
787+
788+
For `thread-based` caching (typically for caching cuda tensors), option 1 is recommended.
789+
For single process workflows with multiprocessing data loading, option 2 is recommended.
790+
For multiprocessing workflows (typically for distributed training),
791+
where this class is initialized in subprocesses, option 3 is recommended,
792+
and the list-like object should be prepared in the main process and passed to all subprocesses.
793+
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
792794
793795
"""
794796
if not isinstance(transform, Compose):
@@ -808,10 +810,9 @@ def __init__(
808810
self.cache_num = 0
809811
self._cache: Union[List, ListProxy] = []
810812
self._hash_keys: List = []
811-
self._is_dist = dist.is_available() and dist.is_initialized()
812813
self.set_data(data)
813814

814-
def set_data(self, data: Sequence):
815+
def set_data(self, data: Sequence) -> None:
815816
"""
816817
Set the input data and run deterministic transforms to generate cache content.
817818
@@ -825,44 +826,28 @@ def set_data(self, data: Sequence):
825826
def _compute_cache_num(data_len: int):
826827
self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)
827828

828-
def _compute_cache(indices=None):
829-
if self.runtime_cache:
830-
cache = Manager().list([None for _ in range(self.cache_num)])
831-
if self._is_dist:
832-
obj_list = [cache]
833-
# broadcast the ListProxy to all the ranks, then share the same cache content at runtime
834-
dist.broadcast_object_list(obj_list, src=0)
835-
cache = obj_list[0]
836-
else:
837-
cache = self._fill_cache(indices)
838-
return cache
839-
840829
if self.hash_as_key:
841830
# only compute cache for the unique items of dataset, and record the last index for duplicated items
842-
mapping = {self.hash_func(v): i for i, v in enumerate(data)}
831+
mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
843832
_compute_cache_num(len(mapping))
844833
self._hash_keys = list(mapping)[: self.cache_num]
845834
indices = list(mapping.values())[: self.cache_num]
846835
else:
847836
_compute_cache_num(len(self.data))
848837
indices = list(range(self.cache_num))
849838

850-
self._cache = _compute_cache(indices)
851-
852-
def disable_share_memory_cache(self):
853-
"""
854-
If the cache content is a multiprocessing shared memory ListProxy, convert it to a regular python list.
855-
Because multiprocessing ListProxy is not supported for the GPU caching, explicitly disable it.
856-
857-
"""
858-
if self.runtime_cache:
859-
if not self._is_dist:
860-
self._cache = list(self._cache)
861-
else:
862-
warnings.warn(
863-
"Unable to disable shared cache in DDP, when runtime_cache==True."
864-
"Please use runtime_cache=False option to explicitly not use the shared cache."
865-
)
839+
if self.runtime_cache in (False, None): # prepare cache content immediately
840+
self._cache = self._fill_cache(indices)
841+
return
842+
if isinstance(self.runtime_cache, str) and "process" in self.runtime_cache:
843+
# this must be in the main process, not in dataloader's workers
844+
self._cache = Manager().list([None] * self.cache_num)
845+
return
846+
if (self.runtime_cache is True) or (isinstance(self.runtime_cache, str) and "thread" in self.runtime_cache):
847+
self._cache = [None] * self.cache_num
848+
return
849+
self._cache = self.runtime_cache # type: ignore
850+
return
866851

867852
def _fill_cache(self, indices=None) -> List:
868853
"""
@@ -1006,6 +991,7 @@ class SmartCacheDataset(Randomizable, CacheDataset):
1006991
may set `copy=False` for better performance.
1007992
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
1008993
it may help improve the performance of following logic.
994+
runtime_cache: Default to `False`, other options are not implemented yet.
1009995
1010996
"""
1011997

@@ -1023,7 +1009,7 @@ def __init__(
10231009
seed: int = 0,
10241010
copy_cache: bool = True,
10251011
as_contiguous: bool = True,
1026-
runtime_cache: bool = False,
1012+
runtime_cache=False,
10271013
) -> None:
10281014
if shuffle:
10291015
self.set_random_state(seed=seed)
@@ -1034,8 +1020,20 @@ def __init__(
10341020
self._round: int = 1
10351021
self._replace_done: bool = False
10361022
self._replace_mgr: Optional[threading.Thread] = None
1023+
if runtime_cache is not False:
1024+
raise NotImplementedError("Options other than `runtime_cache=False` is not implemented yet.")
10371025

1038-
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous)
1026+
super().__init__(
1027+
data=data,
1028+
transform=transform,
1029+
cache_num=cache_num,
1030+
cache_rate=cache_rate,
1031+
num_workers=num_init_workers,
1032+
progress=progress,
1033+
copy_cache=copy_cache,
1034+
as_contiguous=as_contiguous,
1035+
runtime_cache=False,
1036+
)
10391037
if self._cache is None:
10401038
self._cache = self._fill_cache()
10411039
if self.cache_num >= len(data):

tests/test_integration_segmentation_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None,
8484
# create a training data loader
8585
if cachedataset == 2:
8686
train_ds = monai.data.CacheDataset(
87-
data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache=True
87+
data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache="process"
8888
)
8989
elif cachedataset == 3:
9090
train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)

tests/test_sampler_dist.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616
import torch.distributed as dist
17+
from torch.multiprocessing import Manager
1718

1819
from monai.data import CacheDataset, DataLoader, DistributedSampler
1920
from monai.transforms import ToTensor
@@ -48,9 +49,14 @@ def test_uneven(self):
4849
@DistCall(nnodes=1, nproc_per_node=2, timeout=120)
4950
def test_cachedataset(self):
5051
data = [1, 2, 3, 4, 5]
51-
dataset = CacheDataset(data=data, transform=ToTensor(track_meta=False), cache_rate=1.0, runtime_cache=True)
52+
obj_list = [Manager().list([None] * len(data))]
53+
dist.broadcast_object_list(obj_list, src=0)
54+
dataset = CacheDataset(
55+
data=data, transform=ToTensor(track_meta=False), cache_rate=1.0, runtime_cache=obj_list[0]
56+
)
5257
sampler = DistributedSampler(dataset=dataset, shuffle=False, even_divisible=False)
5358
dataloader = DataLoader(dataset=dataset, sampler=sampler, batch_size=1, num_workers=2)
59+
dist.barrier()
5460
for i in range(3):
5561
if i > 0:
5662
# verify the runtime cache content is completed after first epoch

0 commit comments

Comments
 (0)