1+ """ Auto Augment
2+ Implementation adapted from:
3+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
4+ Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
5+
6+ Hacked together by Ross Wightman
7+ """
18import random
29import math
3- from torchvision import transforms
4- from PIL import Image , ImageOps , ImageEnhance , ImageChops , ImageDraw
10+ from PIL import Image , ImageOps , ImageEnhance
511import PIL
612import numpy as np
713
@@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__):
131137 return img
132138
133139
134- def posterize (img , bits , ** __ ):
135- return ImageOps .posterize (img , 4 - bits )
140+ def posterize (img , bits_to_keep , ** __ ):
141+ if bits_to_keep >= 8 :
142+ return img
143+ bits_to_keep = max (1 , bits_to_keep ) # prevent all 0 images
144+ return ImageOps .posterize (img , bits_to_keep )
136145
137146
138147def contrast (img , factor , ** __ ):
@@ -157,16 +166,19 @@ def _randomly_negate(v):
157166
158167
159168def _rotate_level_to_arg (level ):
169+ # range [-30, 30]
160170 level = (level / _MAX_LEVEL ) * 30.
161171 level = _randomly_negate (level )
162172 return (level ,)
163173
164174
165175def _enhance_level_to_arg (level ):
176+ # range [0.1, 1.9]
166177 return ((level / _MAX_LEVEL ) * 1.8 + 0.1 ,)
167178
168179
169180def _shear_level_to_arg (level ):
181+ # range [-0.3, 0.3]
170182 level = (level / _MAX_LEVEL ) * 0.3
171183 level = _randomly_negate (level )
172184 return (level ,)
@@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const):
179191
180192
181193def _translate_rel_level_to_arg (level ):
194+ # range [-0.45, 0.45]
182195 level = (level / _MAX_LEVEL ) * 0.45
183196 level = _randomly_negate (level )
184197 return (level ,)
@@ -190,9 +203,12 @@ def level_to_arg(hparams):
190203 'Equalize' : lambda level : (),
191204 'Invert' : lambda level : (),
192205 'Rotate' : _rotate_level_to_arg ,
193- 'Posterize' : lambda level : (int ((level / _MAX_LEVEL ) * 4 ),),
194- 'Solarize' : lambda level : (int ((level / _MAX_LEVEL ) * 256 ),),
195- 'SolarizeAdd' : lambda level : (int ((level / _MAX_LEVEL ) * 110 ),),
206+ # FIXME these are both different from original impl as I believe there is a bug,
207+ # not sure what is the correct alternative, hence 2 options that look better
208+ 'Posterize' : lambda level : (int ((level / _MAX_LEVEL ) * 4 ) + 4 ,), # range [4, 8]
209+ 'Posterize2' : lambda level : (4 - int ((level / _MAX_LEVEL ) * 4 ),), # range [4, 0]
210+ 'Solarize' : lambda level : (int ((level / _MAX_LEVEL ) * 256 ),), # range [0, 256]
211+ 'SolarizeAdd' : lambda level : (int ((level / _MAX_LEVEL ) * 110 ),), # range [0, 110]
196212 'Color' : _enhance_level_to_arg ,
197213 'Contrast' : _enhance_level_to_arg ,
198214 'Brightness' : _enhance_level_to_arg ,
@@ -212,6 +228,7 @@ def level_to_arg(hparams):
212228 'Invert' : invert ,
213229 'Rotate' : rotate ,
214230 'Posterize' : posterize ,
231+ 'Posterize2' : posterize ,
215232 'Solarize' : solarize ,
216233 'SolarizeAdd' : solarize_add ,
217234 'Color' : color ,
@@ -252,10 +269,8 @@ def __call__(self, img):
252269
253270
254271def auto_augment_policy_v0 (hparams = _HPARAMS_DEFAULT ):
255- """Autoaugment policy that was used in AutoAugment Paper."""
256- # Each tuple is an augmentation operation of the form
257- # (operation, probability, magnitude). Each element in policy is a
258- # sub-policy that will be applied sequentially on the image.
272+ # ImageNet policy from TPU EfficientNet impl, cannot find
273+ # a paper reference.
259274 policy = [
260275 [('Equalize' , 0.8 , 1 ), ('ShearY' , 0.8 , 4 )],
261276 [('Color' , 0.4 , 9 ), ('Equalize' , 0.6 , 3 )],
@@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
287302 return pc
288303
289304
305+ def auto_augment_policy_original (hparams = _HPARAMS_DEFAULT ):
306+ # ImageNet policy from https://arxiv.org/abs/1805.09501
307+ policy = [
308+ [('Posterize' , 0.4 , 8 ), ('Rotate' , 0.6 , 9 )],
309+ [('Solarize' , 0.6 , 5 ), ('AutoContrast' , 0.6 , 5 )],
310+ [('Equalize' , 0.8 , 8 ), ('Equalize' , 0.6 , 3 )],
311+ [('Posterize' , 0.6 , 7 ), ('Posterize' , 0.6 , 6 )],
312+ [('Equalize' , 0.4 , 7 ), ('Solarize' , 0.2 , 4 )],
313+ [('Equalize' , 0.4 , 4 ), ('Rotate' , 0.8 , 8 )],
314+ [('Solarize' , 0.6 , 3 ), ('Equalize' , 0.6 , 7 )],
315+ [('Posterize' , 0.8 , 5 ), ('Equalize' , 1.0 , 2 )],
316+ [('Rotate' , 0.2 , 3 ), ('Solarize' , 0.6 , 8 )],
317+ [('Equalize' , 0.6 , 8 ), ('Posterize' , 0.4 , 6 )],
318+ [('Rotate' , 0.8 , 8 ), ('Color' , 0.4 , 0 )],
319+ [('Rotate' , 0.4 , 9 ), ('Equalize' , 0.6 , 2 )],
320+ [('Equalize' , 0.0 , 7 ), ('Equalize' , 0.8 , 8 )],
321+ [('Invert' , 0.6 , 4 ), ('Equalize' , 1.0 , 8 )],
322+ [('Color' , 0.6 , 4 ), ('Contrast' , 1.0 , 8 )],
323+ [('Rotate' , 0.8 , 8 ), ('Color' , 1.0 , 2 )],
324+ [('Color' , 0.8 , 8 ), ('Solarize' , 0.8 , 7 )],
325+ [('Sharpness' , 0.4 , 7 ), ('Invert' , 0.6 , 8 )],
326+ [('ShearX' , 0.6 , 5 ), ('Equalize' , 1.0 , 9 )],
327+ [('Color' , 0.4 , 0 ), ('Equalize' , 0.6 , 3 )],
328+ [('Equalize' , 0.4 , 7 ), ('Solarize' , 0.2 , 4 )],
329+ [('Solarize' , 0.6 , 5 ), ('AutoContrast' , 0.6 , 5 )],
330+ [('Invert' , 0.6 , 4 ), ('Equalize' , 1.0 , 8 )],
331+ [('Color' , 0.6 , 4 ), ('Contrast' , 1.0 , 8 )],
332+ [('Equalize' , 0.8 , 8 ), ('Equalize' , 0.6 , 3 )],
333+ ]
334+ pc = [[AutoAugmentOp (* a , hparams ) for a in sp ] for sp in policy ]
335+ return pc
336+
337+
338+ def auto_augment_policy (name = 'v0' , hparams = _HPARAMS_DEFAULT ):
339+ if name == 'original' :
340+ return auto_augment_policy_original (hparams )
341+ elif name == 'v0' :
342+ return auto_augment_policy_v0 (hparams )
343+ else :
344+ assert False , 'Unknown AA policy (%s)' % name
345+
346+
290347class AutoAugment :
291348
292349 def __init__ (self , policy ):
0 commit comments