Skip to content

Commit b750b76

Browse files
committed
More AutoAugment work. Ready to roll...
1 parent 25d2088 commit b750b76

File tree

4 files changed

+107
-82
lines changed

4 files changed

+107
-82
lines changed

timm/data/auto_augment.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
""" Auto Augment
2+
Implementation adapted from:
3+
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
4+
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
5+
6+
Hacked together by Ross Wightman
7+
"""
18
import random
29
import math
3-
from torchvision import transforms
4-
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
10+
from PIL import Image, ImageOps, ImageEnhance
511
import PIL
612
import numpy as np
713

@@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__):
131137
return img
132138

133139

134-
def posterize(img, bits, **__):
135-
return ImageOps.posterize(img, 4 - bits)
140+
def posterize(img, bits_to_keep, **__):
141+
if bits_to_keep >= 8:
142+
return img
143+
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
144+
return ImageOps.posterize(img, bits_to_keep)
136145

137146

138147
def contrast(img, factor, **__):
@@ -157,16 +166,19 @@ def _randomly_negate(v):
157166

158167

159168
def _rotate_level_to_arg(level):
169+
# range [-30, 30]
160170
level = (level / _MAX_LEVEL) * 30.
161171
level = _randomly_negate(level)
162172
return (level,)
163173

164174

165175
def _enhance_level_to_arg(level):
176+
# range [0.1, 1.9]
166177
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
167178

168179

169180
def _shear_level_to_arg(level):
181+
# range [-0.3, 0.3]
170182
level = (level / _MAX_LEVEL) * 0.3
171183
level = _randomly_negate(level)
172184
return (level,)
@@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const):
179191

180192

181193
def _translate_rel_level_to_arg(level):
194+
# range [-0.45, 0.45]
182195
level = (level / _MAX_LEVEL) * 0.45
183196
level = _randomly_negate(level)
184197
return (level,)
@@ -190,9 +203,12 @@ def level_to_arg(hparams):
190203
'Equalize': lambda level: (),
191204
'Invert': lambda level: (),
192205
'Rotate': _rotate_level_to_arg,
193-
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
194-
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
195-
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
206+
# FIXME these are both different from original impl as I believe there is a bug,
207+
# not sure what is the correct alternative, hence 2 options that look better
208+
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
209+
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
210+
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
211+
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
196212
'Color': _enhance_level_to_arg,
197213
'Contrast': _enhance_level_to_arg,
198214
'Brightness': _enhance_level_to_arg,
@@ -212,6 +228,7 @@ def level_to_arg(hparams):
212228
'Invert': invert,
213229
'Rotate': rotate,
214230
'Posterize': posterize,
231+
'Posterize2': posterize,
215232
'Solarize': solarize,
216233
'SolarizeAdd': solarize_add,
217234
'Color': color,
@@ -252,10 +269,8 @@ def __call__(self, img):
252269

253270

