"""
Temperature scaling functions and networks modules
Taken from https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py
"""
import torch
from torch import nn, optim
from mixmo.utils.logger import get_logger
LOGGER = get_logger(__name__, level="INFO")
[docs]def apply_temperature_on_logits(logits, temperature):
"""
Apply temperature relaxation on logits
"""
reshaped_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
return logits / reshaped_temperature
[docs]class NetworkWithTemperature(nn.Module):
"""
A thin decorator, which wraps a network with temperature scaling
network (nn.Module):
"""
default_temperature = 1.0
[docs] def __init__(self, network, temperature=None, device=None):
nn.Module.__init__(self)
self.network = network
self._init_temperature = temperature or self.default_temperature
self.device = device
self.set_temperature(self._init_temperature)
def set_temperature(self, temperature):
self.temperature = nn.Parameter(
torch.ones(1) * temperature,
requires_grad=True
)
self.to(self.device)
[docs] def forward(self, input):
out = self.network(input)
out = self.apply_temperature(out, self.temperature)
return out
[docs] @staticmethod
def apply_temperature(output, temperature):
"""
Apply temperature scaling to outputs
"""
output["logits_prescaled"] = output["logits"]
output["logits"] = apply_temperature_on_logits(
logits=output["logits_prescaled"], temperature=temperature
)
return output
def learn_temperature_gridsearch(self, valid_loader, lrs, max_iters):
best_temperature = self.default_temperature
valid_loader_processed = self._prepare_training(valid_loader)
best_nll = valid_loader_processed[2]
for lr in lrs:
for max_iter in max_iters:
temperature, nll = self._learn_temperature(
valid_loader=None,
valid_loader_processed=valid_loader_processed,
lr=lr,
max_iter=max_iter
)
if best_nll > nll:
best_nll = nll
best_temperature = temperature
self.set_temperature(self._init_temperature)
assert best_nll <= valid_loader_processed[2], "temperature scaling failed because nll increased"
LOGGER.warning(f"Selecting temperature: {best_temperature:.5f} - nll : {best_nll:.5f}")
self.set_temperature(temperature=best_temperature)
def _prepare_training(self, valid_loader):
self.nll_criterion = nn.CrossEntropyLoss().to(self.device)
# First: collect all the logits and targets for the validation set
logits_list = []
targets_list = []
with torch.no_grad():
for data in valid_loader:
(input, target) = data
input = input.to(self.device)
logits = self.network(input)["logits"]
logits_list.append(logits)
targets_list.append(target)
logits = torch.cat(logits_list).to(self.device)
targets = torch.cat(targets_list).to(self.device)
# Calculate NLL before temperature scaling
before_temperature_nll = self.nll_criterion(logits, targets).item()
LOGGER.debug(f'Before temperature - nll: {before_temperature_nll:.5f}')
return logits, targets, before_temperature_nll
# This function probably should live outside of this class, but whatever
[docs] def _learn_temperature(self, valid_loader, lr, max_iter, valid_loader_processed=None):
"""
Tune the temperature of the network (using the validation set).
We're going to set it to optimize NLL.
valid_loader (DataLoader): validation set loader
"""
if valid_loader_processed is not None:
(logits, targets, before_temperature_nll) = valid_loader_processed
else:
(logits, targets, before_temperature_nll) = self._prepare_training(valid_loader)
# Next: optimize the temperature w.r.t. NLL
optimizer = optim.LBFGS([self.temperature], lr=lr, max_iter=max_iter)
def eval():
loss = self.nll_criterion(
apply_temperature_on_logits(logits, self.temperature),
targets)
loss.backward()
return loss
optimizer.step(eval)
# Calculate NLL after temperature scaling
after_temperature_nll = self.nll_criterion(
apply_temperature_on_logits(logits, self.temperature),
targets).item()
LOGGER.debug(
f'With lr: {lr:.6f}, temperature {self.temperature.item():.5f} - nll: {after_temperature_nll:.5f}'
)
if after_temperature_nll > before_temperature_nll:
LOGGER.error(r"Temperature scaling failed for lr: {lr:.6f}")
return self.default_temperature, before_temperature_nll
temperature = self.temperature.detach().to("cpu").numpy()[0]
return temperature, after_temperature_nll