Source code for mixmo.learners.abstract_learner

"""
Base Learner wrapper definitions for logging, training and evaluating models
"""

import torch
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter

from mixmo.utils import misc, logger, config
from mixmo.learners import model_wrapper


LOGGER = logger.get_logger(__name__, level="INFO")


[docs]class AbstractLearner: """ Base learner class that groups models, optimizers and loggers Performs the entire model building, training and evaluating process """
[docs] def __init__(self, config_args, dloader, device): self.config_args = config_args self.device = device self.dloader = dloader self._tb_logger = None self._create_model_wrapper() self._best_acc = 0 self._best_epoch = 0
[docs] def _create_model_wrapper(self): """ Initialize the model along with other elements through a ModelWrapper """ self.model_wrapper = model_wrapper.ModelWrapper( config=self.config_args["model_wrapper"], config_args=self.config_args, device=self.device ) self.model_wrapper.to_eval_mode() self.model_wrapper.print_summary( pixels_size=self.dloader.properties("pixels_size") )
@property def tb_logger(self): """ Get (or initialize) the Tensorboard SummaryWriter """ if self._tb_logger is None: self._tb_logger = SummaryWriter(log_dir=self.config_args["training"]["output_folder"]) return self._tb_logger
[docs] def save_tb(self, logs_dict, epoch): """ Write stats from logs_dict at epoch to the Tensoboard summary writer """ for tag in logs_dict: self.tb_logger.add_scalar(tag, logs_dict[tag]["value"], epoch) if "test/diversity_accuracy_mean" not in logs_dict: self.tb_logger.add_scalar( "test/diversity_accuracy_mean", logs_dict["test/accuracy"]["value"], epoch )
[docs] def load_checkpoint(self, checkpoint, include_optimizer=True, return_epoch=False): """ Load checkpoint (and optimizer if included) to the wrapped model """ checkpoint = torch.load(checkpoint, map_location=self.device) self.model_wrapper.network.load_state_dict(checkpoint[self.model_wrapper.name + "_state_dict"], strict=True) if include_optimizer: if self.model_wrapper.optimizer is not None: self.model_wrapper.optimizer.load_state_dict( checkpoint[self.model_wrapper.name + "_optimizer_state_dict"]) else: assert self.model_wrapper.name + "_optimizer_state_dict" not in checkpoint if return_epoch: return checkpoint["epoch"]
[docs] def save_checkpoint(self, epoch, save_path=None): """ Save model (and optimizer) state dict """ # get save_path if epoch is not None: dict_to_save = {"epoch": epoch} if save_path is None: save_path = misc.get_model_path( self.config_args["training"]["output_folder"], epoch=epoch ) else: assert save_path is not None # update dict to save dict_to_save[self.model_wrapper.name + "_state_dict"] = ( self.model_wrapper.network.state_dict() if isinstance(self.model_wrapper.network, torch.nn.DataParallel) else self.model_wrapper.network.state_dict()) if self.model_wrapper.optimizer is not None: dict_to_save[self.model_wrapper.name + "_optimizer_state_dict"] = self.model_wrapper.optimizer.state_dict() # final save torch.save(dict_to_save, save_path)
def train_loop(self, epoch): raise NotImplementedError
[docs] def train(self, epoch): """ Train for one epoch """ self.model_wrapper.to_train_mode(epoch=epoch) # Train over the entire epoch self.train_loop(epoch) # Eval on epoch end logs_dict = OrderedDict( { "epoch": {"value": epoch, "string": f"{epoch}"}, } ) scores = self.model_wrapper.get_dict_to_scores(split="train") for s in scores: logs_dict[s] = scores[s] ## Val scores if self.dloader.val_loader is not None: val_scores = self.evaluate( inference_loader=self.dloader.val_loader, split="val") for val_score in val_scores: logs_dict[val_score] = val_scores[val_score] ## Test scores test_scores = self.evaluate( inference_loader=self.dloader.test_loader, split="test") for test_score in test_scores: logs_dict[test_score] = test_scores[test_score] ## Print metrics misc.print_dict(logs_dict) ## Check if best epoch is_best_epoch = False ens_acc = float(logs_dict["test/accuracy"]["value"]) if ens_acc >= self._best_acc: self._best_acc = ens_acc self._best_epoch = epoch is_best_epoch = True ## Save the model checkpoint ## and not config.cfg.DEBUG if is_best_epoch: logs_dict["general/checkpoint_saved"] = {"value": 1.0, "string": "1.0"} self.save_checkpoint(epoch) LOGGER.warning(f"Epoch: {epoch} was saved") else: logs_dict["general/checkpoint_saved"] = {"value": 0.0, "string": "0.0"} ## CSV logging short_logs_dict = OrderedDict( {k: v for k, v in logs_dict.items() if any([regex in k for regex in [ "test/accuracy", "train/accuracy", "epoch", "checkpoint_saved" ]]) }) misc.csv_writter( path=misc.get_logs_path(self.config_args["training"]["output_folder"]), dic=short_logs_dict ) # Tensorboard logging if not config.cfg.DEBUG: self.save_tb(logs_dict, epoch=epoch) # Perform end of step procedure like scheduler update self.model_wrapper.scheduler.step()
def evaluate_loop(self, dloader, verbose, **kwargs): raise NotImplementedError
[docs] def evaluate(self, inference_loader, split="test"): """ Perform an evaluation of the model """ # Restart stats self.model_wrapper.to_eval_mode() # Evaluation over the dataset properly speaking self.evaluate_loop(inference_loader) # Gather scores scores = self.model_wrapper.get_dict_to_scores(split=split) return scores