1- from enum import IntEnum
2- from typing import Any , Dict , List , Optional , Tuple , Union
1+ from enum import Enum
2+ from functools import partial
3+ from typing import List , NamedTuple , Optional , Tuple , Union
34
45import numpy as np
56
1213 train_test_split
1314)
1415
15- from typing_extensions import Protocol
16+ from torch . utils . data import Dataset
1617
1718
18- # Use callback protocol as workaround, since callable with function fields count 'self' as argument
19- class CrossValFunc (Protocol ):
20- def __call__ (self ,
21- random_state : np .random .RandomState ,
22- num_splits : int ,
23- indices : np .ndarray ,
24- stratify : Optional [Any ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
25- ...
19+ class _ResamplingStrategyArgs (NamedTuple ):
20+ val_share : float = 0.33
21+ num_splits : int = 5
22+ shuffle : bool = False
23+ stratify : bool = False
2624
2725
28- class HoldOutFunc (Protocol ):
29- def __call__ (self , random_state : np .random .RandomState , val_share : float ,
30- indices : np .ndarray , stratify : Optional [Any ]
31- ) -> Tuple [np .ndarray , np .ndarray ]:
32- ...
33-
34-
35- class CrossValTypes (IntEnum ):
36- """The type of cross validation
37-
38- This class is used to specify the cross validation function
39- and is not supposed to be instantiated.
40-
41- Examples: This class is supposed to be used as follows
42- >>> cv_type = CrossValTypes.k_fold_cross_validation
43- >>> print(cv_type.name)
44-
45- k_fold_cross_validation
46-
47- >>> for cross_val_type in CrossValTypes:
48- print(cross_val_type.name, cross_val_type.value)
49-
50- stratified_k_fold_cross_validation 1
51- k_fold_cross_validation 2
52- stratified_shuffle_split_cross_validation 3
53- shuffle_split_cross_validation 4
54- time_series_cross_validation 5
55- """
56- stratified_k_fold_cross_validation = 1
57- k_fold_cross_validation = 2
58- stratified_shuffle_split_cross_validation = 3
59- shuffle_split_cross_validation = 4
60- time_series_cross_validation = 5
61-
62- def is_stratified (self ) -> bool :
63- stratified = [self .stratified_k_fold_cross_validation ,
64- self .stratified_shuffle_split_cross_validation ]
65- return getattr (self , self .name ) in stratified
66-
67-
68- class HoldoutValTypes (IntEnum ):
69- """TODO: change to enum using functools.partial"""
70- """The type of hold out validation (refer to CrossValTypes' doc-string)"""
71- holdout_validation = 6
72- stratified_holdout_validation = 7
73-
74- def is_stratified (self ) -> bool :
75- stratified = [self .stratified_holdout_validation ]
76- return getattr (self , self .name ) in stratified
77-
78-
79- # TODO: replace it with another way
80- RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
81-
82- DEFAULT_RESAMPLING_PARAMETERS = {
83- HoldoutValTypes .holdout_validation : {
84- 'val_share' : 0.33 ,
85- },
86- HoldoutValTypes .stratified_holdout_validation : {
87- 'val_share' : 0.33 ,
88- },
89- CrossValTypes .k_fold_cross_validation : {
90- 'num_splits' : 5 ,
91- },
92- CrossValTypes .stratified_k_fold_cross_validation : {
93- 'num_splits' : 5 ,
94- },
95- CrossValTypes .shuffle_split_cross_validation : {
96- 'num_splits' : 5 ,
97- },
98- CrossValTypes .time_series_cross_validation : {
99- 'num_splits' : 5 ,
100- },
101- } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
102-
103-
104- class HoldOutFuncs ():
26+ class HoldoutFuncs ():
10527 @staticmethod
106- def holdout_validation (random_state : np .random .RandomState ,
107- val_share : float ,
108- indices : np .ndarray ,
109- ** kwargs : Any
110- ) -> Tuple [np .ndarray , np .ndarray ]:
111- shuffle = kwargs .get ('shuffle' , True )
112- train , val = train_test_split (indices , test_size = val_share ,
113- shuffle = shuffle ,
114- random_state = random_state if shuffle else None ,
115- )
28+ def holdout_validation (
29+ random_state : np .random .RandomState ,
30+ val_share : float ,
31+ indices : np .ndarray ,
32+ shuffle : bool = False ,
33+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
34+ ):
35+
36+ train , val = train_test_split (
37+ indices , test_size = val_share , shuffle = shuffle ,
38+ random_state = random_state if shuffle else None ,
39+ stratify = labels_to_stratify
40+ )
11641 return train , val
11742
118- @staticmethod
119- def stratified_holdout_validation (random_state : np .random .RandomState ,
120- val_share : float ,
121- indices : np .ndarray ,
122- ** kwargs : Any
123- ) -> Tuple [np .ndarray , np .ndarray ]:
124- train , val = train_test_split (indices , test_size = val_share , shuffle = True , stratify = kwargs ["stratify" ],
125- random_state = random_state )
126- return train , val
127-
128- @classmethod
129- def get_holdout_validators (cls , * holdout_val_types : HoldoutValTypes ) -> Dict [str , HoldOutFunc ]:
130-
131- holdout_validators = {
132- holdout_val_type .name : getattr (cls , holdout_val_type .name )
133- for holdout_val_type in holdout_val_types
134- }
135- return holdout_validators
136-
13743
13844class CrossValFuncs ():
139- @staticmethod
140- def shuffle_split_cross_validation (random_state : np .random .RandomState ,
141- num_splits : int ,
142- indices : np .ndarray ,
143- ** kwargs : Any
144- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
145- cv = ShuffleSplit (n_splits = num_splits , random_state = random_state )
146- splits = list (cv .split (indices ))
147- return splits
148-
149- @staticmethod
150- def stratified_shuffle_split_cross_validation (random_state : np .random .RandomState ,
151- num_splits : int ,
152- indices : np .ndarray ,
153- ** kwargs : Any
154- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
155- cv = StratifiedShuffleSplit (n_splits = num_splits , random_state = random_state )
156- splits = list (cv .split (indices , kwargs ["stratify" ]))
157- return splits
158-
159- @staticmethod
160- def stratified_k_fold_cross_validation (random_state : np .random .RandomState ,
161- num_splits : int ,
162- indices : np .ndarray ,
163- ** kwargs : Any
164- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
165- cv = StratifiedKFold (n_splits = num_splits , random_state = random_state )
166- splits = list (cv .split (indices , kwargs ["stratify" ]))
167- return splits
45+ # (shuffle, is_stratify) -> split_fn
46+ _args2split_fn = {
47+ (True , True ): StratifiedShuffleSplit ,
48+ (True , False ): ShuffleSplit ,
49+ (False , True ): StratifiedKFold ,
50+ (False , False ): KFold ,
51+ }
16852
16953 @staticmethod
170- def k_fold_cross_validation (random_state : np .random .RandomState ,
171- num_splits : int ,
172- indices : np .ndarray ,
173- ** kwargs : Any
174- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
54+ def k_fold_cross_validation (
55+ random_state : np .random .RandomState ,
56+ num_splits : int ,
57+ indices : np .ndarray ,
58+ shuffle : bool = False ,
59+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
60+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
17561 """
176- Standard k fold cross validation.
177-
178- Args:
179- indices (np.ndarray): array of indices to be split
180- num_splits (int): number of cross validation splits
181-
18262 Returns:
18363 splits (List[Tuple[List, List]]): list of tuples of training and validation indices
18464 """
185- shuffle = kwargs .get ('shuffle' , True )
186- cv = KFold (n_splits = num_splits , random_state = random_state if shuffle else None , shuffle = shuffle )
65+
66+ split_fn = CrossValFuncs ._args2split_fn [(shuffle , labels_to_stratify is not None )]
67+ cv = split_fn (n_splits = num_splits , random_state = random_state )
18768 splits = list (cv .split (indices ))
18869 return splits
18970
19071 @staticmethod
191- def time_series_cross_validation (random_state : np .random .RandomState ,
192- num_splits : int ,
193- indices : np .ndarray ,
194- ** kwargs : Any
195- ) -> List [Tuple [np .ndarray , np .ndarray ]]:
72+ def time_series (
73+ random_state : np .random .RandomState ,
74+ num_splits : int ,
75+ indices : np .ndarray ,
76+ shuffle : bool = False ,
77+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
78+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
19679 """
19780 Returns train and validation indices respecting the temporal ordering of the data.
19881
@@ -215,10 +98,115 @@ def time_series_cross_validation(random_state: np.random.RandomState,
21598 splits = list (cv .split (indices ))
21699 return splits
217100
218- @classmethod
219- def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , CrossValFunc ]:
220- cross_validators = {
221- cross_val_type .name : getattr (cls , cross_val_type .name )
222- for cross_val_type in cross_val_types
223- }
224- return cross_validators
101+
102+ class CrossValTypes (Enum ):
103+ """The type of cross validation
104+
105+ This class is used to specify the cross validation function
106+ and is not supposed to be instantiated.
107+
108+ Examples: This class is supposed to be used as follows
109+ >>> cv_type = CrossValTypes.k_fold_cross_validation
110+ >>> print(cv_type.name)
111+
112+ k_fold_cross_validation
113+
114+ >>> for cross_val_type in CrossValTypes:
115+ print(cross_val_type.name, cross_val_type.value)
116+
117+ k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
118+ time_series <function CrossValFuncs.time_series>
119+ """
120+ k_fold_cross_validation = partial (CrossValFuncs .k_fold_cross_validation )
121+ time_series = partial (CrossValFuncs .time_series )
122+
123+ def __call__ (
124+ self ,
125+ random_state : np .random .RandomState ,
126+ indices : np .ndarray ,
127+ num_splits : int = 5 ,
128+ shuffle : bool = False ,
129+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
130+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
131+ """
132+ This function allows to call and type-check the specified function.
133+
134+ Args:
135+ random_state (np.random.RandomState): random number genetor for the reproducibility
136+ num_splits (int): The number of splits in cross validation
137+ indices (np.ndarray): The indices of data points in a dataset
138+ shuffle (bool): If shuffle the indices or not
139+ labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
140+ The labels of the corresponding data points. It is used for the stratification.
141+
142+ Returns:
143+ splits (List[Tuple[np.ndarray, np.ndarray]]):
144+ splits[a split identifier][0: train, 1: val][a data point identifier]
145+
146+ """
147+ return self .value (
148+ random_state = random_state ,
149+ num_splits = num_splits ,
150+ indices = indices ,
151+ shuffle = shuffle ,
152+ labels_to_stratify = labels_to_stratify
153+ )
154+
155+
156+ class HoldoutValTypes (Enum ):
157+ """The type of holdout validation
158+
159+ This class is used to specify the holdout validation function
160+ and is not supposed to be instantiated.
161+
162+ Examples: This class is supposed to be used as follows
163+ >>> holdout_type = HoldoutValTypes.holdout_validation
164+ >>> print(holdout_type.name)
165+
166+ holdout_validation
167+
168+ >>> print(holdout_type.value)
169+
170+ functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
171+
172+ >>> for holdout_type in HoldoutValTypes:
173+ print(holdout_type.name)
174+
175+ holdout_validation
176+
177+ Additionally, HoldoutValTypes.<function> can be called directly.
178+ """
179+
180+ holdout = partial (HoldoutFuncs .holdout_validation )
181+
182+ def __call__ (
183+ self ,
184+ random_state : np .random .RandomState ,
185+ indices : np .ndarray ,
186+ val_share : float = 0.33 ,
187+ shuffle : bool = False ,
188+ labels_to_stratify : Optional [Union [Tuple [np .ndarray , np .ndarray ], Dataset ]] = None
189+ ) -> List [Tuple [np .ndarray , np .ndarray ]]:
190+ """
191+ This function allows to call and type-check the specified function.
192+
193+ Args:
194+ random_state (np.random.RandomState): random number genetor for the reproducibility
195+ val_share (float): The ratio of validation dataset vs the given dataset
196+ indices (np.ndarray): The indices of data points in a dataset
197+ shuffle (bool): If shuffle the indices or not
198+ labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
199+ The labels of the corresponding data points. It is used for the stratification.
200+
201+ Returns:
202+ splits (List[Tuple[np.ndarray, np.ndarray]]):
203+ splits[a split identifier][0: train, 1: val][a data point identifier]
204+
205+ """
206+ return self .value (
207+ random_state = random_state ,
208+ val_share = val_share ,
209+ indices = indices ,
210+ shuffle = shuffle ,
211+ labels_to_stratify = labels_to_stratify
212+ )
0 commit comments