1515import torch
1616import math
1717import numbers
18- from enum import IntEnum
19-
20-
21- class MixupMode (IntEnum ):
22- MIXUP = 0
23- CUTMIX = 1
24- RANDOM = 2
25-
26- @classmethod
27- def from_str (cls , value ):
28- return cls [value .upper ()]
2918
3019
3120def one_hot (x , num_classes , on_value = 1. , off_value = 0. , device = 'cuda' ):
@@ -50,132 +39,185 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
5039 return input , target
5140
5241
53- def calc_ratio ( lam , minmax = None ):
42+ def rand_bbox ( size , lam , border = 0. , count = None ):
5443 ratio = math .sqrt (1 - lam )
55- if minmax is not None :
56- if isinstance (minmax , numbers .Number ):
57- minmax = (minmax , 1 - minmax )
58- ratio = np .clip (ratio , minmax [0 ], minmax [1 ])
59- return ratio
60-
61-
62- def rand_bbox (size , ratio ):
63- H , W = size [- 2 :]
64- cut_h , cut_w = int (H * ratio ), int (W * ratio )
65- cy , cx = np .random .randint (H ), np .random .randint (W )
66- yl , yh = np .clip (cy - cut_h // 2 , 0 , H ), np .clip (cy + cut_h // 2 , 0 , H )
67- xl , xh = np .clip (cx - cut_w // 2 , 0 , W ), np .clip (cx + cut_w // 2 , 0 , W )
44+ img_h , img_w = size [- 2 :]
45+ cut_h , cut_w = int (img_h * ratio ), int (img_w * ratio )
46+ margin_y , margin_x = int (border * cut_h ), int (border * cut_w )
47+ cy = np .random .randint (0 + margin_y , img_h - margin_y , size = count )
48+ cx = np .random .randint (0 + margin_x , img_w - margin_x , size = count )
49+ yl = np .clip (cy - cut_h // 2 , 0 , img_h )
50+ yh = np .clip (cy + cut_h // 2 , 0 , img_h )
51+ xl = np .clip (cx - cut_w // 2 , 0 , img_w )
52+ xh = np .clip (cx + cut_w // 2 , 0 , img_w )
6853 return yl , yh , xl , xh
6954
7055
56+ def rand_bbox_minmax (size , minmax , count = None ):
57+ assert len (minmax ) == 2
58+ img_h , img_w = size [- 2 :]
59+ cut_h = np .random .randint (int (img_h * minmax [0 ]), int (img_h * minmax [1 ]), size = count )
60+ cut_w = np .random .randint (int (img_w * minmax [0 ]), int (img_w * minmax [1 ]), size = count )
61+ yl = np .random .randint (0 , img_h - cut_h , size = count )
62+ xl = np .random .randint (0 , img_w - cut_w , size = count )
63+ yu = yl + cut_h
64+ xu = xl + cut_w
65+ return yl , yu , xl , xu
66+
67+
68+ def cutmix_bbox_and_lam (img_shape , lam , ratio_minmax = None , correct_lam = True , count = None ):
69+ if ratio_minmax is not None :
70+ yl , yu , xl , xu = rand_bbox_minmax (img_shape , ratio_minmax , count = count )
71+ else :
72+ yl , yu , xl , xu = rand_bbox (img_shape , lam , count = count )
73+ if correct_lam or ratio_minmax is not None :
74+ bbox_area = (yu - yl ) * (xu - xl )
75+ lam = 1. - bbox_area / (img_shape [- 2 ] * img_shape [- 1 ])
76+ return (yl , yu , xl , xu ), lam
77+
78+
7179def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , correct_lam = False ):
7280 lam = 1.
7381 if not disable :
7482 lam = np .random .beta (alpha , alpha )
7583 if lam != 1 :
76- yl , yh , xl , xh = rand_bbox (input .size (), calc_ratio ( lam ) )
84+ yl , yh , xl , xh = rand_bbox (input .size (), lam )
7785 input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
7886 if correct_lam :
7987 lam = 1 - (yh - yl ) * (xh - xl ) / (input .shape [- 2 ] * input .shape [- 1 ])
8088 target = mixup_target (target , num_classes , lam , smoothing )
8189 return input , target
8290
8391
84- def _resolve_mode (mode ):
85- mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
86- if mode == MixupMode .RANDOM :
87- mode = MixupMode (np .random .rand () > 0.7 )
88- return mode # will be one of cutmix or mixup
89-
90-
9192def mix_batch (
92- input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , mode = MixupMode .MIXUP ):
93- mode = _resolve_mode (mode )
94- if mode == MixupMode .CUTMIX :
95- return cutmix_batch (input , target , alpha , num_classes , smoothing , disable )
93+ input , target , mixup_alpha = 0.2 , cutmix_alpha = 0. , prob = 1.0 , switch_prob = .5 ,
94+ num_classes = 1000 , smoothing = 0.1 , disable = False ):
95+ # FIXME test this version
96+ if np .random .rand () > prob :
97+ return input , target
98+ use_cutmix = cutmix_alpha > 0. and np .random .rand () <= switch_prob
99+ if use_cutmix :
100+ return cutmix_batch (input , target , cutmix_alpha , num_classes , smoothing , disable )
96101 else :
97- return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
102+ return mixup_batch (input , target , mixup_alpha , num_classes , smoothing , disable )
98103
99104
100105class FastCollateMixup :
101- """Fast Collate Mixup that applies different params to each element + flipped pair
106+ """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
102107
103108 NOTE once experiments are done, one of the three variants will remain with this class name
109+
104110 """
105- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
111+ def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
112+ elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
113+ """
114+
115+ Args:
116+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
117+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
118+ cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
119+ prob (float): probability of applying mixup or cutmix per batch or element
120+ switch_prob (float): probability of using cutmix instead of mixup when both active
121+ elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
122+ label_smoothing (float):
123+ num_classes (int):
124+ """
106125 self .mixup_alpha = mixup_alpha
126+ self .cutmix_alpha = cutmix_alpha
127+ self .cutmix_minmax = cutmix_minmax
128+ if self .cutmix_minmax is not None :
129+ assert len (self .cutmix_minmax ) == 2
130+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
131+ self .cutmix_alpha = 1.0
132+ self .prob = prob
133+ self .switch_prob = switch_prob
107134 self .label_smoothing = label_smoothing
108135 self .num_classes = num_classes
109- self .mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
110- self .mixup_enabled = True
111- self .correct_lam = True # correct lambda based on clipped area for cutmix
112- self .ratio_minmax = None # (0.2, 0.8)
136+ self .elementwise = elementwise
137+ self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
138+ self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
113139
114- def _do_mix (self , tensor , batch ):
140+ def _mix_elem (self , output , batch ):
115141 batch_size = len (batch )
116- lam_out = torch .ones (batch_size )
142+ lam_out = np .ones (batch_size )
143+ use_cutmix = np .zeros (batch_size ).astype (np .bool )
144+ if self .mixup_enabled :
145+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
146+ use_cutmix = np .random .rand (batch_size ) < self .switch_prob
147+ lam_mix = np .where (
148+ use_cutmix ,
149+ np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size ),
150+ np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size ))
151+ elif self .mixup_alpha > 0. :
152+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size )
153+ elif self .cutmix_alpha > 0. :
154+ use_cutmix = np .ones (batch_size ).astype (np .bool )
155+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
156+ else :
157+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158+ lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix , lam_out )
159+
117160 for i in range (batch_size ):
118161 j = batch_size - i - 1
119- lam = 1.
120- if self .mixup_enabled :
121- lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
122-
123- if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
124- mixed = batch [i ][0 ].astype (np .float32 )
125- if lam != 1 :
126- ratio = calc_ratio (lam )
127- yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
162+ lam = lam_out [i ]
163+ mixed = batch [i ][0 ].astype (np .float32 )
164+ if lam != 1. :
165+ if use_cutmix [i ]:
166+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
167+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
128168 mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
129- if self .correct_lam :
130- lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
131- else :
132- lam_out [i ] = lam
169+ lam_out [i ] = lam
170+ else :
171+ mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
172+ lam_out [i ] = lam
173+ np .round (mixed , out = mixed )
174+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
175+ return torch .tensor (lam_out ).unsqueeze (1 )
176+
177+ def _mix_batch (self , output , batch ):
178+ batch_size = len (batch )
179+ lam = 1.
180+ use_cutmix = False
181+ if self .mixup_enabled and np .random .rand () < self .prob :
182+ if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
183+ use_cutmix = np .random .rand () < self .switch_prob
184+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha ) if use_cutmix else \
185+ np .random .beta (self .mixup_alpha , self .mixup_alpha )
186+ elif self .mixup_alpha > 0. :
187+ lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha )
188+ elif self .cutmix_alpha > 0. :
189+ use_cutmix = True
190+ lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha )
133191 else :
134- mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
135- lam_out [i ] = lam
136- np .round (mixed , out = mixed )
137- tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
138- return lam_out .unsqueeze (1 )
192+ assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
193+ lam = lam_mix
194+
195+ if use_cutmix :
196+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
197+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
198+
199+ for i in range (batch_size ):
200+ j = batch_size - i - 1
201+ mixed = batch [i ][0 ].astype (np .float32 )
202+ if lam != 1. :
203+ if use_cutmix :
204+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
205+ else :
206+ mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
207+ np .round (mixed , out = mixed )
208+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
209+ return lam
139210
140211 def __call__ (self , batch ):
141212 batch_size = len (batch )
142213 assert batch_size % 2 == 0 , 'Batch size should be even when using this'
143- tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
144- lam = self ._do_mix (tensor , batch )
214+ output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
215+ if self .elementwise :
216+ lam = self ._mix_elem (output , batch )
217+ else :
218+ lam = self ._mix_batch (output , batch )
145219 target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
146220 target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
147221
148- return tensor , target
149-
150-
151- class FastCollateMixupBatchwise (FastCollateMixup ):
152- """Fast Collate Mixup that applies same params to whole batch
153-
154- NOTE this is for experimentation, may remove at some point
155- """
156-
157- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
158- super (FastCollateMixupBatchwise , self ).__init__ (mixup_alpha , label_smoothing , num_classes , mode )
222+ return output , target
159223
160- def _do_mix (self , tensor , batch ):
161- batch_size = len (batch )
162- lam = 1.
163- cutmix = _resolve_mode (self .mode ) == MixupMode .CUTMIX
164- if self .mixup_enabled :
165- lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
166- if cutmix :
167- yl , yh , xl , xh = rand_bbox (batch [0 ][0 ].shape , calc_ratio (lam ))
168- if self .correct_lam :
169- lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
170-
171- for i in range (batch_size ):
172- j = batch_size - i - 1
173- if cutmix :
174- mixed = batch [i ][0 ].astype (np .float32 )
175- if lam != 1 :
176- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
177- else :
178- mixed = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
179- np .round (mixed , out = mixed )
180- tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
181- return lam
0 commit comments