Skip to content

Commit f37e633

Browse files
committed
Merge remote-tracking branch 'origin/re-exp' into opt
2 parents b06dce8 + 66634d2 commit f37e633

File tree

4 files changed

+43
-22
lines changed

4 files changed

+43
-22
lines changed

timm/data/loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self,
2020
loader,
2121
rand_erase_prob=0.,
2222
rand_erase_mode='const',
23+
rand_erase_count=1,
2324
mean=IMAGENET_DEFAULT_MEAN,
2425
std=IMAGENET_DEFAULT_STD,
2526
fp16=False):
@@ -32,7 +33,7 @@ def __init__(self,
3233
self.std = self.std.half()
3334
if rand_erase_prob > 0.:
3435
self.random_erasing = RandomErasing(
35-
probability=rand_erase_prob, mode=rand_erase_mode)
36+
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
3637
else:
3738
self.random_erasing = None
3839

@@ -135,6 +136,7 @@ def create_loader(
135136
use_prefetcher=True,
136137
rand_erase_prob=0.,
137138
rand_erase_mode='const',
139+
rand_erase_count=1,
138140
color_jitter=0.4,
139141
interpolation='bilinear',
140142
mean=IMAGENET_DEFAULT_MEAN,
@@ -184,6 +186,7 @@ def create_loader(
184186
loader,
185187
rand_erase_prob=rand_erase_prob if is_training else 0.,
186188
rand_erase_mode=rand_erase_mode,
189+
rand_erase_count=rand_erase_count,
187190
mean=mean,
188191
std=std,
189192
fp16=fp16)

timm/data/random_erasing.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,20 @@ class RandomErasing:
3333
'const' - erase block is constant color of 0 for all channels
3434
'rand' - erase block is same per-cannel random (normal) color
3535
'pixel' - erase block is per-pixel random (normal) color
36+
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
37+
per-image count is randomly chosen between 1 and this value.
3638
"""
3739

3840
def __init__(
3941
self,
4042
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
41-
mode='const', device='cuda'):
43+
mode='const', max_count=1, device='cuda'):
4244
self.probability = probability
4345
self.sl = sl
4446
self.sh = sh
4547
self.min_aspect = min_aspect
48+
self.min_count = 1
49+
self.max_count = max_count
4650
mode = mode.lower()
4751
self.rand_color = False
4852
self.per_pixel = False
@@ -58,18 +62,22 @@ def _erase(self, img, chan, img_h, img_w, dtype):
5862
if random.random() > self.probability:
5963
return
6064
area = img_h * img_w
61-
for attempt in range(100):
62-
target_area = random.uniform(self.sl, self.sh) * area
63-
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
64-
h = int(round(math.sqrt(target_area * aspect_ratio)))
65-
w = int(round(math.sqrt(target_area / aspect_ratio)))
66-
if w < img_w and h < img_h:
67-
top = random.randint(0, img_h - h)
68-
left = random.randint(0, img_w - w)
69-
img[:, top:top + h, left:left + w] = _get_pixels(
70-
self.per_pixel, self.rand_color, (chan, h, w),
71-
dtype=dtype, device=self.device)
72-
break
65+
count = self.min_count if self.min_count == self.max_count else \
66+
random.randint(self.min_count, self.max_count)
67+
for _ in range(count):
68+
for attempt in range(10):
69+
target_area = random.uniform(self.sl, self.sh) * area / count
70+
log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect))
71+
aspect_ratio = math.exp(random.uniform(*log_ratio))
72+
h = int(round(math.sqrt(target_area * aspect_ratio)))
73+
w = int(round(math.sqrt(target_area / aspect_ratio)))
74+
if w < img_w and h < img_h:
75+
top = random.randint(0, img_h - h)
76+
left = random.randint(0, img_w - w)
77+
img[:, top:top + h, left:left + w] = _get_pixels(
78+
self.per_pixel, self.rand_color, (chan, h, w),
79+
dtype=dtype, device=self.device)
80+
break
7381

7482
def __call__(self, input):
7583
if len(input.size()) == 3:

timm/data/transforms.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,31 @@ def get_params(img, scale, ratio):
107107

108108
for attempt in range(10):
109109
target_area = random.uniform(*scale) * area
110-
aspect_ratio = random.uniform(*ratio)
110+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
111+
aspect_ratio = math.exp(random.uniform(*log_ratio))
111112

112113
w = int(round(math.sqrt(target_area * aspect_ratio)))
113114
h = int(round(math.sqrt(target_area / aspect_ratio)))
114115

115-
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
116-
w, h = h, w
117-
118116
if w <= img.size[0] and h <= img.size[1]:
119117
i = random.randint(0, img.size[1] - h)
120118
j = random.randint(0, img.size[0] - w)
121119
return i, j, h, w
122120

123-
# Fallback
124-
w = min(img.size[0], img.size[1])
125-
i = (img.size[1] - w) // 2
121+
# Fallback to central crop
122+
in_ratio = img.size[0] / img.size[1]
123+
if in_ratio < min(ratio):
124+
w = img.size[0]
125+
h = int(round(w / min(ratio)))
126+
elif in_ratio > max(ratio):
127+
h = img.size[1]
128+
w = int(round(h * max(ratio)))
129+
else: # whole image
130+
w = img.size[0]
131+
h = img.size[1]
132+
i = (img.size[1] - h) // 2
126133
j = (img.size[0] - w) // 2
127-
return i, j, w, w
134+
return i, j, h, w
128135

129136
def __call__(self, img):
130137
"""

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
help='Random erase prob (default: 0.)')
9292
parser.add_argument('--remode', type=str, default='const',
9393
help='Random erase mode (default: "const")')
94+
parser.add_argument('--recount', type=int, default=1,
95+
help='Random erase count (default: 1)')
9496
parser.add_argument('--mixup', type=float, default=0.0,
9597
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
9698
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
@@ -273,6 +275,7 @@ def main():
273275
use_prefetcher=args.prefetcher,
274276
rand_erase_prob=args.reprob,
275277
rand_erase_mode=args.remode,
278+
rand_erase_count=args.recount,
276279
color_jitter=args.color_jitter,
277280
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
278281
mean=data_config['mean'],

0 commit comments

Comments
 (0)