2626
2727import numpy as np
2828import torch
29- import torch .distributed as dist
3029from torch .multiprocessing import Manager
3130from torch .serialization import DEFAULT_PROTOCOL
3231from torch .utils .data import Dataset as _TorchDataset
@@ -751,7 +750,7 @@ def __init__(
751750 as_contiguous : bool = True ,
752751 hash_as_key : bool = False ,
753752 hash_func : Callable [..., bytes ] = pickle_hashing ,
754- runtime_cache : bool = False ,
753+ runtime_cache : Union [ bool , str , List , ListProxy ] = False ,
755754 ) -> None :
756755 """
757756 Args:
@@ -777,18 +776,21 @@ def __init__(
777776 the dataset has duplicated items or augmented dataset.
778777 hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
779778 defaults to `monai.data.utils.pickle_hashing`.
780- runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
781- the cache content at initialization, if `True`, it will cache during the first epoch
782- of model training, so it can start the first mini-batch earlier. please note that:
783- 1. when using this option in multi-gpu distributed training,
784- `torch.cuda.set_device()` must be called before initializing this class.
785- 2. if caching data that is in GPU memory during multi-gpu distributed training, this option
786- should not be used, since the underlying shared cache only works for CPU shared memory.
787- 3. to execute `runtime cache` on GPU memory, must co-work with
788- `monai.data.DataLoader`, and can't work with `monai.data.DistributedSampler`
789- as GPU Tensor usually can't be shared in the multiprocessing context.
790- (try ``cache_dataset.disable_share_memory_cache()`` in case of GPU caching issues.)
779+ runtime_cache: mode of cache at the runtime. Default to `False` to prepare
780+ the cache content for the entire ``data`` during initialization, this potentially largely increase the
781+ time required between the constructor called and first mini-batch generated.
782+ Three options are provided to compute the cache on the fly after the dataset initialization:
791783
784+ 1. ``"threads"`` or ``True``: use a regular ``list`` to store the cache items.
785+ 2. ``"processes"``: use a ListProxy to store the cache items, it can be shared among processes.
786+ 3. A list-like object: a users-provided container to be used to store the cache items.
787+
788+ For `thread-based` caching (typically for caching cuda tensors), option 1 is recommended.
789+ For single process workflows with multiprocessing data loading, option 2 is recommended.
790+ For multiprocessing workflows (typically for distributed training),
791+ where this class is initialized in subprocesses, option 3 is recommended,
792+ and the list-like object should be prepared in the main process and passed to all subprocesses.
793+ Not following these recommendations may lead to runtime errors or duplicated cache across processes.
792794
793795 """
794796 if not isinstance (transform , Compose ):
@@ -808,10 +810,9 @@ def __init__(
808810 self .cache_num = 0
809811 self ._cache : Union [List , ListProxy ] = []
810812 self ._hash_keys : List = []
811- self ._is_dist = dist .is_available () and dist .is_initialized ()
812813 self .set_data (data )
813814
814- def set_data (self , data : Sequence ):
815+ def set_data (self , data : Sequence ) -> None :
815816 """
816817 Set the input data and run deterministic transforms to generate cache content.
817818
@@ -825,44 +826,28 @@ def set_data(self, data: Sequence):
825826 def _compute_cache_num (data_len : int ):
826827 self .cache_num = min (int (self .set_num ), int (data_len * self .set_rate ), data_len )
827828
828- def _compute_cache (indices = None ):
829- if self .runtime_cache :
830- cache = Manager ().list ([None for _ in range (self .cache_num )])
831- if self ._is_dist :
832- obj_list = [cache ]
833- # broadcast the ListProxy to all the ranks, then share the same cache content at runtime
834- dist .broadcast_object_list (obj_list , src = 0 )
835- cache = obj_list [0 ]
836- else :
837- cache = self ._fill_cache (indices )
838- return cache
839-
840829 if self .hash_as_key :
841830 # only compute cache for the unique items of dataset, and record the last index for duplicated items
842- mapping = {self .hash_func (v ): i for i , v in enumerate (data )}
831+ mapping = {self .hash_func (v ): i for i , v in enumerate (self . data )}
843832 _compute_cache_num (len (mapping ))
844833 self ._hash_keys = list (mapping )[: self .cache_num ]
845834 indices = list (mapping .values ())[: self .cache_num ]
846835 else :
847836 _compute_cache_num (len (self .data ))
848837 indices = list (range (self .cache_num ))
849838
850- self ._cache = _compute_cache (indices )
851-
852- def disable_share_memory_cache (self ):
853- """
854- If the cache content is a multiprocessing shared memory ListProxy, convert it to a regular python list.
855- Because multiprocessing ListProxy is not supported for the GPU caching, explicitly disable it.
856-
857- """
858- if self .runtime_cache :
859- if not self ._is_dist :
860- self ._cache = list (self ._cache )
861- else :
862- warnings .warn (
863- "Unable to disable shared cache in DDP, when runtime_cache==True."
864- "Please use runtime_cache=False option to explicitly not use the shared cache."
865- )
839+ if self .runtime_cache in (False , None ): # prepare cache content immediately
840+ self ._cache = self ._fill_cache (indices )
841+ return
842+ if isinstance (self .runtime_cache , str ) and "process" in self .runtime_cache :
843+ # this must be in the main process, not in dataloader's workers
844+ self ._cache = Manager ().list ([None ] * self .cache_num )
845+ return
846+ if (self .runtime_cache is True ) or (isinstance (self .runtime_cache , str ) and "thread" in self .runtime_cache ):
847+ self ._cache = [None ] * self .cache_num
848+ return
849+ self ._cache = self .runtime_cache # type: ignore
850+ return
866851
867852 def _fill_cache (self , indices = None ) -> List :
868853 """
@@ -1006,6 +991,7 @@ class SmartCacheDataset(Randomizable, CacheDataset):
1006991 may set `copy=False` for better performance.
1007992 as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
1008993 it may help improve the performance of following logic.
994+ runtime_cache: Default to `False`, other options are not implemented yet.
1009995
1010996 """
1011997
@@ -1023,7 +1009,7 @@ def __init__(
10231009 seed : int = 0 ,
10241010 copy_cache : bool = True ,
10251011 as_contiguous : bool = True ,
1026- runtime_cache : bool = False ,
1012+ runtime_cache = False ,
10271013 ) -> None :
10281014 if shuffle :
10291015 self .set_random_state (seed = seed )
@@ -1034,8 +1020,20 @@ def __init__(
10341020 self ._round : int = 1
10351021 self ._replace_done : bool = False
10361022 self ._replace_mgr : Optional [threading .Thread ] = None
1023+ if runtime_cache is not False :
1024+ raise NotImplementedError ("Options other than `runtime_cache=False` is not implemented yet." )
10371025
1038- super ().__init__ (data , transform , cache_num , cache_rate , num_init_workers , progress , copy_cache , as_contiguous )
1026+ super ().__init__ (
1027+ data = data ,
1028+ transform = transform ,
1029+ cache_num = cache_num ,
1030+ cache_rate = cache_rate ,
1031+ num_workers = num_init_workers ,
1032+ progress = progress ,
1033+ copy_cache = copy_cache ,
1034+ as_contiguous = as_contiguous ,
1035+ runtime_cache = False ,
1036+ )
10391037 if self ._cache is None :
10401038 self ._cache = self ._fill_cache ()
10411039 if self .cache_num >= len (data ):
0 commit comments