'''
Resnet for cifar dataset.
Adapted from
https://github.com/facebook/fb.resnet.torch
and
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
(c) YANG, Wei
'''
import torch
import torch.nn as nn
from torch.nn import functional as F
from mixmo.augmentations import mixing_blocks
from mixmo.utils import torchutils
from mixmo.utils.logger import get_logger
LOGGER = get_logger(__name__, level="DEBUG")
BATCHNORM_MOMENTUM_PREACT = 0.1
[docs]class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1
[docs] def __init__(self, inplanes, planes, stride=1, **kwargs):
super(PreActBlock, self).__init__()
final_planes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(inplanes, momentum=BATCHNORM_MOMENTUM_PREACT)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BATCHNORM_MOMENTUM_PREACT)
if stride != 1 or inplanes != final_planes:
self.shortcut = nn.Sequential(
nn.Conv2d(inplanes, final_planes, kernel_size=1, stride=stride, bias=False)
)
[docs] def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out)))
out += shortcut
return out
[docs]class PreActResNet(nn.Module):
"""
Pre-activated ResNet network
"""
[docs] def __init__(self, config_network, config_args):
nn.Module.__init__(self)
self.config_network = config_network
self.config_args = config_args
self._define_config()
self._init_first_layer()
self._init_core_network()
self._init_final_classifier()
self._init_weights_resnet()
LOGGER.warning("Features dimension: {features_dim}".format(features_dim=self.features_dim))
[docs] def _define_config(self):
"""
Initialize network parameters from specified config
"""
# network config
self.num_classes = self.config_args["data"]["num_classes"]
self.depth = self.config_network["depth"]
self._init_block(widen_factor=self.config_network["widen_factor"])
[docs] def _init_block(self, widen_factor):
"""
Build list of residual blocks for networks on the CIFAR datasets
Network type specifies number of layers for CIFAR network
"""
blocks = {
18: PreActBlock,
}
layers = {
18: [2, 2, 2, 2],
}
assert layers[
self.depth
], 'invalid depth for ResNet (self.depth should be one of 18, 34, 50, 101, 152, and 200)'
self._layers = layers[self.depth]
self._block = blocks[self.depth]
assert widen_factor in [1., 2., 3.]
self._nChannels = [
64,
64 * widen_factor, 128 * widen_factor,
256 * widen_factor, 512 * widen_factor
]
def _init_first_layer(self):
assert self.config_args["num_members"] == 1
self.conv1 = self._make_conv1(nb_input_channel=3)
[docs] def _init_core_network(self, max_layer=4):
"""
Build the core of the Residual network (residual blocks)
"""
self.inplanes = self._nChannels[0]
self.layer1 = self._make_layer(self._block, planes=self._nChannels[1],
blocks=self._layers[0], stride=1)
self.layer2 = self._make_layer(self._block, planes=self._nChannels[2],
blocks=self._layers[1], stride=2)
self.layer3 = self._make_layer(self._block, planes=self._nChannels[3],
blocks=self._layers[2], stride=2)
if max_layer == 4:
self.layer4 = self._make_layer(self._block, self._nChannels[4], blocks=self._layers[3], stride=2)
self.features_dim = self._nChannels[-1] * self._block.expansion
def _make_conv1(self, nb_input_channel):
conv1 = nn.Conv2d(
nb_input_channel, self._nChannels[0], kernel_size=3, stride=2, padding=1, bias=False
)
return conv1
[docs] def _make_layer(
self,
block,
planes,
blocks,
stride=1,
):
"""
Build a layer of successive (residual) blocks
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(
inplanes=self.inplanes,
planes=planes,
stride=stride,
downsample=downsample)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,))
return nn.Sequential(*layers)
[docs] def _init_final_classifier(self):
"""
Build linear classification head
"""
self.fc = nn.Linear(self.features_dim, self.num_classes)
dense_gaussian = True
[docs] def _init_weights_resnet(self):
"""
Apply specified random initializations to all modules of the network
"""
for m in self.modules():
torchutils.weights_init_hetruncatednormal(m, dense_gaussian=self.dense_gaussian)
[docs] def forward(self, x):
if isinstance(x, dict):
metadata = x["metadata"] or {}
pixels = x["pixels"]
else:
metadata = {"mode": "inference"}
pixels = x
merged_representation = self._forward_first_layer(pixels, metadata)
extracted_features = self._forward_core_network(merged_representation)
dict_output = self._forward_final_classifier(extracted_features)
return dict_output
def _forward_first_layer(self, pixels, metadata=None):
return self.conv1(pixels)
def _forward_core_network(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x_avg = F.avg_pool2d(x, 4)
return x_avg.view(x_avg.size(0), -1)
def _forward_final_classifier(self, extracted_features):
x = self.fc(extracted_features)
dict_output = {"logits": x, "logits_0": x}
return dict_output
[docs]class PreActResNetMixMo(PreActResNet):
"""
Multi-Input Multi-Output ResNet network
"""
[docs] def _init_first_layer(self):
"""
Initialize the M input heads/encoders
"""
list_conv1 = []
for _ in range(0, self.config_args["num_members"]):
list_conv1.append(self._make_conv1(nb_input_channel=3))
self.list_conv1 = nn.ModuleList(list_conv1)
[docs] def _init_final_classifier(self):
"""
Initialize the M output heads/classifiers
"""
list_fc = []
for _ in range(0, self.config_args["num_members"]):
fc = nn.Linear(self.features_dim, self.num_classes)
list_fc.append(fc)
self.list_fc = nn.ModuleList(list_fc)
def _forward_first_layer(self, pixels, metadata):
metadata = metadata or {}
list_lfeats = []
# Embed the M inputs into the shared space
for num_member in range(0, self.config_args["num_members"]):
if pixels.size(1) == 3:
pixels_member = pixels
else:
pixels_member = pixels[:, 3*num_member:3*(num_member + 1)]
list_lfeats.append(self.list_conv1[num_member](pixels_member))
# Mix the M inputs in the shared space
merged_representation = mixing_blocks.mix_manifolds(list_lfeats, metadata=metadata)
return merged_representation
def _forward_final_classifier(self, extracted_features):
dict_output = {}
# compute individual logits
for num_member in range(0, self.config_args["num_members"]):
logits_n = self.list_fc[num_member](extracted_features)
dict_output["logits_" + str(num_member)] = logits_n
# compute ensemble logits by averaging
_list_logits = [
dict_output["logits_" + str(num_member)]
for num_member in range(0, self.config_args["num_members"])
]
dict_output["logits"] = torch.stack(_list_logits, dim=0).mean(dim=0)
return dict_output
resnet_network_factory = {
# For TinyImageNet
"resnet": PreActResNet,
"resnetmixmo": PreActResNetMixMo,
}