@@ -72,7 +72,7 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
7272 yl , yu , xl , xu = rand_bbox (img_shape , lam , count = count )
7373 if correct_lam or ratio_minmax is not None :
7474 bbox_area = (yu - yl ) * (xu - xl )
75- lam = 1. - bbox_area / (img_shape [- 2 ] * img_shape [- 1 ])
75+ lam = 1. - bbox_area / float (img_shape [- 2 ] * img_shape [- 1 ])
7676 return (yl , yu , xl , xu ), lam
7777
7878
@@ -84,7 +84,7 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa
8484 yl , yh , xl , xh = rand_bbox (input .size (), lam )
8585 input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
8686 if correct_lam :
87- lam = 1 - (yh - yl ) * (xh - xl ) / (input .shape [- 2 ] * input .shape [- 1 ])
87+ lam = 1. - (yh - yl ) * (xh - xl ) / float (input .shape [- 2 ] * input .shape [- 1 ])
8888 target = mixup_target (target , num_classes , lam , smoothing )
8989 return input , target
9090
@@ -139,7 +139,7 @@ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0
139139
140140 def _mix_elem (self , output , batch ):
141141 batch_size = len (batch )
142- lam_out = np .ones (batch_size )
142+ lam_out = np .ones (batch_size , dtype = np . float32 )
143143 use_cutmix = np .zeros (batch_size ).astype (np .bool )
144144 if self .mixup_enabled :
145145 if self .mixup_alpha > 0. and self .cutmix_alpha > 0. :
@@ -155,22 +155,23 @@ def _mix_elem(self, output, batch):
155155 lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha , size = batch_size )
156156 else :
157157 assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
158- lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix , lam_out )
158+ lam_out = np .where (np .random .rand (batch_size ) < self .prob , lam_mix . astype ( np . float32 ) , lam_out )
159159
160160 for i in range (batch_size ):
161161 j = batch_size - i - 1
162162 lam = lam_out [i ]
163- mixed = batch [i ][0 ]. astype ( np . float32 )
163+ mixed = batch [i ][0 ]
164164 if lam != 1. :
165165 if use_cutmix [i ]:
166+ mixed = mixed .copy ()
166167 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
167168 output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
168- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]. astype ( np . float32 )
169+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
169170 lam_out [i ] = lam
170171 else :
171- mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
172+ mixed = mixed . astype ( np . float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
172173 lam_out [i ] = lam
173- np .round (mixed , out = mixed )
174+ np .round (mixed , out = mixed )
174175 output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
175176 return torch .tensor (lam_out ).unsqueeze (1 )
176177
@@ -190,21 +191,22 @@ def _mix_batch(self, output, batch):
190191 lam_mix = np .random .beta (self .cutmix_alpha , self .cutmix_alpha )
191192 else :
192193 assert False , "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
193- lam = lam_mix
194+ lam = float ( lam_mix )
194195
195196 if use_cutmix :
196197 (yl , yh , xl , xh ), lam = cutmix_bbox_and_lam (
197198 output .shape , lam , ratio_minmax = self .cutmix_minmax , correct_lam = self .correct_lam )
198199
199200 for i in range (batch_size ):
200201 j = batch_size - i - 1
201- mixed = batch [i ][0 ]. astype ( np . float32 )
202+ mixed = batch [i ][0 ]
202203 if lam != 1. :
203204 if use_cutmix :
204- mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
205+ mixed = mixed .copy ()
206+ mixed [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ]
205207 else :
206- mixed = mixed * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
207- np .round (mixed , out = mixed )
208+ mixed = mixed . astype ( np . float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
209+ np .round (mixed , out = mixed )
208210 output [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
209211 return lam
210212
0 commit comments