|
| 1 | +import random |
| 2 | +import math |
| 3 | +from torchvision import transforms |
| 4 | +from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw |
| 5 | +import PIL |
| 6 | +import numpy as np |
| 7 | + |
| 8 | + |
| 9 | +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) |
| 10 | + |
| 11 | +_FILL = (128, 128, 128) |
| 12 | + |
| 13 | +# This signifies the max integer that the controller RNN could predict for the |
| 14 | +# augmentation scheme. |
| 15 | +_MAX_LEVEL = 10. |
| 16 | + |
| 17 | +_HPARAMS_DEFAULT = dict( |
| 18 | + translate_const=250, |
| 19 | + img_mean=_FILL, |
| 20 | +) |
| 21 | + |
| 22 | +_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) |
| 23 | + |
| 24 | + |
| 25 | +def _interpolation(kwargs): |
| 26 | + interpolation = kwargs.pop('resample', Image.NEAREST) |
| 27 | + if isinstance(interpolation, (list, tuple)): |
| 28 | + return random.choice(interpolation) |
| 29 | + else: |
| 30 | + return interpolation |
| 31 | + |
| 32 | + |
| 33 | +def _check_args_tf(kwargs): |
| 34 | + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): |
| 35 | + kwargs.pop('fillcolor') |
| 36 | + kwargs['resample'] = _interpolation(kwargs) |
| 37 | + |
| 38 | + |
| 39 | +def shear_x(img, factor, **kwargs): |
| 40 | + _check_args_tf(kwargs) |
| 41 | + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) |
| 42 | + |
| 43 | + |
| 44 | +def shear_y(img, factor, **kwargs): |
| 45 | + _check_args_tf(kwargs) |
| 46 | + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) |
| 47 | + |
| 48 | + |
| 49 | +def translate_x_rel(img, pct, **kwargs): |
| 50 | + pixels = pct * img.size[0] |
| 51 | + _check_args_tf(kwargs) |
| 52 | + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
| 53 | + |
| 54 | + |
| 55 | +def translate_y_rel(img, pct, **kwargs): |
| 56 | + pixels = pct * img.size[1] |
| 57 | + _check_args_tf(kwargs) |
| 58 | + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
| 59 | + |
| 60 | + |
| 61 | +def translate_x_abs(img, pixels, **kwargs): |
| 62 | + _check_args_tf(kwargs) |
| 63 | + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
| 64 | + |
| 65 | + |
| 66 | +def translate_y_abs(img, pixels, **kwargs): |
| 67 | + _check_args_tf(kwargs) |
| 68 | + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
| 69 | + |
| 70 | + |
| 71 | +def rotate(img, degrees, **kwargs): |
| 72 | + _check_args_tf(kwargs) |
| 73 | + if _PIL_VER >= (5, 2): |
| 74 | + return img.rotate(degrees, **kwargs) |
| 75 | + elif _PIL_VER >= (5, 0): |
| 76 | + w, h = img.size |
| 77 | + post_trans = (0, 0) |
| 78 | + rotn_center = (w / 2.0, h / 2.0) |
| 79 | + angle = -math.radians(degrees) |
| 80 | + matrix = [ |
| 81 | + round(math.cos(angle), 15), |
| 82 | + round(math.sin(angle), 15), |
| 83 | + 0.0, |
| 84 | + round(-math.sin(angle), 15), |
| 85 | + round(math.cos(angle), 15), |
| 86 | + 0.0, |
| 87 | + ] |
| 88 | + |
| 89 | + def transform(x, y, matrix): |
| 90 | + (a, b, c, d, e, f) = matrix |
| 91 | + return a * x + b * y + c, d * x + e * y + f |
| 92 | + |
| 93 | + matrix[2], matrix[5] = transform( |
| 94 | + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix |
| 95 | + ) |
| 96 | + matrix[2] += rotn_center[0] |
| 97 | + matrix[5] += rotn_center[1] |
| 98 | + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) |
| 99 | + else: |
| 100 | + return img.rotate(degrees, resample=kwargs['resample']) |
| 101 | + |
| 102 | + |
| 103 | +def auto_contrast(img, **__): |
| 104 | + return ImageOps.autocontrast(img) |
| 105 | + |
| 106 | + |
| 107 | +def invert(img, **__): |
| 108 | + return ImageOps.invert(img) |
| 109 | + |
| 110 | + |
| 111 | +def equalize(img, **__): |
| 112 | + return ImageOps.equalize(img) |
| 113 | + |
| 114 | + |
| 115 | +def solarize(img, thresh, **__): |
| 116 | + return ImageOps.solarize(img, thresh) |
| 117 | + |
| 118 | + |
| 119 | +def solarize_add(img, add, thresh=128, **__): |
| 120 | + lut = [] |
| 121 | + for i in range(256): |
| 122 | + if i < thresh: |
| 123 | + lut.append(min(255, i + add)) |
| 124 | + else: |
| 125 | + lut.append(i) |
| 126 | + if img.mode in ("L", "RGB"): |
| 127 | + if img.mode == "RGB" and len(lut) == 256: |
| 128 | + lut = lut + lut + lut |
| 129 | + return img.point(lut) |
| 130 | + else: |
| 131 | + return img |
| 132 | + |
| 133 | + |
| 134 | +def posterize(img, bits, **__): |
| 135 | + return ImageOps.posterize(img, 4 - bits) |
| 136 | + |
| 137 | + |
| 138 | +def contrast(img, factor, **__): |
| 139 | + return ImageEnhance.Contrast(img).enhance(factor) |
| 140 | + |
| 141 | + |
| 142 | +def color(img, factor, **__): |
| 143 | + return ImageEnhance.Color(img).enhance(factor) |
| 144 | + |
| 145 | + |
| 146 | +def brightness(img, factor, **__): |
| 147 | + return ImageEnhance.Brightness(img).enhance(factor) |
| 148 | + |
| 149 | + |
| 150 | +def sharpness(img, factor, **__): |
| 151 | + return ImageEnhance.Sharpness(img).enhance(factor) |
| 152 | + |
| 153 | + |
| 154 | +def _randomly_negate(v): |
| 155 | + """With 50% prob, negate the value""" |
| 156 | + return -v if random.random() > 0.5 else v |
| 157 | + |
| 158 | + |
| 159 | +def _rotate_level_to_arg(level): |
| 160 | + level = (level / _MAX_LEVEL) * 30. |
| 161 | + level = _randomly_negate(level) |
| 162 | + return (level,) |
| 163 | + |
| 164 | + |
| 165 | +def _enhance_level_to_arg(level): |
| 166 | + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) |
| 167 | + |
| 168 | + |
| 169 | +def _shear_level_to_arg(level): |
| 170 | + level = (level / _MAX_LEVEL) * 0.3 |
| 171 | + level = _randomly_negate(level) |
| 172 | + return (level,) |
| 173 | + |
| 174 | + |
| 175 | +def _translate_abs_level_to_arg(level, translate_const): |
| 176 | + level = (level / _MAX_LEVEL) * float(translate_const) |
| 177 | + level = _randomly_negate(level) |
| 178 | + return (level,) |
| 179 | + |
| 180 | + |
| 181 | +def _translate_rel_level_to_arg(level): |
| 182 | + level = (level / _MAX_LEVEL) * 0.45 |
| 183 | + level = _randomly_negate(level) |
| 184 | + return (level,) |
| 185 | + |
| 186 | + |
| 187 | +def level_to_arg(hparams): |
| 188 | + return { |
| 189 | + 'AutoContrast': lambda level: (), |
| 190 | + 'Equalize': lambda level: (), |
| 191 | + 'Invert': lambda level: (), |
| 192 | + 'Rotate': _rotate_level_to_arg, |
| 193 | + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), |
| 194 | + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), |
| 195 | + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), |
| 196 | + 'Color': _enhance_level_to_arg, |
| 197 | + 'Contrast': _enhance_level_to_arg, |
| 198 | + 'Brightness': _enhance_level_to_arg, |
| 199 | + 'Sharpness': _enhance_level_to_arg, |
| 200 | + 'ShearX': _shear_level_to_arg, |
| 201 | + 'ShearY': _shear_level_to_arg, |
| 202 | + 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), |
| 203 | + 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), |
| 204 | + 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), |
| 205 | + 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level), |
| 206 | + } |
| 207 | + |
| 208 | + |
| 209 | +NAME_TO_OP = { |
| 210 | + 'AutoContrast': auto_contrast, |
| 211 | + 'Equalize': equalize, |
| 212 | + 'Invert': invert, |
| 213 | + 'Rotate': rotate, |
| 214 | + 'Posterize': posterize, |
| 215 | + 'Solarize': solarize, |
| 216 | + 'SolarizeAdd': solarize_add, |
| 217 | + 'Color': color, |
| 218 | + 'Contrast': contrast, |
| 219 | + 'Brightness': brightness, |
| 220 | + 'Sharpness': sharpness, |
| 221 | + 'ShearX': shear_x, |
| 222 | + 'ShearY': shear_y, |
| 223 | + 'TranslateX': translate_x_abs, |
| 224 | + 'TranslateY': translate_y_abs, |
| 225 | + 'TranslateXRel': translate_x_rel, |
| 226 | + 'TranslateYRel': translate_y_rel, |
| 227 | +} |
| 228 | + |
| 229 | + |
| 230 | +class AutoAugmentOp: |
| 231 | + |
| 232 | + def __init__(self, name, prob, magnitude, hparams={}): |
| 233 | + self.aug_fn = NAME_TO_OP[name] |
| 234 | + self.level_fn = level_to_arg(hparams)[name] |
| 235 | + self.prob = prob |
| 236 | + self.magnitude = magnitude |
| 237 | + self.kwargs = { |
| 238 | + 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, |
| 239 | + 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION |
| 240 | + } |
| 241 | + self.rand_magnitude = True |
| 242 | + |
| 243 | + def __call__(self, img): |
| 244 | + if self.prob < random.random(): |
| 245 | + return img |
| 246 | + magnitude = self.magnitude |
| 247 | + if self.rand_magnitude: |
| 248 | + magnitude = random.normalvariate(magnitude, 0.5) |
| 249 | + magnitude = min(_MAX_LEVEL, max(0, magnitude)) |
| 250 | + level_args = self.level_fn(magnitude) |
| 251 | + return self.aug_fn(img, *level_args, **self.kwargs) |
| 252 | + |
| 253 | + |
| 254 | +def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): |
| 255 | + """Autoaugment policy that was used in AutoAugment Paper.""" |
| 256 | + # Each tuple is an augmentation operation of the form |
| 257 | + # (operation, probability, magnitude). Each element in policy is a |
| 258 | + # sub-policy that will be applied sequentially on the image. |
| 259 | + policy = [ |
| 260 | + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], |
| 261 | + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], |
| 262 | + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], |
| 263 | + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], |
| 264 | + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], |
| 265 | + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], |
| 266 | + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], |
| 267 | + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], |
| 268 | + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], |
| 269 | + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], |
| 270 | + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], |
| 271 | + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], |
| 272 | + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], |
| 273 | + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], |
| 274 | + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], |
| 275 | + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], |
| 276 | + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], |
| 277 | + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], |
| 278 | + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], |
| 279 | + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], |
| 280 | + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], |
| 281 | + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], |
| 282 | + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], |
| 283 | + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], |
| 284 | + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], |
| 285 | + ] |
| 286 | + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] |
| 287 | + return pc |
| 288 | + |
| 289 | + |
| 290 | +class AutoAugment: |
| 291 | + |
| 292 | + def __init__(self, policy): |
| 293 | + self.policy = policy |
| 294 | + |
| 295 | + def __call__(self, img): |
| 296 | + sub_policy = random.choice(self.policy) |
| 297 | + for op in sub_policy: |
| 298 | + img = op(img) |
| 299 | + return img |
0 commit comments