Source code for mixmo.augmentations.mixing_blocks

"""
Mixing blocks inspired from several standard mixing sample data augmentations
"""

import numpy as np
import torch
import torch.nn.functional as F
import numpy as np
import math

from mixmo.utils import misc


[docs]def mix_manifolds(list_lfeats, metadata): """ Main function to mix manifolds in network :param list_lfeats: list of latent features: torch.tensors encodings in merged features space :param metadata: metadata information passed in network :type metadata: dict :return: merged representation :rtype: torch tensor """ num_inputs = len(list_lfeats) if num_inputs == 1: return list_lfeats[0] if metadata["mode"] == "inference" and "mixmo_masks" not in metadata: # Standard sum mixing at inference tensor_aggreg = torch.stack(list_lfeats, dim=0).mean(dim=0) elif metadata["mode"] == "train" and "mixmo_masks" in metadata: # Custom mixing during training ## 1. Computing masks mixmo_masks = [ mimo_mix_mask.to(list_lfeats[0].device).to(torch.float) for mimo_mix_mask in metadata["mixmo_masks"] ] ## 2. Combine latent features with given masks if len(mixmo_masks) == 1: ### Special case for M=2, we can directly sum mask_0 = mixmo_masks[0] tensor_aggreg = mask_0 * list_lfeats[0] + (1. - mask_0) * list_lfeats[1] else: ### General case assert num_inputs > 2 list_tensor_aggreg = [mixmo_masks[i] * list_lfeats[i] for i in range(num_inputs)] tensor_aggreg = torch.stack(list_tensor_aggreg, dim=0).sum(dim=0) else: raise ValueError(metadata) return num_inputs * tensor_aggreg
######################################### Mask computing functions for mixing #########################################
[docs]def mix(method, lams, input_size, config_mix=None): """ Front facing function that computes the masks/lams for any number of inputs """ if len(lams) == 2: return _single_mix(method, lams[0], input_size, config_mix) return _n_mix(method, lams, input_size, config_mix=None)
[docs]def _single_mix(method, lam, input_size, config_mix): """ Computes masks for two inputs (traditional MSDA methods) Inputs: ------- method: string, mixing to use lam: float Returns: -------- masks: list of torch.Tensor. The masks used for mixing. lams: list of torch.Tensor. The lams ratio on the final masks. """ config_mix = config_mix or {} mask = MASK_MIX_DICT[method](input_size, lam, config_mix) lam = mask.mean() # adjust for potential drift when computing masks return [mask, 1 - mask], [lam, 1 - lam]
[docs]def _n_mix(method, lams, input_size, config_mix): """ Computes masks for M>2 inputs (requires lam tuples) Inputs: ------- method: string, mixing to use lam: float Returns: -------- masks: list of torch.Tensor. The masks used for mixing. lams: list of torch.Tensor. The lams ratio on the final masks. """ config_mix = config_mix or {} masks = MASK_N_MIX_DICT[method](input_size, lams, config_mix) lams = [mask.mean() for mask in masks] # adjust for potential drift return masks, lams
######################################### Mask generating functions #########################################
[docs]def _mixup_mask(input_size, lam, config_mix): """ Compute masks for MixUp (constant masks) """ mask = lam * torch.ones(input_size) return mask
[docs]def _cutmix_mask(input_size, lam, config_mix=None): """ Compute masks for CutMix """ # Get box [bbx1, bby1, bbx2, bby2] = _rand_bbox_of_area_lam( input_size, lam, seed=(config_mix or {}).get("seed")) # Build the mask mask = torch.zeros(input_size) mask[:, bbx1:bbx2, bby1:bby2] = 1 return mask
[docs]def _cow_mask(input_size, lam, config_mix): """ Compute masks for CowMix lam is overridden by Cowmask's parameters https://github.com/google-research/google-research/tree/master/milking_cowmask/masking """ # Default CowMix config misc.ifnotfound_update( config_mix, { "cow_p_max": 0.8, "cow_p_min": 0.2, "cow_sigma_max": 16.0, "cow_sigma_min": 4.0, } ) # Get lam ratio for the mask p_max = config_mix["cow_p_max"] p_min = config_mix["cow_p_min"] proba = torch.tensor(p_min + np.random.rand(1) * (p_max - p_min)) # Get sigma for Gaussian kernel sigma_max = config_mix["cow_sigma_max"] sigma_min = config_mix["cow_sigma_min"] sigma = np.exp(math.log(sigma_min) + np.random.rand(1) * (math.log(sigma_max) - math.log(sigma_min))) # Compute Gaussian kernel gaussian_kernel = _gaussian_blur_kernel(sigma, sigma_max) gaussian_kernel = gaussian_kernel.unsqueeze(-1) gaussian_kernel = gaussian_kernel.T * gaussian_kernel # Shape it as a proper kernel gaussian_kernel = gaussian_kernel.unsqueeze(0).unsqueeze(0) gaussian_kernel = gaussian_kernel.repeat((input_size[0], 1, 1, 1)).float() noise = torch.randn(1, 1, input_size[1], input_size[2]) blurred_noise = F.conv2d(noise, gaussian_kernel, padding = gaussian_kernel.size()[-1]//2) noise_mean = blurred_noise.mean() noise_std = blurred_noise.std() # Get thresholded cowmask threshold_stat = noise_mean + math.sqrt(2) * torch.erfinv(2*proba - 1) * noise_std mask = blurred_noise <= threshold_stat return mask.squeeze(0).float()
[docs]def _stack_mask(input_size, lam, config_mix): """ Compute masks for Channel/Horizontal/Vertical concat (number of images) x channel x (image width) x (image height) """ # Default config misc.ifnotfound_update(config_mix, { "stack_dim": 1, "stack_rdflip": True, }) dim = config_mix["stack_dim"] random_flip = config_mix["stack_rdflip"] flip = random_flip and misc.random_lower_than( prob=0.5, seed=None, r=None) if flip: lam = 1 - lam # Split the dimension in two border = int(lam * input_size[dim]) ones_size = list(input_size) ones_size[dim] = border zeros_size = list(input_size) zeros_size[dim] = input_size[dim] - border ones_mask = torch.ones(ones_size) zeros_mask = torch.zeros(zeros_size) # Merge the two split masks if flip: mask = torch.cat([zeros_mask, ones_mask], dim=dim) else: mask = torch.cat([ones_mask, zeros_mask], dim=dim) return mask
[docs]def _stack2_mask(input_size, lam, config_mix): """ Wrapper function for vertical concat mixing """ misc.ifnotfound_update(config_mix, {"stack_dim": 2, "stack_rdflip": True}) return _stack_mask(input_size, lam, config_mix)
[docs]def _stack0_mask(input_size, lam, config_mix): """ Wrapper function for channel concat mixing """ misc.ifnotfound_update(config_mix, {"stack_dim": 0, "stack_rdflip": True}) return _stack_mask(input_size, lam, config_mix)
[docs]def _noise_mask(input_size, lam, config_mix): """ Random mask pixels drawn from uniform distribution """ if config_mix["noise_2d"]: mask = torch.rand(input_size[1:]) else: mask = torch.rand(input_size) # rescale the center with a piecewise linear function to have the proper lam mask_below = torch.min(mask - 0.5,torch.zeros_like(mask)) mask_above = torch.max(mask - 0.5,torch.zeros_like(mask)) mask = 2 * (lam * mask_below + (1-lam) * mask_above) + lam if config_mix["noise_2d"]: mask = mask.repeat((3,1,1)) return mask
[docs]def _patchup_mask(input_size, lam, config_mix): """ Compute masks for PatchUp mixing https://github.com/chandar-lab/PatchUp """ # Default config misc.ifnotfound_update( config_mix, { "patchup_gamma": None, "patchup_block_size": 7, "patchup_soft": False, "patchup_2d": False }) if config_mix.get("patchup_gamma", None) is not None: gamma = config_mix["patchup_gamma"] else: gamma = lam block_size = config_mix["patchup_block_size"] kernel_size = (block_size, block_size) padding = (block_size//2, block_size//2) stride = (1,1) # As per the official patchup_hard implementation gamma *= (input_size[-1] ** 2 / ( block_size ** 2 * (input_size[-1] - block_size + 1) ** 2 ) ) if config_mix["patchup_2d"]: p = gamma * torch.ones(input_size[1:]) else: p = gamma * torch.ones(input_size) m_i_j = torch.bernoulli(p) if config_mix["patchup_2d"]: m_i_j = m_i_j.repeat((input_size[0], 1, 1)) # following line provides the continuous blocks that should be altered with PatchUp denoted as holes here. mask = F.max_pool2d(m_i_j, kernel_size, stride, padding) if config_mix["patchup_soft"]: mask = mask + (1 - mask) * lam return mask
[docs]def _patchuphard2d_mask(input_size, lam, config_mix): """ Wrapper function for PatchUp hard masking (2d variant) """ misc.ifnotfound_update(config_mix, { "patchup_2d": True, }) return _patchup_mask(input_size, lam, config_mix)
[docs]def _patchupsoft_mask(input_size, lam, config_mix): """ Wrapper function for PatchUp soft masking """ misc.ifnotfound_update(config_mix, { "patchup_soft": True, }) assert config_mix["patchup_soft"] return _patchup_mask(input_size, lam, config_mix)
[docs]def _channel_mask(input_size, lam, config_mix): """ Compute masks that toggle entire channels on and off """ p = lam * torch.ones((input_size[0], 1, 1)) # 0 is the channel dimension mask = torch.bernoulli(p) mask = mask.expand(input_size) return mask
MASK_MIX_DICT = { "patchuphard": _patchup_mask, "patchupsoft": _patchupsoft_mask, "patchuphard2d": _patchuphard2d_mask, "mixup": _mixup_mask, "cutmix": _cutmix_mask, "noise": _noise_mask, "stackchannel": _stack0_mask, # split on dimension 0 "stackhorizontal": _stack_mask, # split on dimension 1 "stackvertical": _stack2_mask, # split on dimension 2 "channel": _channel_mask, "cow": _cow_mask, } LIST_METHODS_NOT_INVARIANT_CHANNELS = [ "channel", "stackchannel", ]
[docs]def _n_mixup_mask(input_size, lams, config_mix): """ Multivariate MixUp generalization lam is a tuple here (simplex) that gives the proportion between n inputs """ masks = [] for lam in lams: masks.append(lam * torch.ones(input_size)) return masks
[docs]def _n_cutinmix_mask(input_size, lams, config_mix): """ Multivariate CutMix generalization (see paper) lam is a tuple here (simplex) that gives the proportion between n inputs CutMix(A, MixUp(B,C,...)) """ # Compute the base masks mixup_lams = [lam / sum(lams[1:]) for lam in lams[1:]] mixup_masks = _n_mixup_mask(input_size, mixup_lams, config_mix) cutmix_mask = _cutmix_mask(input_size, lams[0], config_mix) # Combine the mixup and cutmix masks masks = [cutmix_mask] masks += [(1 - cutmix_mask) * mask for mask in mixup_masks] return masks
MASK_N_MIX_DICT = { "mixup": _n_mixup_mask, "cutmix": _n_cutinmix_mask, } ######################################### General utility functions for masking #########################################
[docs]def _rand_bbox_of_area_lam(size, lam, seed=None): """ Compute the corner coordinates of a random rectangular box such that area_box/area_image=lam """ # Retrieving H and W depending on image format if len(size) == 4: W = size[2] H = size[3] elif len(size) == 3: W = size[1] H = size[2] else: raise Exception # Compute box width and height cut_rat = np.sqrt(lam) cut_w = np.int(W * cut_rat) cut_h = np.int(H * cut_rat) # Draw coordinates of the center of the box mynprandom = misc.get_nprandom(seed=seed) try: rg_integers = mynprandom.randint except AttributeError: rg_integers = mynprandom.integers cx = rg_integers(W) cy = rg_integers(H) # Box corners naturally follow bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) return bbx1, bby1, bbx2, bby2
[docs]def _gaussian_blur_kernel(sigma, sigma_max, sym=True): """ Compute Gaussian kernel, as per the scipy.signal implementation """ size = math.ceil(sigma_max * 3) * 2 + 1 # Keep up to 99.7 of the Gaussian for the kernel n = torch.arange(0, size).float() - (size - 1.0) / 2.0 sig2 = 2 * sigma * sigma w = np.exp(-n**2 / sig2) return w / math.sqrt(math.pi * sig2)