Skip to content

Commit 25d2088

Browse files
committed
Working on auto-augment
1 parent aff194f commit 25d2088

File tree

3 files changed

+357
-7
lines changed

3 files changed

+357
-7
lines changed

timm/data/auto_augment.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import random
2+
import math
3+
from torchvision import transforms
4+
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
5+
import PIL
6+
import numpy as np
7+
8+
9+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
10+
11+
_FILL = (128, 128, 128)
12+
13+
# This signifies the max integer that the controller RNN could predict for the
14+
# augmentation scheme.
15+
_MAX_LEVEL = 10.
16+
17+
_HPARAMS_DEFAULT = dict(
18+
translate_const=250,
19+
img_mean=_FILL,
20+
)
21+
22+
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
23+
24+
25+
def _interpolation(kwargs):
26+
interpolation = kwargs.pop('resample', Image.NEAREST)
27+
if isinstance(interpolation, (list, tuple)):
28+
return random.choice(interpolation)
29+
else:
30+
return interpolation
31+
32+
33+
def _check_args_tf(kwargs):
34+
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
35+
kwargs.pop('fillcolor')
36+
kwargs['resample'] = _interpolation(kwargs)
37+
38+
39+
def shear_x(img, factor, **kwargs):
40+
_check_args_tf(kwargs)
41+
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
42+
43+
44+
def shear_y(img, factor, **kwargs):
45+
_check_args_tf(kwargs)
46+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
47+
48+
49+
def translate_x_rel(img, pct, **kwargs):
50+
pixels = pct * img.size[0]
51+
_check_args_tf(kwargs)
52+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
53+
54+
55+
def translate_y_rel(img, pct, **kwargs):
56+
pixels = pct * img.size[1]
57+
_check_args_tf(kwargs)
58+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
59+
60+
61+
def translate_x_abs(img, pixels, **kwargs):
62+
_check_args_tf(kwargs)
63+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
64+
65+
66+
def translate_y_abs(img, pixels, **kwargs):
67+
_check_args_tf(kwargs)
68+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
69+
70+
71+
def rotate(img, degrees, **kwargs):
72+
_check_args_tf(kwargs)
73+
if _PIL_VER >= (5, 2):
74+
return img.rotate(degrees, **kwargs)
75+
elif _PIL_VER >= (5, 0):
76+
w, h = img.size
77+
post_trans = (0, 0)
78+
rotn_center = (w / 2.0, h / 2.0)
79+
angle = -math.radians(degrees)
80+
matrix = [
81+
round(math.cos(angle), 15),
82+
round(math.sin(angle), 15),
83+
0.0,
84+
round(-math.sin(angle), 15),
85+
round(math.cos(angle), 15),
86+
0.0,
87+
]
88+
89+
def transform(x, y, matrix):
90+
(a, b, c, d, e, f) = matrix
91+
return a * x + b * y + c, d * x + e * y + f
92+
93+
matrix[2], matrix[5] = transform(
94+
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
95+
)
96+
matrix[2] += rotn_center[0]
97+
matrix[5] += rotn_center[1]
98+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
99+
else:
100+
return img.rotate(degrees, resample=kwargs['resample'])
101+
102+
103+
def auto_contrast(img, **__):
104+
return ImageOps.autocontrast(img)
105+
106+
107+
def invert(img, **__):
108+
return ImageOps.invert(img)
109+
110+
111+
def equalize(img, **__):
112+
return ImageOps.equalize(img)
113+
114+
115+
def solarize(img, thresh, **__):
116+
return ImageOps.solarize(img, thresh)
117+
118+
119+
def solarize_add(img, add, thresh=128, **__):
120+
lut = []
121+
for i in range(256):
122+
if i < thresh:
123+
lut.append(min(255, i + add))
124+
else:
125+
lut.append(i)
126+
if img.mode in ("L", "RGB"):
127+
if img.mode == "RGB" and len(lut) == 256:
128+
lut = lut + lut + lut
129+
return img.point(lut)
130+
else:
131+
return img
132+
133+
134+
def posterize(img, bits, **__):
135+
return ImageOps.posterize(img, 4 - bits)
136+
137+
138+
def contrast(img, factor, **__):
139+
return ImageEnhance.Contrast(img).enhance(factor)
140+
141+
142+
def color(img, factor, **__):
143+
return ImageEnhance.Color(img).enhance(factor)
144+
145+
146+
def brightness(img, factor, **__):
147+
return ImageEnhance.Brightness(img).enhance(factor)
148+
149+
150+
def sharpness(img, factor, **__):
151+
return ImageEnhance.Sharpness(img).enhance(factor)
152+
153+
154+
def _randomly_negate(v):
155+
"""With 50% prob, negate the value"""
156+
return -v if random.random() > 0.5 else v
157+
158+
159+
def _rotate_level_to_arg(level):
160+
level = (level / _MAX_LEVEL) * 30.
161+
level = _randomly_negate(level)
162+
return (level,)
163+
164+
165+
def _enhance_level_to_arg(level):
166+
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
167+
168+
169+
def _shear_level_to_arg(level):
170+
level = (level / _MAX_LEVEL) * 0.3
171+
level = _randomly_negate(level)
172+
return (level,)
173+
174+
175+
def _translate_abs_level_to_arg(level, translate_const):
176+
level = (level / _MAX_LEVEL) * float(translate_const)
177+
level = _randomly_negate(level)
178+
return (level,)
179+
180+
181+
def _translate_rel_level_to_arg(level):
182+
level = (level / _MAX_LEVEL) * 0.45
183+
level = _randomly_negate(level)
184+
return (level,)
185+
186+
187+
def level_to_arg(hparams):
188+
return {
189+
'AutoContrast': lambda level: (),
190+
'Equalize': lambda level: (),
191+
'Invert': lambda level: (),
192+
'Rotate': _rotate_level_to_arg,
193+
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
194+
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
195+
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
196+
'Color': _enhance_level_to_arg,
197+
'Contrast': _enhance_level_to_arg,
198+
'Brightness': _enhance_level_to_arg,
199+
'Sharpness': _enhance_level_to_arg,
200+
'ShearX': _shear_level_to_arg,
201+
'ShearY': _shear_level_to_arg,
202+
'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
203+
'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
204+
'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
205+
'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
206+
}
207+
208+
209+
NAME_TO_OP = {
210+
'AutoContrast': auto_contrast,
211+
'Equalize': equalize,
212+
'Invert': invert,
213+
'Rotate': rotate,
214+
'Posterize': posterize,
215+
'Solarize': solarize,
216+
'SolarizeAdd': solarize_add,
217+
'Color': color,
218+
'Contrast': contrast,
219+
'Brightness': brightness,
220+
'Sharpness': sharpness,
221+
'ShearX': shear_x,
222+
'ShearY': shear_y,
223+
'TranslateX': translate_x_abs,
224+
'TranslateY': translate_y_abs,
225+
'TranslateXRel': translate_x_rel,
226+
'TranslateYRel': translate_y_rel,
227+
}
228+
229+
230+
class AutoAugmentOp:
231+
232+
def __init__(self, name, prob, magnitude, hparams={}):
233+
self.aug_fn = NAME_TO_OP[name]
234+
self.level_fn = level_to_arg(hparams)[name]
235+
self.prob = prob
236+
self.magnitude = magnitude
237+
self.kwargs = {
238+
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
239+
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
240+
}
241+
self.rand_magnitude = True
242+
243+
def __call__(self, img):
244+
if self.prob < random.random():
245+
return img
246+
magnitude = self.magnitude
247+
if self.rand_magnitude:
248+
magnitude = random.normalvariate(magnitude, 0.5)
249+
magnitude = min(_MAX_LEVEL, max(0, magnitude))
250+
level_args = self.level_fn(magnitude)
251+
return self.aug_fn(img, *level_args, **self.kwargs)
252+
253+
254+
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
255+
"""Autoaugment policy that was used in AutoAugment Paper."""
256+
# Each tuple is an augmentation operation of the form
257+
# (operation, probability, magnitude). Each element in policy is a
258+
# sub-policy that will be applied sequentially on the image.
259+
policy = [
260+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
261+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
262+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
263+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
264+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
265+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
266+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
267+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
268+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
269+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
270+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
271+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
272+
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
273+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
274+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
275+
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
276+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
277+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
278+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
279+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
280+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
281+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
282+
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
283+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
284+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
285+
]
286+
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
287+
return pc
288+
289+
290+
class AutoAugment:
291+
292+
def __init__(self, policy):
293+
self.policy = policy
294+
295+
def __call__(self, img):
296+
sub_policy = random.choice(self.policy)
297+
for op in sub_policy:
298+
img = op(img)
299+
return img

