"""
Dataset wrappers for multi-input multi-output models with data augmentation
"""
import math
import torch
from torch.utils.data.dataset import Dataset
from mixmo.utils import misc, config, torchutils
from mixmo.augmentations import augmix, mixing_blocks
[docs]class DADataset(Dataset):
"""
Dataset wrapper with with outputs formatted as dictionaries and AugMix augmentation
"""
[docs] def __init__(self, dataset, num_classes, num_members, dict_config, properties):
self.dataset = dataset
self.num_classes = num_classes
self.num_members = num_members
self.dict_config = dict_config
self.properties = properties
self._custom_init()
self.set_ratio_epoch(0)
def _custom_init(self):
pass
def set_ratio_epoch(self, ratioepoch):
self.ratio_epoch_current = ratioepoch
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
"""
Retrieve target and image, return a dictionary with the two
"""
pixels_0, target_0 = self.call_dataset(index["index_0"])
dict_output = {"pixels_0": pixels_0, "target_0": target_0}
return dict_output
[docs] def call_dataset(self, index, seed=None):
"""
Get target and image, apply AugMix if necessary and return dictionary
"""
if misc.is_none(self.dict_config["da_method"]):
pixels, target = self.dataset[index]
else:
dict_pixels_postprocessing, target = self.dataset.__getitem__(
index, apply_postprocessing=False
)
if self.dict_config["da_method"] == "augmix":
pixels = augmix.AugMix(seed=seed)(
image=dict_pixels_postprocessing["pixels"],
preprocess=dict_pixels_postprocessing["postprocessing"]
)
else:
raise ValueError(self.dict_config["da_method"])
return pixels, torchutils.onehot(self.num_classes, target)
[docs]class MSDADataset(DADataset):
"""
Dataset wrapper that returns dictionaries and applies MSDA augmentations
"""
reverse_if_first_minor = False
def _custom_init(self):
self._custom_init_msda()
def _custom_init_msda(self):
self.msda_mix_method = self.dict_config["msda"]["mix_method"]
self.msda_beta = self.dict_config["msda"]["beta"]
self.msda_prob = self.dict_config["msda"]["prob"]
[docs] def call_msda(self, index_0, mixmo_mask=None, seed_da=None):
"""
Get two samples and mix them. Return a dictionary of sample and label
"""
# Gather the two image/label pairs used by the augmentation
pixels_0, target_0 = self.call_dataset(index_0, seed=seed_da)
skip_msda = (
self.msda_mix_method is None or not misc.random_lower_than(self.msda_prob)
)
if skip_msda: # Early exit if we are not mixing
return pixels_0, target_0
index_1 = misc.get_random(seed=None).choice(range(len(self)))
pixels_1, target_1 = self.call_dataset(index_1, seed=seed_da)
targets = [target_0, target_1]
# Get mixing masks
msda_lams = misc.sample_lams(self.msda_beta, n=2)
msda_masks, msda_lams = mixing_blocks.mix(
method=self.msda_mix_method,
lams=msda_lams,
input_size=pixels_0.size(),
)
# Adjust the lams to account for later mixmo mixing that might alter masks
if mixmo_mask is not None:
## approx for computational issues: mask should be symmetrical in channels
mixmo_mask_0 = mixmo_mask[0, :, :]
if self.properties("conv1_is_half_size"):
_msda_mask_0 = torch.nn.AvgPool2d(kernel_size=(2, 2))(msda_masks[0][:1, :, :])
msda_masks_for_lam = [_msda_mask_0.to(torch.float16)]
else:
mixmo_mask_0 = mixmo_mask_0.to(torch.float32)
msda_masks_for_lam = msda_masks
## Compute the adjusted ratios after mixmo mixing
mean_mixmo_mask_0 = mixmo_mask_0.mean()
msda_lams = [
(msda_mask[0, :, :] * mixmo_mask_0).mean() / (mean_mixmo_mask_0 + 1e-8)
for msda_mask in msda_masks_for_lam
]
if self.properties("conv1_is_half_size"):
lam = msda_lams[0].to(torch.float32)
msda_lams = [lam, 1-lam]
# Randomly reverse the roles of mixed samples (important to symmetrize CutMix, Patch-Up, ...)
if self.reverse_if_first_minor and msda_lams[0] < 0.5:
msda_pixels = msda_masks[1] * pixels_0 + msda_masks[0] * pixels_1
msda_lams = [msda_lams[1], msda_lams[0]]
else:
msda_pixels = msda_masks[0] * pixels_0 + msda_masks[1] * pixels_1
# Standard MSDA label interpolation
msda_targets = sum([
lam * target for lam, target
in zip(msda_lams, targets)])
return msda_pixels, msda_targets
def __getitem__(self, index):
"""
Return a dictionary with the relevant sample and target, possibly mixed with another
"""
if self.msda_mix_method is None:
return DADataset.__getitem__(self, index)
pixels_0, target_0 = self.call_msda(index_0 = index["index_0"])
dict_output = {"pixels_0": pixels_0, "target_0": target_0}
return dict_output
[docs]class MixMoDataset(MSDADataset):
"""
Dataset wrapper that returns dictionaries of multiple samples, and applies MSDA augmentations
"""
reverse_if_first_minor = True
def _custom_init(self):
self._custom_init_msda()
self._custom_init_mixmo()
def _custom_init_mixmo(self):
self.dict_mixmo_mix_method = self.dict_config["mixmo"]["mix_method"]
# dict with key 'method_name', 'prob' and 'replacement_method_name'
self.mixmo_alpha = float(self.dict_config["mixmo"]["alpha"])
self.mixmo_weight_root = self.dict_config["mixmo"]["weight_root"]
[docs] def get_mixmo_mix_method_at_ratio_epoch(self, batch_seed=None):
"""
Select which mixing method should be used according to training scheduling.
Procedure:
Select self.dict_mixmo_mix_method["method_name"] with proba self.dict_mixmo_mix_method["prob"] that is linearly decreased towards 0 after 11/12 of training process.
Otherwise, use self.dict_mixmo_mix_method["replacement_method_name"] (in general mixup)
"""
method = self.dict_mixmo_mix_method["method_name"]
replacement_method = self.dict_mixmo_mix_method["replacement_method_name"]
if method == replacement_method:
return method
# Check the actual switch probability according to scheduler and current epoch
default_prob = self.dict_mixmo_mix_method["prob"]
if self.ratio_epoch_current < config.cfg.RATIO_EPOCH_DECREASE:
prob = default_prob
else:
eta = max(0, (1 - self.ratio_epoch_current) / (1 - config.cfg.RATIO_EPOCH_DECREASE))
prob = default_prob * eta
# Choose the method depending on draw result
if misc.random_lower_than(prob, seed=batch_seed):
return method
return replacement_method
[docs] def _init_dict_output_mixmo(self, batch_seed):
"""
Compute MixMo block variables (masks, lams) and prepare it as a dictionary output
"""
# Get MixMo mixing method and the corresponding masks/lams
mixmo_mix_method = self.get_mixmo_mix_method_at_ratio_epoch(
batch_seed=batch_seed
)
mixmo_lams = misc.sample_lams(self.mixmo_alpha, n=self.num_members)
mixmo_masks, mixmo_lams = mixing_blocks.mix(
method=mixmo_mix_method,
lams=mixmo_lams,
input_size=self.properties("conv1_input_size"),
)
# Shuffle the roles of the inputs (same for every sample in the batch)
# Mostly useful for asymmetrical mixing (CutMix, ...)
assert batch_seed is not None
myrandom = misc.get_random(seed=batch_seed+config.cfg.RANDOM.SEED_OFFSET_MIXMO)
zipped_masking = list(zip(mixmo_lams, mixmo_masks))
myrandom.shuffle(zipped_masking)
# Format everything nicely in dictionaries
dict_output = {"metadata": {"mixmo_lams": [el[0] for el in zipped_masking], "mixmo_masks": [el[1] for el in zipped_masking]}}
if mixmo_mix_method not in mixing_blocks.LIST_METHODS_NOT_INVARIANT_CHANNELS:
dict_output["metadata"]["mixmo_masks"] = [
mimo_mix_mask[:1, :, :].to(torch.float16)
for mimo_mix_mask in dict_output["metadata"]["mixmo_masks"]]
return dict_output
def __getitem__(self, index):
"""
Get a (mixed) sample/label pair for each head and output it in a dictionary
"""
# Initialize output with mixing block descriptors
dict_output = self._init_dict_output_mixmo(
batch_seed=index["batch_seed"])
# Compute sample/label pairs for each head
for num_member in range(0, self.num_members):
member_index = index["index_" + str(num_member)]
# useful only for augmix: force same transfos on batch duplicated samples
seed_da = index["batch_seed"] + config.cfg.RANDOM.SEED_DA * member_index + num_member
## Retrieve and compute mixed sample/label, accounting for mixmo mixing
seed_da = index["batch_seed"] + config.cfg.RANDOM.SEED_DA * member_index + num_member
pixels_member, target_member = self.call_msda(
index_0=member_index,
mixmo_mask=dict_output["metadata"]["mixmo_masks"][num_member],
seed_da=seed_da
)
## Format output
dict_output.update({
"pixels_" + str(num_member): pixels_member,
"target_" + str(num_member): target_member
})
dict_output = self._target_balancing(dict_output)
if self.num_members == 2:
# only keep first to reduce memory footprint
# as the second can be obtained by 1 - mask
dict_output["metadata"]["mixmo_masks"] = dict_output["metadata"]["mixmo_masks"][:1]
return dict_output
[docs] def _target_balancing(self, dict_output):
"""
Final formatting of outputs with mixmo balancing
"""
def apply_root(a):
return math.pow(a, (1 / self.mixmo_weight_root))
# Get balancing weights
_list_weights_not_normalized = [apply_root(lam) for lam in dict_output["metadata"]["mixmo_lams"]]
norm = sum(_list_weights_not_normalized)
list_weights = [self.num_members * weight / norm for weight in _list_weights_not_normalized]
# Apply weights: as we use categorical cross entropy, we can simply multiply the target
for i in range(self.num_members):
dict_output["target_{}".format(i)] *= list_weights[i]
return dict_output