1010
1111Hacked together by / Copyright 2020 Ross Wightman
1212"""
13-
1413import numpy as np
1514import torch
16- import math
17- import numbers
1815
1916
2017def one_hot (x , num_classes , on_value = 1. , off_value = 0. , device = 'cuda' ):
@@ -30,20 +27,21 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
3027 return y1 * lam + y2 * (1. - lam )
3128
3229
33- def mixup_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
34- lam = 1.
35- if not disable :
36- lam = np .random .beta (alpha , alpha )
37- input = input .mul (lam ).add_ (1 - lam , input .flip (0 ))
38- target = mixup_target (target , num_classes , lam , smoothing )
39- return input , target
40-
30+ def rand_bbox (img_shape , lam , margin = 0. , count = None ):
31+ """ Standard CutMix bounding-box
32+ Generates a random square bbox based on lambda value. This impl includes
33+ support for enforcing a border margin as percent of bbox dimensions.
4134
42- def rand_bbox (size , lam , border = 0. , count = None ):
43- ratio = math .sqrt (1 - lam )
44- img_h , img_w = size [- 2 :]
35+ Args:
36+ img_shape (tuple): Image shape as tuple
37+ lam (float): Cutmix lambda value
38+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39+ count (int): Number of bbox to generate
40+ """
41+ ratio = np .sqrt (1 - lam )
42+ img_h , img_w = img_shape [- 2 :]
4543 cut_h , cut_w = int (img_h * ratio ), int (img_w * ratio )
46- margin_y , margin_x = int (border * cut_h ), int (border * cut_w )
44+ margin_y , margin_x = int (margin * cut_h ), int (margin * cut_w )
4745 cy = np .random .randint (0 + margin_y , img_h - margin_y , size = count )
4846 cx = np .random .randint (0 + margin_x , img_w - margin_x , size = count )
4947 yl = np .clip (cy - cut_h // 2 , 0 , img_h )
@@ -53,9 +51,20 @@ def rand_bbox(size, lam, border=0., count=None):
5351 return yl , yh , xl , xh
5452
5553
56- def rand_bbox_minmax (size , minmax , count = None ):
54+ def rand_bbox_minmax (img_shape , minmax , count = None ):
55+ """ Min-Max CutMix bounding-box
56+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57+ based on min/max percent values applied to each dimension of the input image.
58+
59+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60+
61+ Args:
62+ img_shape (tuple): Image shape as tuple
63+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64+ count (int): Number of bbox to generate
65+ """
5766 assert len (minmax ) == 2
58- img_h , img_w = size [- 2 :]
67+ img_h , img_w = img_shape [- 2 :]
5968 cut_h = np .random .randint (int (img_h * minmax [0 ]), int (img_h * minmax [1 ]), size = count )
6069 cut_w = np .random .randint (int (img_w * minmax [0 ]), int (img_w * minmax [1 ]), size = count )
6170 yl = np .random .randint (0 , img_h - cut_h , size = count )
@@ -66,6 +75,8 @@ def rand_bbox_minmax(size, minmax, count=None):
6675
6776
6877def cutmix_bbox_and_lam (img_shape , lam , ratio_minmax = None , correct_lam = True , count = None ):
78+ """ Generate bbox and apply lambda correction.
79+ """
6980 if ratio_minmax is not None :
7081 yl , yu , xl , xu = rand_bbox_minmax (img_shape , ratio_minmax , count = count )
7182 else :
@@ -76,71 +87,40 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
7687 return (yl , yu , xl , xu ), lam
7788
7889
79- def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , correct_lam = False ):
80- lam = 1.
81- if not disable :
82- lam = np .random .beta (alpha , alpha )
83- if lam != 1 :
84- yl , yh , xl , xh = rand_bbox (input .size (), lam )
85- input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
86- if correct_lam :
87- lam = 1. - (yh - yl ) * (xh - xl ) / float (input .shape [- 2 ] * input .shape [- 1 ])
88- target = mixup_target (target , num_classes , lam , smoothing )
89- return input , target
90-
91-
92- def mix_batch (
93- input , target , mixup_alpha = 0.2 , cutmix_alpha = 0. , prob = 1.0 , switch_prob = .5 ,
94- num_classes = 1000 , smoothing = 0.1 , disable = False ):
95- # FIXME test this version
96- if np .random .rand () > prob :
97- return input , target
98- use_cutmix = cutmix_alpha > 0. and np .random .rand () <= switch_prob
99- if use_cutmix :
100- return cutmix_batch (input , target , cutmix_alpha , num_classes , smoothing , disable )
101- else :
102- return mixup_batch (input , target , mixup_alpha , num_classes , smoothing , disable )
103-
104-
105- class FastCollateMixup :
106- """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
107-
108- NOTE once experiments are done, one of the three variants will remain with this class name
90+ class Mixup :
91+ """ Mixup/Cutmix that applies different params to each element or whole batch
10992
93+ Args:
94+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97+ prob (float): probability of applying mixup or cutmix per batch or element
98+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99+ elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
100+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101+ label_smoothing (float): apply label smoothing to the mixed target tensor
102+ num_classes (int): number of classes for target
110103 """
111104 def __init__ (self , mixup_alpha = 1. , cutmix_alpha = 0. , cutmix_minmax = None , prob = 1.0 , switch_prob = 0.5 ,
112105 elementwise = False , correct_lam = True , label_smoothing = 0.1 , num_classes = 1000 ):
113- """
114-
115- Args:
116- mixup_alpha (float): mixup alpha value, mixup is active if > 0.
117- cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
118- cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
119- prob (float): probability of applying mixup or cutmix per batch or element
120- switch_prob (float): probability of using cutmix instead of mixup when both active
121- elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
122- label_smoothing (float):
123- num_classes (int):
124- """
125106 self .mixup_alpha = mixup_alpha
126107 self .cutmix_alpha = cutmix_alpha
127108 self .cutmix_minmax = cutmix_minmax
128109 if self .cutmix_minmax is not None :
129110 assert len (self .cutmix_minmax ) == 2
130111 # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
131112 self .cutmix_alpha = 1.0
132- self .prob = prob
113+ self .mix_prob = prob
133114 self .switch_prob = switch_prob
134115 self .label_smoothing = label_smoothing
135116 self .num_classes = num_classes
136117 self .elementwise = elementwise
137118 self .correct_lam = correct_lam # correct lambda based on clipped area for cutmix
138119 self .mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
139120
140- def _mix_elem (self , output , batch ):
141- batch_size = len (batch )
142- lam_out = np .ones (batch_size , dtype = np .float32 )
143- use_cutmix = np .zeros (batch_size ).astype (np .bool )
121+ def _params_per_elem (self , batch_size ):
122+ lam = np .ones (batch_size , dtype = np .float32 )
123+ use_cutmix = np .zeros (batch_size , dtype = np .bool )
144124 if self .mixup_enabled :
145125 if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
146126 use_cutmix = np .random .rand (batch_size ) < self .switch_prob
@@ -151,35 +131,17 @@ def _mix_elem(self, output, batch):
151131 elif self .mixup_alpha > 0. :
152132 lam_mix = np .random .beta (self .mixup_alpha , self .mixup_alpha , size = batch_size )
153133 elif self .cutmix_alpha > 0. :
154- use_cutmix = np .ones (batch_size ). astype ( np .bool )
134+ use_cutmix = np .ones (batch_size , dtype = np .bool )
155135 lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
156136 else :
157137 assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158- lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix .astype (np .float32 ), lam_out )
159-
160- for i in range (batch_size ):
161- j = batch_size - i - 1
162- lam = lam_out [i ]
163- mixed = batch [i ][0 ]
164- if lam != 1. :
165- if use_cutmix [i ]:
166- mixed = mixed .copy ()
167- (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
168- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
169- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
170- lam_out [i ] = lam
171- else :
172- mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
173- lam_out [i ] = lam
174- np .round (mixed , out = mixed )
175- output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
176- return torch .tensor (lam_out ).unsqueeze (1 )
138+ lam = np .where (np .random .rand (batch_size ) < self .mix_prob , lam_mix .astype (np .float32 ), lam )
139+ return lam , use_cutmix
177140
178- def _mix_batch (self , output , batch ):
179- batch_size = len (batch )
141+ def _params_per_batch (self ):
180142 lam = 1.
181143 use_cutmix = False
182- if self .mixup_enabled and np .random .rand () < self .prob :
144+ if self .mixup_enabled and np .random .rand () < self .mix_prob :
183145 if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
184146 use_cutmix = np .random .rand () < self .switch_prob
185147 lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha ) if use_cutmix else \
@@ -192,34 +154,100 @@ def _mix_batch(self, output, batch):
192154 else :
193155 assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
194156 lam = float (lam_mix )
157+ return lam , use_cutmix
195158
159+ def _mix_elem (self , x ):
160+ batch_size = len (x )
161+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
162+ x_orig = x .clone () # need to keep an unmodified original for mixing source
163+ for i in range (batch_size ):
164+ j = batch_size - i - 1
165+ lam = lam_batch [i ]
166+ if lam != 1. :
167+ if use_cutmix [i ]:
168+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
169+ x [i ].shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
170+ x [i ][:, yl :yh , xl :xh ] = x_orig [j ][:, yl :yh , xl :xh ]
171+ lam_batch [i ] = lam
172+ else :
173+ x [i ] = x [i ] * lam + x_orig [j ] * (1 - lam )
174+ return torch .tensor (lam_batch , device = x .device , dtype = x .dtype ).unsqueeze (1 )
175+
176+ def _mix_batch (self , x ):
177+ lam , use_cutmix = self ._params_per_batch ()
178+ if lam == 1. :
179+ return 1.
196180 if use_cutmix :
197181 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
198- output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
182+ x .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
183+ x [:, :, yl :yh , xl :xh ] = x .flip (0 )[:, :, yl :yh , xl :xh ]
184+ else :
185+ x_flipped = x .flip (0 ).mul_ (1. - lam )
186+ x .mul_ (lam ).add_ (x_flipped )
187+ return lam
188+
189+ def __call__ (self , x , target ):
190+ assert len (x ) % 2 == 0 , 'Batch size should be even when using this'
191+ lam = self ._mix_elem (x ) if self .elementwise else self ._mix_batch (x )
192+ target = mixup_target (target , self .num_classes , lam , self .label_smoothing )
193+ return x , target
194+
195+
196+ class FastCollateMixup (Mixup ):
197+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
199198
199+ A Mixup impl that's performed while collating the batches.
200+ """
201+
202+ def _mix_elem_collate (self , output , batch ):
203+ batch_size = len (batch )
204+ lam_batch , use_cutmix = self ._params_per_elem (batch_size )
200205 for i in range (batch_size ):
201206 j = batch_size - i - 1
207+ lam = lam_batch [i ]
202208 mixed = batch [i ][0 ]
203209 if lam != 1. :
204- if use_cutmix :
210+ if use_cutmix [ i ] :
205211 mixed = mixed .copy ()
212+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
213+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
214+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
215+ lam_batch [i ] = lam
216+ else :
217+ mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
218+ lam_batch [i ] = lam
219+ np .round (mixed , out = mixed )
220+ output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
221+ return torch .tensor (lam_batch ).unsqueeze (1 )
222+
223+ def _mix_batch_collate (self , output , batch ):
224+ batch_size = len (batch )
225+ lam , use_cutmix = self ._params_per_batch ()
226+ if use_cutmix :
227+ (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
228+ output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
229+ for i in range (batch_size ):
230+ j = batch_size - i - 1
231+ mixed = batch [i ][0 ]
232+ if lam != 1. :
233+ if use_cutmix :
234+ mixed = mixed .copy () # don't want to modify the original while iterating
206235 mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
207236 else :
208237 mixed = mixed .astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
209238 np .round (mixed , out = mixed )
210239 output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
211240 return lam
212241
213- def __call__ (self , batch ):
242+ def __call__ (self , batch , _ = None ):
214243 batch_size = len (batch )
215244 assert batch_size % 2 == 0 , 'Batch size should be even when using this'
216245 output = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
217246 if self .elementwise :
218- lam = self ._mix_elem (output , batch )
247+ lam = self ._mix_elem_collate (output , batch )
219248 else :
220- lam = self ._mix_batch (output , batch )
249+ lam = self ._mix_batch_collate (output , batch )
221250 target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
222251 target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
223-
224252 return output , target
225253
0 commit comments