Source code for mixmo.loaders.loader

"""
Custom Dataloaders for each of the considered datasets
"""

import os

from torchvision import datasets

from mixmo.augmentations.standard_augmentations import get_default_composed_augmentations
from mixmo.loaders import cifar_dataset, abstract_loader
from mixmo.utils.logger import get_logger

LOGGER = get_logger(__name__, level="DEBUG")


[docs]class CIFAR10Loader(abstract_loader.AbstractDataLoader): """ Loader for the CIFAR10 dataset that inherits the abstract_loader.AbstractDataLoader dataloading API and defines the proper augmentations and datasets """ def _init_dataaugmentations(self): (self.augmentations_train, self.augmentations_test) = get_default_composed_augmentations( dataset_name="cifar", ) def _init_dataset(self, corruptions=False): self.train_dataset = cifar_dataset.CustomCIFAR10( root=self.data_dir, train=True, download=True, transform=self.augmentations_train ) if not corruptions: self.test_dataset = cifar_dataset.CustomCIFAR10( root=self.data_dir, train=False, download=True, transform=self.augmentations_test ) else: self.test_dataset = cifar_dataset.CIFARCorruptions( root=self.corruptions_data_dir, train=False, transform=self.augmentations_test ) @property def data_dir(self): return os.path.join(self.dataplace, "cifar10-data") @property def corruptions_data_dir(self): return os.path.join(self.dataplace, "CIFAR-10-C") @staticmethod def properties(key): dict_key_to_values = { "conv1_input_size": (16, 32, 32), "conv1_is_half_size": False, "pixels_size": 32, } return dict_key_to_values[key]
[docs]class CIFAR100Loader(CIFAR10Loader): """ Loader for the CIFAR100 dataset that inherits the abstract_loader.AbstractDataLoader dataloading API and defines the proper augmentations and datasets """ def _init_dataset(self, corruptions=False): self.train_dataset = cifar_dataset.CustomCIFAR100( root=self.data_dir, train=True, download=True, transform=self.augmentations_train ) if not corruptions: self.test_dataset = cifar_dataset.CustomCIFAR100( root=self.data_dir, train=False, download=True, transform=self.augmentations_test ) else: self.test_dataset = cifar_dataset.CIFARCorruptions( root=self.corruptions_data_dir, train=False, transform=self.augmentations_test ) @property def data_dir(self): return os.path.join(self.dataplace, "cifar100-data") @property def corruptions_data_dir(self): return os.path.join(self.dataplace, "CIFAR-100-C")
[docs]class TinyImagenet200Loader(abstract_loader.AbstractDataLoader): """ Loader for the TinyImageNet dataset that inherits the abstract_loader.AbstractDataLoader dataloading API and defines the proper augmentations and datasets """ def _init_dataaugmentations(self): (self.augmentations_train, self.augmentations_test) = get_default_composed_augmentations( dataset_name="tinyimagenet", ) @property def data_dir(self): return os.path.join(self.dataplace, "tinyimagenet200-data") def _init_dataset(self, corruptions=False): traindir = os.path.join(self.data_dir, 'train') valdir = os.path.join(self.data_dir, 'val/images') self.train_dataset = datasets.ImageFolder(traindir, self.augmentations_train) self.test_dataset = datasets.ImageFolder(valdir, self.augmentations_test) @staticmethod def properties(key): dict_key_to_values = { "conv1_input_size": (64, 32, 32), "conv1_is_half_size": True, "pixels_size": 64, } return dict_key_to_values[key]