Source code for mixmo.networks.wrn

"""
WideResNet network definition
ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
"""

import torch.nn as nn
import torch.nn.functional as F

from mixmo.utils.logger import get_logger
from mixmo.networks.resnet import (PreActResNet, PreActResNetMixMo)

LOGGER = get_logger(__name__, level="DEBUG")

BATCHNORM_MOMENTUM = 0.1
# Tensorflow setup from edward2
BATCHNORM_MOMENTUM_END = 0.1
USE_BIAS = False


[docs]class WideBasic(nn.Module): """ Basic wide residual block module """ expansion = 1
[docs] def __init__(self, inplanes, planes, stride=1): super(WideBasic, self).__init__() self.bn1 = nn.BatchNorm2d(inplanes, momentum=BATCHNORM_MOMENTUM) self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, bias=USE_BIAS) self.bn2 = nn.BatchNorm2d(planes, momentum=BATCHNORM_MOMENTUM) self.stride = stride padding = 1 if stride == 1 else 0 self.pad1 = nn.ZeroPad2d((0, 1, 0, 1)) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=padding, bias=USE_BIAS) self.equalInOut = (inplanes == planes) self.shortcut = nn.Sequential() if stride != 1 or not self.equalInOut: self.shortcut = nn.Sequential( nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=USE_BIAS), )
[docs] def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = F.relu(self.bn2(out)) if self.stride != 1 and not self.equalInOut: out = self.pad1(out) out = self.conv2(out) out += self.shortcut(x) return out
[docs]class WideResNet(PreActResNet): """ Standard WideResNet network Mostly re-using PreActResNet API and defining WideResNet specific parameters/architecture """
[docs] def _init_block(self, widen_factor): # Standard wide resnet depth check assert ((self.depth - 4) % 6 == 0), 'Wide-resnet self.depth should be 6n+4' n = (self.depth - 4) / 6 self._block = WideBasic self._layers = [n for _ in range(3)] self._nChannels = [ 16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor ]
def _make_conv1(self, nb_input_channel): return nn.Conv2d( nb_input_channel, self._nChannels[0], kernel_size=3, stride=1, padding=1, bias=USE_BIAS)
[docs] def _make_layer(self, block, planes, blocks, stride=1): strides = [stride] + [1] * int(blocks - 1) layers = [] for stride in strides: layers.append(block(self.inplanes, planes, stride)) self.inplanes = planes return nn.Sequential(*layers)
[docs] def _init_core_network(self): PreActResNet._init_core_network(self, max_layer=3) self.bn1 = nn.BatchNorm2d(self._nChannels[3], momentum=BATCHNORM_MOMENTUM_END) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
dense_gaussian = False def _forward_core_network(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = F.relu(self.bn1(x)) x_avg = self.avgpool(x) return x_avg.view(x_avg.size(0), -1)
[docs]class WideResNetMixMo(WideResNet): """ Multi-Input Multi-Output WideResNet network Mostly re-using ResNetMixMo API and WideResNet architecture/parameters """ def _init_first_layer(self): PreActResNetMixMo._init_first_layer(self)
[docs] def _init_final_classifier(self): PreActResNetMixMo._init_final_classifier(self)
def _forward_first_layer(self, *args, **kwargs): return PreActResNetMixMo._forward_first_layer(self, *args, **kwargs) def _forward_final_classifier(self, *args, **kwargs): return PreActResNetMixMo._forward_final_classifier(self, *args, **kwargs)
wrn_network_factory = { # For CIFAR "wideresnet": WideResNet, "wideresnetmixmo": WideResNetMixMo, }