Source code for mixmo.core.loss

"""
Base loss definitions
"""
from collections import OrderedDict
import copy
import torch
import torch.nn as nn

from mixmo.utils import misc, logger

LOGGER = logger.get_logger(__name__, level="DEBUG")


[docs]class AbstractLoss(nn.modules.loss._Loss): """ Base loss class defining printing and logging utilies """
[docs] def __init__(self, config_args, device, config_loss=None): self.device = device self.config_args = config_args or {} self.config_loss = config_loss or {} self.name = self.config_loss["display_name"] nn.modules.loss._Loss.__init__(self)
def print_details(self): LOGGER.info(f"Using loss: {self.config_loss} with name: {self.name}") def start_accumulator(self): self._accumulator_loss = 0 self._accumulator_len = 0
[docs] def get_accumulator_stats(self, format="short", split=None): """ Gather tracked stats into a dictionary as formatted strings """ if not self._accumulator_len: return {} stats = OrderedDict({}) loss_value = self._accumulator_loss / self._accumulator_len if format == "long": assert split is not None key = split + "/" + self.name stats[key] = { "value": loss_value, "string": f"{loss_value:.5}", } else: # make it as short as possibe to fit on one line of tqdm postfix loss_string = f"{loss_value:.3}".replace("e-0", "-").replace("e-", "-") stats[self.name] = loss_string return stats
[docs] def forward(self, input, target): current_loss = self._forward(input, target) self._accumulator_loss += current_loss.detach().to("cpu").numpy() self._accumulator_len += 1 return current_loss
def _forward(self, input, target): raise NotImplementedError
[docs]class SoftCrossEntropyLoss(AbstractLoss): """ Soft CrossEntropy loss that specifies the proper forward function for AbstractLoss """
[docs] def _forward(self, input, target): """ Cross entropy that accepts soft targets Args: pred: predictions for neural network targets: targets, can be soft size_average: if false, sum is returned instead of mean Examples:: input = torch.FloatTensor([[1.1, 2.8, 1.3], [1.1, 2.1, 4.8]]) input = torch.autograd.Variable(out, requires_grad=True) target = torch.FloatTensor([[0.05, 0.9, 0.05], [0.05, 0.05, 0.9]]) target = torch.autograd.Variable(y1) loss = cross_entropy(input, target) loss.backward() """ if len(target.size()) == 1: target = torch.nn.functional.one_hot(target, num_classes=input.size(-1)) target = target.to(torch.float).to(self.device) logsoftmax = torch.nn.LogSoftmax(dim=1) return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
DICT_LOSS_STANDARD = { "soft_cross_entropy": SoftCrossEntropyLoss, }
[docs]class WrapperLoss(AbstractLoss): """ Wrapper around the multiple losses. Initialized from listloss. """
[docs] def __init__(self, config_loss, config_args, device): AbstractLoss.__init__( self, config_args=config_args, config_loss=config_loss, device=device, ) self.losses = self._init_get_losses() self.regularized_network = None
[docs] def _init_get_losses(self): """ Initialize and gather losses from listloss """ losses = [] for ic, config_loss in enumerate(self.config_loss["listloss"]): if config_loss["coeff"] == "<num_members": config_loss["coeff"] = (1. if ic < self.config_args["num_members"] else 0) if config_loss["coeff"] == 0: LOGGER.debug(f"Skip loss: {config_loss}") continue loss_callable = get_loss(config_loss, device=self.device, config_args=self.config_args) loss = copy.deepcopy(config_loss) loss["callable"] = loss_callable losses.append(loss) return losses
def print_details(self): return def start_accumulator(self): AbstractLoss.start_accumulator(self) for loss in self.losses: loss["callable"].start_accumulator()
[docs] def get_accumulator_stats(self, format="short", split=None): """ Gather tracked stats into a dictionary as formatted strings """ if not self._accumulator_len: return {} stats = AbstractLoss.get_accumulator_stats(self, format=format, split=split) if format == "long": # tensorboard logs if self.config_loss.get("l2_reg"): l2_reg = self.l2_reg().detach().to("cpu").numpy() stats["general/l2_reg"] = { "value": l2_reg, "string": f"{l2_reg:.4}", } for loss in self.losses: substats = loss["callable"].get_accumulator_stats( format=format, split=split, ) misc.clean_update(stats, substats) return stats
[docs] def _forward(self, input, target): """ Perform loss forwards for each sublosses and l2 reg """ computed_losses = [self._forward_subloss(loss, input, target) for loss in self.losses] stacked_computed_losses = torch.stack(computed_losses) final_loss = stacked_computed_losses.sum() if self.config_loss.get("l2_reg"): final_loss = final_loss + self.l2_reg() * float(self.config_loss.get("l2_reg")) return final_loss
[docs] def _forward_subloss(self, loss, input, target): """ Standard loss forward for one of the sublosses """ coeff = float(loss["coeff"]) subloss_input = self._match_item(loss["input"], dict_tensors=input) subloss_target = self._match_item(loss["target"], dict_tensors=target) loss = loss["callable"](input=subloss_input, target=subloss_target) return loss * coeff
@staticmethod def _match_item(name, dict_tensors): if misc.is_none(name): return None if name in dict_tensors: return dict_tensors[str(name)] raise ValueError(name) def set_regularized_network(self, network): if self.config_loss.get("l2_reg"): self.regularized_network = network LOGGER.warning(f"Set l2 regularization on {network.__class__.__name__}")
[docs] def l2_reg(self,): """ Compute l2 regularization/weight decay over the non-excluded parameters """ assert self.regularized_network is not None # Retrieve non excluded parameters params = list(self.regularized_network.parameters()) # Iterate over all parameters to decay l2_reg = None for W in params: if l2_reg is None: l2_reg = torch.sum(torch.pow(W, 2)) else: l2_reg = l2_reg + torch.sum(torch.pow(W, 2)) assert l2_reg is not None return l2_reg
[docs]def get_loss(config_loss, device=None, config_args=None): """ Construct loss object, wrapped if there are multiple losses """ loss_name = config_loss["name"] if loss_name == "multitask": loss = WrapperLoss(config_args=config_args, device=device, config_loss=config_loss) elif loss_name in DICT_LOSS_STANDARD: loss = DICT_LOSS_STANDARD[loss_name]( config_loss=config_loss, config_args=config_args, device=device ) else: raise Exception(f"Loss {loss_name} not implemented") loss.print_details() return loss