Source code for mixmo.augmentations.standard_augmentations

"""
Basic augmentation procedures
"""

from torchvision import transforms
from mixmo.utils.logger import get_logger

LOGGER = get_logger(__name__, level="DEBUG")


[docs]class CustomCompose(transforms.Compose): def __call__(self, img, apply_postprocessing=True): for i, t in enumerate(self.transforms): if i == len(self.transforms) - 1 and not apply_postprocessing: return {"pixels": img, "postprocessing": t} img = t(img) return img
cifar_mean = (0.4913725490196078, 0.4823529411764706, 0.4466666666666667) cifar_std = (0.2023, 0.1994, 0.2010)
[docs]def get_default_composed_augmentations(dataset_name): if dataset_name.startswith("cifar"): normalize = transforms.Normalize(cifar_mean, cifar_std) # Transformer for train set: random crops and horizontal flip train_transformer = CustomCompose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), # randomly flip image horizontally # postprocessing CustomCompose([transforms.ToTensor(), normalize]) ] ) test_transformer = CustomCompose([ transforms.ToTensor(), normalize]) elif dataset_name.startswith("tinyimagenet"): normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) train_transformer = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.RandomCrop(64, padding=4), CustomCompose([transforms.ToTensor(), normalize]) ] ) test_transformer = CustomCompose([transforms.ToTensor(), normalize]) else: raise ValueError(dataset_name) return train_transformer, test_transformer