Source code for mixmo.loaders.abstract_loader

"""
Base DataLoader definitions
"""
import os
import numpy as np
import torch

from mixmo.utils.logger import get_logger
from mixmo.loaders import dataset_wrapper, batch_repetition_sampler
from mixmo.utils.config import cfg

LOGGER = get_logger(__name__, level="DEBUG")


[docs]class AbstractDataLoader: """ General dataloader that defines how loaders are built """
[docs] def __init__(self, config_args, dataplace, split_test_val=False, corruptions=False): self.config_args = config_args self.dataplace = dataplace self.batch_size = int(config_args['training']['batch_size']) self.num_workers = 10 if not cfg.DEBUG else 0 self._init_dataaugmentations() self._init_dataset(corruptions) self._init_train_loader() self._init_valtest_loader(split_test_val)
def _init_dataaugmentations(self): raise NotImplementedError def _init_dataset(self, corruptions=False): # self.train_dataset = None # self.test_dataset = None raise NotImplementedError
[docs] def _init_train_loader(self): """ Build the train loader with the proper sampler and data augmentations """ # Choose the right dataset type if self.config_args["num_members"] > 1: class_dataset_wrapper = dataset_wrapper.MixMoDataset else: class_dataset_wrapper = dataset_wrapper.MSDADataset # Load augmentations self.traindatasetwrapper = class_dataset_wrapper( dataset=self.train_dataset, num_classes=int(self.config_args["data"]["num_classes"]), num_members=self.config_args["num_members"], dict_config=self.config_args["training"]["dataset_wrapper"], properties=self.properties ) # Build standard sampler _train_sampler = torch.utils.data.sampler.RandomSampler( data_source=self.traindatasetwrapper, ## only needed for its length num_samples=None, replacement=False, ) # Wrap it with the repeating sampler used for multi-input models batch_sampler = batch_repetition_sampler.BatchRepetitionSampler( sampler=_train_sampler, batch_size=self.batch_size, num_members=self.config_args["num_members"], drop_last=True, config_batch_sampler=self.config_args["training"]["batch_sampler"] ) self.train_loader = torch.utils.data.DataLoader( self.traindatasetwrapper, batch_sampler=batch_sampler, num_workers=self.num_workers, batch_size=1, shuffle=False, sampler=None, drop_last=False, pin_memory=True, )
[docs] def _init_valtest_loader(self, split_test_val): """ Build the test (and possibly val) loader with the proper sampler and data augmentations """ if not split_test_val: LOGGER.warning("No validation loader") self.val_loader = None self.test_loader = self.make_standard_loader( self.test_dataset) else: split_ratio = 0.5 LOGGER.warning("Validation size={split_ratio} taken from test".format( split_ratio=split_ratio)) num_test = len(self.test_dataset) indices = list(range(num_test)) test_idx_npy = os.path.join(self.data_dir, "test_idx.npy") val_idx_npy = os.path.join(self.data_dir, "val_idx.npy") if os.path.exists(test_idx_npy): LOGGER.warning("Loading existing test-val split indices") test_idx = np.load(test_idx_npy) val_idx = np.load(val_idx_npy) else: split = int(np.floor(split_ratio * num_test)) np.random.seed(cfg.RANDOM.SEED_TESTVAL) np.random.shuffle(indices) val_idx, test_idx = indices[:split], indices[split:] np.save(test_idx_npy, test_idx) np.save(val_idx_npy, val_idx) # _init samplers test_dataset = torch.utils.data.Subset(self.test_dataset, test_idx) val_dataset = torch.utils.data.Subset(self.test_dataset, val_idx) # _init loaders self.val_loader = self.make_standard_loader( val_dataset) self.test_loader = self.make_standard_loader( test_dataset)
[docs] def make_standard_loader(self, dataset): """ Build a dataloader from a dataset (wrapper on torch.utils) """ return torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, pin_memory=not (cfg.DEBUG > 0), num_workers=self.num_workers, )
[docs] def make_corruptions_test_dataset(self): """ Make robustness test dataset à la CIFAR10-C Prototype function (redefined for specific datasets) """ corruptions_test_dataset = None return corruptions_test_dataset