Source code for mixmo.augmentations.augmix

"""
Implementation following setups from PuzzleMix authors
According to the seminal code: https://github.com/google-research/augmix/blob/master/cifar.py
This code structure is borrowed from:
https://github.com/ildoonet/pytorch-randaugment/blob/616ef12a5176169b4e1e645728f3dedd1a5a148e/RandAugment/augmentations.py
"""

import random

import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
import numpy as np
import torch

from mixmo.utils import misc


[docs]def ShearX(img, v, myrandom=None): # [-0.3, 0.3] if myrandom is None: myrandom = random assert -0.3 <= v <= 0.3 if myrandom.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
[docs]def ShearY(img, v, myrandom=None): # [-0.3, 0.3] if myrandom is None: myrandom = random assert -0.3 <= v <= 0.3 if myrandom.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
[docs]def TranslateX(img, v, myrandom=None): # [-150, 150] => percentage: [-0.45, 0.45] if myrandom is None: myrandom = random assert -0.45 <= v <= 0.45 if myrandom.random() > 0.5: v = -v v = v * img.size[0] return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
[docs]def TranslateY(img, v, myrandom=None): # [-150, 150] => percentage: [-0.45, 0.45] if myrandom is None: myrandom = random assert -0.45 <= v <= 0.45 if myrandom.random() > 0.5: v = -v v = v * img.size[1] return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
[docs]def Rotate(img, v, myrandom=None): # [-30, 30] if myrandom is None: myrandom = random assert -30 <= v <= 30 if myrandom.random() > 0.5: v = -v return img.rotate(v)
[docs]def AutoContrast(img, *args, **kwargs): return PIL.ImageOps.autocontrast(img)
[docs]def Invert(img, *args, **kwargs): return PIL.ImageOps.invert(img)
[docs]def Equalize(img, *args, **kwargs): return PIL.ImageOps.equalize(img)
[docs]def Solarize(img, v, **kwargs): # [0, 256] assert 0 <= v <= 256 return PIL.ImageOps.solarize(img, v)
[docs]def Posterize(img, v, **kwargs): # [4, 8] assert 4 <= v <= 8 v = int(v) return PIL.ImageOps.posterize(img, v)
[docs]def Contrast(img, v, **kwargs): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Contrast(img).enhance(v)
[docs]def Color(img, v, **kwargs): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Color(img).enhance(v)
[docs]def Brightness(img, v, **kwargs): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Brightness(img).enhance(v)
[docs]def Sharpness(img, v, **kwargs): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Sharpness(img).enhance(v)
[docs]def Identity(img, v, **kwargs): return img
[docs]def augment_list(include_auto_contrast=False): # Accepted: equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y # Rejected: autocontrast?, brightness, contrast, color, sharpness l = [ (Identity, 0., 1.0), (Equalize, 0, 1), # 7 (Posterize, 4, 8), # 9 (Rotate, 0, 30), # 4 (Solarize, 0, 256), # 8 (ShearX, 0., 0.3), # 0 (ShearY, 0., 0.3), # 1 (TranslateX, 0., 0.45), # 2 (TranslateY, 0., 0.45), # 3 (Invert, 0, 1), # 6, ] if include_auto_contrast: l.append((AutoContrast, 0, 1)) # 5 return l
[docs]def get_value_when_none(value, default_value): if value is None: return default_value return value
[docs]class AugMix: _default_mixture_depth = -1 _default_mixture_width = 3 _default_severity = 3 # [0, 30]
[docs] def __init__(self, seed=None, mixture_depth=None, mixture_width=None, aug_severity=None, include_auto_contrast=False): self.mixture_depth = get_value_when_none( mixture_depth, self._default_mixture_depth) self.mixture_width = get_value_when_none(mixture_width, self._default_mixture_width) self.aug_severity = get_value_when_none(aug_severity, self._default_severity) self.augment_list = augment_list(include_auto_contrast=include_auto_contrast) self.seed = seed
def __call__(self, image, preprocess, ): myrandom = misc.get_random(self.seed) mynprandom = misc.get_nprandom(self.seed) ws = np.float32(mynprandom.dirichlet([1] * self.mixture_width)) m = np.float32(mynprandom.beta(1, 1)) mix = torch.zeros_like(preprocess(image)) for i in range(self.mixture_width): image_aug = image.copy() depth = self.mixture_depth if self.mixture_depth > 0 else myrandom.randint(1, 4) for _ in range(depth): x = mynprandom.choice(range(0, len(self.augment_list))) op, minval, maxval = self.augment_list[x] val = (float(self.aug_severity) / 30) * float(maxval - minval) + minval image_aug = op(image_aug, val, myrandom=myrandom) # Preprocessing commutes since all coefficients are convex mix += ws[i] * preprocess(image_aug) mixed = (1 - m) * preprocess(image) + m * mix return mixed