Source code for mixmo.utils.torchutils

"""
General tensor manipulation utility functions (initializations, permutations, one hot)
"""

import math
import torch

import torch.nn as nn

from mixmo.utils.logger import get_logger

LOGGER = get_logger(__name__, level="DEBUG")


[docs]def onehot(size, target): """ Translate scalar targets to one hot vectors """ vec = torch.zeros(size, dtype=torch.float32) vec[target] = 1. return vec
# permutation tools
[docs]def randperm_static(batch_size, proba_static=0): """ Perform random permutation with a set percentage remaining fixed """ if proba_static <= 0: return torch.randperm(batch_size) size_static = int(batch_size * proba_static) torch_static = torch.arange(0, size_static).long() size_shuffled = int(batch_size) - size_static torch_shuffled = size_static + torch.randperm(size_shuffled) return torch.cat([torch_static, torch_shuffled])
# initializations like in tensorflow
[docs]def truncated_normal_(tensor, mean=0, std=1): """ Initialization function """ size = tensor.shape tmp = tensor.new_empty(size + (4,)).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) tensor.data.mul_(std).add_(mean)
[docs]def _calculate_fan_in_and_fan_out(tensor): """ Compute the minimal input and output sizes for the weight tensor """ dimensions = tensor.dim() if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") if dimensions == 2: # Linear fan_in = tensor.size(1) fan_out = tensor.size(0) else: num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) receptive_field_size = 1 if tensor.dim() > 2: receptive_field_size = tensor[0][0].numel() fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out
[docs]def _calculate_correct_fan(tensor, mode): """ Return the minimal input or output sizes for the weight tensor depending on which is needed """ mode = mode.lower() valid_modes = ['fan_in', 'fan_out'] if mode not in valid_modes: raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) return fan_in if mode == 'fan_in' else fan_out
[docs]def kaiming_normal_truncated(tensor, a=0, mode='fan_in', nonlinearity='relu'): r"""Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from :math:`\mathcal{N}(0, \text{std})` where .. math:: \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}} Also known as He initialization. Args: tensor: an n-dimensional `torch.Tensor` a: the negative slope of the rectifier used after this layer (0 for ReLU by default) mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity: the non-linear function (`nn.functional` name), recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). Examples: >>> w = torch.empty(3, 5) >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') """ fan = _calculate_correct_fan(tensor, mode) gain = nn.init.calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) std = std / .87962566103423978 with torch.no_grad(): return truncated_normal_(tensor, 0, std)
[docs]def weights_init_hetruncatednormal(m, dense_gaussian=False): """ Simple init function """ classname = m.__class__.__name__ if classname.find('Conv') != -1: kaiming_normal_truncated(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') if m.bias is not None: nn.init.constant(m.bias, 0) elif classname.find('Linear') != -1: if dense_gaussian: nn.init.normal_(m.weight.data, mean=0, std=0.01) else: kaiming_normal_truncated( m.weight.data, a=0, mode='fan_in', nonlinearity='relu' ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): if m.weight is not None: m.weight.data.fill_(1) m.bias.data.zero_()