timm/data/loader.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,21 @@ def create_transform(
109109
is_training=is_training, size=img_size, interpolation=interpolation)
110110
else:
111111
if is_training:
112-
transform = transforms_imagenet_train(
113-
img_size,
114-
color_jitter=color_jitter,
115-
interpolation=interpolation,
116-
use_prefetcher=use_prefetcher,
117-
mean=mean,
118-
std=std)
112+
if True:
113+
transform = transforms_imagenet_aa(
114+
img_size,
115+
interpolation=interpolation,
116+
use_prefetcher=use_prefetcher,
117+
mean=mean,
118+
std=std)
119+
else:
120+
transform = transforms_imagenet_train(
121+
img_size,
122+
color_jitter=color_jitter,
123+
interpolation=interpolation,
124+
use_prefetcher=use_prefetcher,
125+
mean=mean,
126+
std=std)
119127
else:
120128
transform = transforms_imagenet_eval(
121129
img_size,

timm/data/transforms.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1111
from .random_erasing import RandomErasing
12+
from .auto_augment import AutoAugment, auto_augment_policy_v0
1213

1314

1415
class ToNumpy:
@@ -160,6 +161,48 @@ def __repr__(self):
160161
return format_string
161162

162163

164+
def transforms_imagenet_aa(
165+
img_size=224,
166+
scale=(0.08, 1.0),
167+
interpolation='random',
168+
random_erasing=0.4,
169+
random_erasing_mode='const',
170+
use_prefetcher=False,
171+
mean=IMAGENET_DEFAULT_MEAN,
172+
std=IMAGENET_DEFAULT_STD
173+
):
174+
aa_params = dict(
175+
cutout_max_pad_fraction=0.75,
176+
cutout_const=100,
177+
translate_const=img_size[-1] // 2 - 1,
178+
img_mean=tuple([min(255, round(255*x)) for x in mean]),
179+
)
180+
if interpolation and interpolation != 'random':
181+
aa_params['interpolation'] = _pil_interp(interpolation)
182+
aa_policy = auto_augment_policy_v0(aa_params)
183+
184+
tfl = [
185+
RandomResizedCropAndInterpolation(
186+
img_size, scale=scale, interpolation=interpolation),
187+
transforms.RandomHorizontalFlip(),
188+
AutoAugment(aa_policy)
189+
]
190+
191+
if use_prefetcher:
192+
# prefetcher and collate will handle tensor conversion and norm
193+
tfl += [ToNumpy()]
194+
else:
195+
tfl += [
196+
transforms.ToTensor(),
197+
transforms.Normalize(
198+
mean=torch.tensor(mean),
199+
std=torch.tensor(std))
200+
]
201+
if random_erasing > 0.:
202+
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
203+
return transforms.Compose(tfl)
204+
205+
163206
def transforms_imagenet_train(
164207
img_size=224,
165208
scale=(0.08, 1.0),

0 commit comments

Comments
 (0)