@@ -33,16 +33,20 @@ class RandomErasing:
3333 'const' - erase block is constant color of 0 for all channels
3434 'rand' - erase block is same per-cannel random (normal) color
3535 'pixel' - erase block is per-pixel random (normal) color
36+ max_count: maximum number of erasing blocks per image, area per box is scaled by count.
37+ per-image count is randomly chosen between 1 and this value.
3638 """
3739
3840 def __init__ (
3941 self ,
4042 probability = 0.5 , sl = 0.02 , sh = 1 / 3 , min_aspect = 0.3 ,
41- mode = 'const' , device = 'cuda' ):
43+ mode = 'const' , max_count = 1 , device = 'cuda' ):
4244 self .probability = probability
4345 self .sl = sl
4446 self .sh = sh
4547 self .min_aspect = min_aspect
48+ self .min_count = 1
49+ self .max_count = max_count
4650 mode = mode .lower ()
4751 self .rand_color = False
4852 self .per_pixel = False
@@ -58,18 +62,22 @@ def _erase(self, img, chan, img_h, img_w, dtype):
5862 if random .random () > self .probability :
5963 return
6064 area = img_h * img_w
61- for attempt in range (100 ):
62- target_area = random .uniform (self .sl , self .sh ) * area
63- aspect_ratio = random .uniform (self .min_aspect , 1 / self .min_aspect )
64- h = int (round (math .sqrt (target_area * aspect_ratio )))
65- w = int (round (math .sqrt (target_area / aspect_ratio )))
66- if w < img_w and h < img_h :
67- top = random .randint (0 , img_h - h )
68- left = random .randint (0 , img_w - w )
69- img [:, top :top + h , left :left + w ] = _get_pixels (
70- self .per_pixel , self .rand_color , (chan , h , w ),
71- dtype = dtype , device = self .device )
72- break
65+ count = self .min_count if self .min_count == self .max_count else \
66+ random .randint (self .min_count , self .max_count )
67+ for _ in range (count ):
68+ for attempt in range (10 ):
69+ target_area = random .uniform (self .sl , self .sh ) * area / count
70+ log_ratio = (math .log (self .min_aspect ), math .log (1 / self .min_aspect ))
71+ aspect_ratio = math .exp (random .uniform (* log_ratio ))
72+ h = int (round (math .sqrt (target_area * aspect_ratio )))
73+ w = int (round (math .sqrt (target_area / aspect_ratio )))
74+ if w < img_w and h < img_h :
75+ top = random .randint (0 , img_h - h )
76+ left = random .randint (0 , img_w - w )
77+ img [:, top :top + h , left :left + w ] = _get_pixels (
78+ self .per_pixel , self .rand_color , (chan , h , w ),
79+ dtype = dtype , device = self .device )
80+ break
7381
7482 def __call__ (self , input ):
7583 if len (input .size ()) == 3 :
0 commit comments