Skip to content

Commit 66634d2

Browse files
committed
Add support to split random erasing blocks into randomly selected number with --recount arg. Fix random selection of aspect ratios.
1 parent 6946281 commit 66634d2

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

timm/data/loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self,
2020
loader,
2121
rand_erase_prob=0.,
2222
rand_erase_mode='const',
23+
rand_erase_count=1,
2324
mean=IMAGENET_DEFAULT_MEAN,
2425
std=IMAGENET_DEFAULT_STD,
2526
fp16=False):
@@ -32,7 +33,7 @@ def __init__(self,
3233
self.std = self.std.half()
3334
if rand_erase_prob > 0.:
3435
self.random_erasing = RandomErasing(
35-
probability=rand_erase_prob, mode=rand_erase_mode)
36+
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
3637
else:
3738
self.random_erasing = None
3839

@@ -94,6 +95,7 @@ def create_loader(
9495
use_prefetcher=True,
9596
rand_erase_prob=0.,
9697
rand_erase_mode='const',
98+
rand_erase_count=1,
9799
color_jitter=0.4,
98100
interpolation='bilinear',
99101
mean=IMAGENET_DEFAULT_MEAN,
@@ -160,6 +162,7 @@ def create_loader(
160162
loader,
161163
rand_erase_prob=rand_erase_prob if is_training else 0.,
162164
rand_erase_mode=rand_erase_mode,
165+
rand_erase_count=rand_erase_count,
163166
mean=mean,
164167
std=std,
165168
fp16=fp16)

timm/data/random_erasing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,20 @@ class RandomErasing:
3333
'const' - erase block is constant color of 0 for all channels
3434
'rand' - erase block is same per-cannel random (normal) color
3535
'pixel' - erase block is per-pixel random (normal) color
36+
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
37+
per-image count is randomly chosen between 1 and this value.
3638
"""
3739

3840
def __init__(
3941
self,
4042
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
41-
mode='const', device='cuda'):
43+
mode='const', max_count=1, device='cuda'):
4244
self.probability = probability
4345
self.sl = sl
4446
self.sh = sh
4547
self.min_aspect = min_aspect
46-
self.max_count = 8
48+
self.min_count = 1
49+
self.max_count = max_count
4750
mode = mode.lower()
4851
self.rand_color = False
4952
self.per_pixel = False
@@ -59,11 +62,13 @@ def _erase(self, img, chan, img_h, img_w, dtype):
5962
if random.random() > self.probability:
6063
return
6164
area = img_h * img_w
62-
count = random.randint(1, self.max_count)
65+
count = self.min_count if self.min_count == self.max_count else \
66+
random.randint(self.min_count, self.max_count)
6367
for _ in range(count):
6468
for attempt in range(10):
65-
target_area = random.uniform(self.sl / count, self.sh / count) * area
66-
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
69+
target_area = random.uniform(self.sl, self.sh) * area / count
70+
log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect))
71+
aspect_ratio = math.exp(random.uniform(*log_ratio))
6772
h = int(round(math.sqrt(target_area * aspect_ratio)))
6873
w = int(round(math.sqrt(target_area / aspect_ratio)))
6974
if w < img_w and h < img_h:

timm/data/transforms.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,31 @@ def get_params(img, scale, ratio):
107107

108108
for attempt in range(10):
109109
target_area = random.uniform(*scale) * area
110-
aspect_ratio = random.uniform(*ratio)
110+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
111+
aspect_ratio = math.exp(random.uniform(*log_ratio))
111112

112113
w = int(round(math.sqrt(target_area * aspect_ratio)))
113114
h = int(round(math.sqrt(target_area / aspect_ratio)))
114115

115-
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
116-
w, h = h, w
117-
118116
if w <= img.size[0] and h <= img.size[1]:
119117
i = random.randint(0, img.size[1] - h)
120118
j = random.randint(0, img.size[0] - w)
121119
return i, j, h, w
122120

123-
# Fallback
124-
w = min(img.size[0], img.size[1])
125-
i = (img.size[1] - w) // 2
121+
# Fallback to central crop
122+
in_ratio = img.size[0] / img.size[1]
123+
if in_ratio < min(ratio):
124+
w = img.size[0]
125+
h = int(round(w / min(ratio)))
126+
elif in_ratio > max(ratio):
127+
h = img.size[1]
128+
w = int(round(h * max(ratio)))
129+
else: # whole image
130+
w = img.size[0]
131+
h = img.size[1]
132+
i = (img.size[1] - h) // 2
126133
j = (img.size[0] - w) // 2
127-
return i, j, w, w
134+
return i, j, h, w
128135

129136
def __call__(self, img):
130137
"""

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
help='Random erase prob (default: 0.)')
9292
parser.add_argument('--remode', type=str, default='const',
9393
help='Random erase mode (default: "const")')
94+
parser.add_argument('--recount', type=int, default=1,
95+
help='Random erase count (default: 1)')
9496
parser.add_argument('--mixup', type=float, default=0.0,
9597
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
9698
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
@@ -273,6 +275,7 @@ def main():
273275
use_prefetcher=args.prefetcher,
274276
rand_erase_prob=args.reprob,
275277
rand_erase_mode=args.remode,
278+
rand_erase_count=args.recount,
276279
color_jitter=args.color_jitter,
277280
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
278281
mean=data_config['mean'],

0 commit comments

Comments
 (0)