Source code for mixmo.loaders.cifar_dataset

"""
CIFAR 10 and 100 dataset wrappers
"""

import os
from PIL import Image
import numpy as np
from torchvision import datasets

from mixmo.utils.logger import get_logger

LOGGER = get_logger(__name__, level="DEBUG")


[docs]class CustomCIFAR10(datasets.CIFAR10): """ Torchvision's CIFAR10 dataset class augmented with a custom __getitem__ method """ def __getitem__(self, index, apply_postprocessing=True): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img, apply_postprocessing=apply_postprocessing) if self.target_transform is not None: target = self.target_transform(target) return img, target
[docs]class CustomCIFAR100(datasets.CIFAR100): """ Torchvision's CIFAR100 dataset class augmented with a custom __getitem__ method """ def __getitem__(self, *args, **kwargs): return CustomCIFAR10.__getitem__(self, *args, **kwargs)
CIFAR_ROBUSTNESS_FILENAMES = [ "speckle_noise.npy", "shot_noise.npy", "impulse_noise.npy", "defocus_blur.npy", "gaussian_blur.npy", "glass_blur.npy", "zoom_blur.npy", "fog.npy", "brightness.npy", "contrast.npy", "elastic_transform.npy", "pixelate.npy", "jpeg_compression.npy", "motion_blur.npy", "snow.npy", 'frost.npy', 'gaussian_noise.npy', 'saturate.npy', 'spatter.npy' ]
[docs]class CIFARCorruptions(CustomCIFAR10):
[docs] def __init__(self, root, transform=None, train=False): super(datasets.CIFAR10, self).__init__(root) self.transform = transform self.train = train # training set or test set self.data = [] self.targets = [] labels_path = os.path.join(root, "labels.npy") list_labels = np.load(labels_path).tolist() ## iterate over CIFAR_ROBUSTNESS_FILENAMES for i, filename in enumerate(CIFAR_ROBUSTNESS_FILENAMES): filepath = os.path.join(root, filename) data = np.load(filepath) if i == 0: ## initialize with first filename self.data = data self.targets = list_labels[:] else: self.data = np.concatenate( (self.data, data), axis=0) self.targets.extend(list_labels) nb_files = len(CIFAR_ROBUSTNESS_FILENAMES) LOGGER.warning(f"Robustness corruptions of len: {nb_files}") LOGGER.warning(f"Pixels of shape: {self.data.shape}") LOGGER.warning(f"Targets of len: {len(self.targets)}")