"""
Utility definitions to wrap a model with losses, metrics and logs
"""
import copy
import torch.nn.functional as F
from collections import OrderedDict
from mixmo.networks import get_network
from mixmo.core import (
loss, optimizer, temperature_scaling, scheduler,
metrics_wrapper)
from mixmo.utils import logger, misc, torchsummary
from mixmo.utils.config import cfg
LOGGER = logger.get_logger(__name__, level="DEBUG")
[docs]def get_predictions(logits):
"""
Convert logits into softmax predictions
"""
probs = F.softmax(logits, dim=1)
confidence, pred = probs.max(dim=1, keepdim=True)
return confidence, pred, probs
[docs]class ModelWrapper:
"""Augment a model with losses, metrics, internal logs and other things
"""
[docs] def __init__(self, config, config_args, device):
self.config = config
self.name = config["name"]
self.config_args = config_args
self.device = device
self.mode = "notinit"
self._init_main()
def _init_main(self):
self.network = get_network(
config_network=self.config["network"], config_args=self.config_args
).to(self.device)
self._scaled_network = None
self.scheduler = None
self._scheduler_initialized = False
self.loss = loss.get_loss(
config_loss=self.config.get("loss"),
config_args=self.config_args, device=self.device
)
if hasattr(self.loss, "set_regularized_network"):
self.loss.set_regularized_network(self.network)
self.optimizer = optimizer.get_optimizer(
optimizer=self.config["optimizer"],
list_param_groups=[{"params": list(self.network.parameters())}]
)
[docs] def to_eval_mode(self):
"""
Switch model to eval mode
"""
self.mode = "eval"
self.network.eval()
self.loss.start_accumulator()
self._init_metrics()
[docs] def to_train_mode(self, epoch):
"""
Switch model to train mode
"""
self.mode = "train"
if not self._scheduler_initialized:
self._init_scheduler(epoch)
self.network.train()
self.loss.start_accumulator()
self._init_metrics()
def _init_scheduler(self, epoch):
self.scheduler = scheduler.get_scheduler(
lr_schedule=self.config["lr_schedule"],
optimizer=self.optimizer,
start_epoch=epoch - 2,
)
self.scheduler.step()
if epoch == 1 and self.config.get("warmup_period", 0) > 0:
LOGGER.warning("Warmup period")
self.warmup_scheduler = scheduler.get_warmup_scheduler(
optimizer=self.optimizer,
warmup_period=self.config.get("warmup_period"))
else:
self.warmup_scheduler = None
self._scheduler_initialized = True
def _init_metrics(self):
if self.mode == "eval":
metrics = [*self.config["metrics"]] + self.config.get("metrics_only_test", [])
else:
metrics = self.config["metrics"]
self._metrics = metrics_wrapper.MetricsWrapper(metrics=metrics)
def print_summary(self, pixels_size=32):
summary_input = (3 * self.config_args["num_members"], pixels_size, pixels_size)
try:
torchsummary.summary(self.network, summary_input, list_dtype=None)
except:
LOGGER.warning("Torch summary failed", exc_info=True)
[docs] def step(self, output, target, backprop=False):
"""
Compute loss, backward step and metrics if required by config
Update internal records
"""
current_loss = self.loss(output, target)
if backprop:
current_loss.backward(retain_graph=False)
logits = output["logits" if self.mode != "train" else "logits_0"]
confidence, pred, probs = get_predictions(logits)
target = target["target_0"]
if len(target.size()) == 2:
target = target.argmax(axis=1)
self._metrics.update(pred, target, confidence, probs)
if self.mode != "train":
self._compute_diversity(output, target)
[docs] def _compute_diversity(self, output, target):
"""
Compute diversity and update internal records
"""
if self.config_args["num_members"] > 1:
predictions = [
output["logits_" + str(head)].max(dim=1, keepdim=False)[1].detach().to("cpu").numpy()
for head in range(
0, self.config_args["num_members"])
]
if self.config_args["num_members"] != 1:
self._metrics.update_diversity(
target=[int(t) for t in target.detach().to("cpu").numpy()],
predictions=predictions,
)
[docs] def get_short_logs(self):
"""
Return summary of internal records
"""
return self.loss.get_accumulator_stats(format="short", split=None)
[docs] def get_dict_to_scores(self, split,):
"""
Format logs into a dictionary
"""
logs_dict = OrderedDict({})
if split == "train":
lr_value = self.optimizer.param_groups[0]["lr"]
logs_dict[f"general/{self.name}_lr"] = {
"value": lr_value,
"string": f"{lr_value:05.5}",
}
misc.clean_update(logs_dict, self.loss.get_accumulator_stats(format="long", split=split))
if self.mode == "eval":
LOGGER.info(f"Compute metrics for {self.name} at split: {split}")
scores = self._metrics.get_scores(split=split)
for s in scores:
logs_dict[s] = scores[s]
return logs_dict
[docs] def predict(self, data):
"""
Perform a forward pass through the model and return the output
"""
return self.scaled_network(data)
@property
def scaled_network(self):
"""
Returns scaled_model if necessary for amp
"""
if self._scaled_network is None:
return self.network
else:
return self._scaled_network
[docs] def calibrate_via_tempscale(self, tempscale_loader):
"""
Returns calibrated temperature on val/test set
"""
self.to_eval_mode()
self._scaled_network = temperature_scaling.NetworkWithTemperature(
network=self.network, device=self.device
)
self._scaled_network.learn_temperature_gridsearch(
valid_loader=tempscale_loader,
lrs=cfg.CALIBRATION.LRS,
max_iters=cfg.CALIBRATION.MAX_ITERS
)
return self._scaled_network.temperature.cpu().detach().numpy()[0]