254271
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
255-
"""Autoaugment policy that was used in AutoAugment Paper."""
256-
# Each tuple is an augmentation operation of the form
257-
# (operation, probability, magnitude). Each element in policy is a
258-
# sub-policy that will be applied sequentially on the image.
272+
# ImageNet policy from TPU EfficientNet impl, cannot find
273+
# a paper reference.
259274
policy = [
260275
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
261276
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
@@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
287302
return pc
288303

289304

305+
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
306+
# ImageNet policy from https://arxiv.org/abs/1805.09501
307+
policy = [
308+
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
309+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
310+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
311+
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
312+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
313+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
314+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
315+
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
316+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
317+
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
318+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
319+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
320+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
321+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
322+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
323+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
324+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
325+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
326+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
327+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
328+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
329+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
330+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
331+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
332+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
333+
]
334+
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
335+
return pc
336+
337+
338+
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
339+
if name == 'original':
340+
return auto_augment_policy_original(hparams)
341+
elif name == 'v0':
342+
return auto_augment_policy_v0(hparams)
343+
else:
344+
assert False, 'Unknown AA policy (%s)' % name
345+
346+
290347
class AutoAugment:
291348

292349
def __init__(self, policy):

timm/data/loader.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def create_transform(
9292
is_training=False,
9393
use_prefetcher=False,
9494
color_jitter=0.4,
95+
auto_augment=None,
9596
interpolation='bilinear',
9697
mean=IMAGENET_DEFAULT_MEAN,
9798
std=IMAGENET_DEFAULT_STD,
@@ -109,21 +110,14 @@ def create_transform(
109110
is_training=is_training, size=img_size, interpolation=interpolation)
110111
else:
111112
if is_training:
112-
if True:
113-
transform = transforms_imagenet_aa(
114-
img_size,
115-
interpolation=interpolation,
116-
use_prefetcher=use_prefetcher,
117-
mean=mean,
118-
std=std)
119-
else:
120-
transform = transforms_imagenet_train(
121-
img_size,
122-
color_jitter=color_jitter,
123-
interpolation=interpolation,
124-
use_prefetcher=use_prefetcher,
125-
mean=mean,
126-
std=std)
113+
transform = transforms_imagenet_train(
114+
img_size,
115+
color_jitter=color_jitter,
116+
auto_augment=auto_augment,
117+
interpolation=interpolation,
118+
use_prefetcher=use_prefetcher,
119+
mean=mean,
120+
std=std)
127121
else:
128122
transform = transforms_imagenet_eval(
129123
img_size,
@@ -146,6 +140,7 @@ def create_loader(
146140
rand_erase_mode='const',
147141
rand_erase_count=1,
148142
color_jitter=0.4,
143+
auto_augment=None,
149144
interpolation='bilinear',
150145
mean=IMAGENET_DEFAULT_MEAN,
151146
std=IMAGENET_DEFAULT_STD,
@@ -161,6 +156,7 @@ def create_loader(
161156
is_training=is_training,
162157
use_prefetcher=use_prefetcher,
163158
color_jitter=color_jitter,
159+
auto_augment=auto_augment,
164160
interpolation=interpolation,
165161
mean=mean,
166162
std=std,

timm/data/transforms.py

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1111
from .random_erasing import RandomErasing
12-
from .auto_augment import AutoAugment, auto_augment_policy_v0
12+
from .auto_augment import AutoAugment, auto_augment_policy
1313

1414

1515
class ToNumpy:
@@ -57,10 +57,10 @@ def _pil_interp(method):
5757
return Image.BILINEAR
5858

5959

60-
RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
60+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
6161

6262

63-
class RandomResizedCropAndInterpolation(object):
63+
class RandomResizedCropAndInterpolation:
6464
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
6565
6666
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
@@ -85,7 +85,7 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
8585
warnings.warn("range should be of kind (min, max)")
8686

8787
if interpolation == 'random':
88-
self.interpolation = RANDOM_INTERPOLATION
88+
self.interpolation = _RANDOM_INTERPOLATION
8989
else:
9090
self.interpolation = _pil_interp(interpolation)
9191
self.scale = scale
@@ -161,73 +161,42 @@ def __repr__(self):
161161
return format_string
162162

163163

164-
def transforms_imagenet_aa(
165-
img_size=224,
166-
scale=(0.08, 1.0),
167-
interpolation='random',
168-
random_erasing=0.4,
169-
random_erasing_mode='const',
170-
use_prefetcher=False,
171-
mean=IMAGENET_DEFAULT_MEAN,
172-
std=IMAGENET_DEFAULT_STD
173-
):
174-
aa_params = dict(
175-
cutout_max_pad_fraction=0.75,
176-
cutout_const=100,
177-
translate_const=img_size[-1] // 2 - 1,
178-
img_mean=tuple([min(255, round(255*x)) for x in mean]),
179-
)
180-
if interpolation and interpolation != 'random':
181-
aa_params['interpolation'] = _pil_interp(interpolation)
182-
aa_policy = auto_augment_policy_v0(aa_params)
183-
184-
tfl = [
185-
RandomResizedCropAndInterpolation(
186-
img_size, scale=scale, interpolation=interpolation),
187-
transforms.RandomHorizontalFlip(),
188-
AutoAugment(aa_policy)
189-
]
190-
191-
if use_prefetcher:
192-
# prefetcher and collate will handle tensor conversion and norm
193-
tfl += [ToNumpy()]
194-
else:
195-
tfl += [
196-
transforms.ToTensor(),
197-
transforms.Normalize(
198-
mean=torch.tensor(mean),
199-
std=torch.tensor(std))
200-
]
201-
if random_erasing > 0.:
202-
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
203-
return transforms.Compose(tfl)
204-
205-
206164
def transforms_imagenet_train(
207165
img_size=224,
208166
scale=(0.08, 1.0),
209167
color_jitter=0.4,
168+
auto_augment=None,
210169
interpolation='random',
211170
random_erasing=0.4,
212171
random_erasing_mode='const',
213172
use_prefetcher=False,
214173
mean=IMAGENET_DEFAULT_MEAN,
215174
std=IMAGENET_DEFAULT_STD
216175
):
217-
if isinstance(color_jitter, (list, tuple)):
218-
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
219-
# or 4 if also augmenting hue
220-
assert len(color_jitter) in (3, 4)
221-
else:
222-
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
223-
color_jitter = (float(color_jitter),) * 3
224-
225176
tfl = [
226177
RandomResizedCropAndInterpolation(
227178
img_size, scale=scale, interpolation=interpolation),
228-
transforms.RandomHorizontalFlip(),
229-
transforms.ColorJitter(*color_jitter),
179+
transforms.RandomHorizontalFlip()
230180
]
181+
if auto_augment:
182+
aa_params = dict(
183+
translate_const=img_size[-1] // 2 - 1,
184+
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
185+
)
186+
if interpolation and interpolation != 'random':
187+
aa_params['interpolation'] = _pil_interp(interpolation)
188+
aa_policy = auto_augment_policy(auto_augment, aa_params)
189+
tfl += [AutoAugment(aa_policy)]
190+
else:
191+
# color jitter is enabled when not using AA
192+
if isinstance(color_jitter, (list, tuple)):
193+
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
194+
# or 4 if also augmenting hue
195+
assert len(color_jitter) in (3, 4)
196+
else:
197+
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
198+
color_jitter = (float(color_jitter),) * 3
199+
tfl += [transforms.ColorJitter(*color_jitter)]
231200

232201
if use_prefetcher:
233202
# prefetcher and collate will handle tensor conversion and norm

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989
# Augmentation parameters
9090
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
9191
help='Color jitter factor (default: 0.4)')
92+
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
93+
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
9294
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
9395
help='Random erase prob (default: 0.)')
9496
parser.add_argument('--remode', type=str, default='const',
@@ -287,6 +289,7 @@ def main():
287289
rand_erase_mode=args.remode,
288290
rand_erase_count=args.recount,
289291
color_jitter=args.color_jitter,
292+
auto_augment=args.aa,
290293
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
291294
mean=data_config['mean'],
292295
std=data_config['std'],

0 commit comments

Comments
 (0)