Source code for mixmo.learners.learner

"""
Training and evaluation loop definitions for the Learner objects
"""

from tqdm import tqdm

import torch

from mixmo.utils import logger, config
from mixmo.learners import abstract_learner

LOGGER = logger.get_logger(__name__, level="DEBUG")

[docs]class Learner(abstract_learner.AbstractLearner): """ Learner object that defines the specific train and test loops for the model """
[docs] def _subloop(self, dict_tensors, backprop): """ Basic subloop for a step/batch (without optimization) """ # Format input input_model = {"pixels": dict_tensors["pixels"]} if "metadata" in dict_tensors: input_model["metadata"] = dict_tensors["metadata"] # Forward pass output_network = self.model_wrapper.predict( input_model) # Compute loss, backward and metrics self.model_wrapper.step( output=output_network, target=dict_tensors["target"], backprop=backprop, ) return self.model_wrapper.get_short_logs()
[docs] def _train_subloop(self, dict_tensors,): """ Complete training step for a batch, return summary logs """ # Reset optimizers self.model_wrapper.optimizer.zero_grad() # Backprop dict_to_log = self._subloop(dict_tensors, backprop=True) # Optimizer step self.model_wrapper.optimizer.step() return dict_to_log
[docs] def train_loop(self, epoch): """ Training loop for one epoch """ # Set loop counter for the epoch loop = tqdm(self.dloader.train_loader, dynamic_ncols=True) # Loop over all samples in the train set for batch_id, data in enumerate(loop): loop.set_description(f"Epoch {epoch}") # Prepare the batch dict_tensors = self._prepare_batch_train(data) # Perform the training step for the batch dict_to_log = self._train_subloop(dict_tensors=dict_tensors) del dict_tensors # Tie up end of step details loop.set_postfix(dict_to_log) loop.update() if config.cfg.DEBUG >= 2 and batch_id >= 10: break if self.model_wrapper.warmup_scheduler is not None: self.model_wrapper.warmup_scheduler.step()
[docs] def evaluate_loop(self, inference_loader): """ Evaluation loop over the dataset specified by the loader """ # Set loop counter for the loader/dataset loop = tqdm(inference_loader, disable=False, dynamic_ncols=True) # Loop over all samples in the evaluated dataset for batch_id, data in enumerate(loop): loop.set_description(f"Evaluation") # Prepare the batch dict_tensors = self._prepare_batch_test(data) # Forward over the batch, stats are logged internally with torch.no_grad(): _ = self._subloop(dict_tensors, backprop=False) if config.cfg.DEBUG >= 2 and batch_id >= 10: break
[docs] def _prepare_batch_train(self, data): """ Prepares the train batch by setting up the input dictionary and putting tensors on devices """ dict_tensors = {"pixels": [], "target": {}} # Concatenate inputs along channel dimension and collect targets for num_member in range(self.config_args["num_members"]): dict_tensors["pixels"].append(data["pixels_" + str(num_member)]) dict_tensors["target"]["target_" + str(num_member)] = data[ "target_" + str(num_member)].to(self.device) dict_tensors["pixels"] = torch.cat(dict_tensors["pixels"], dim=1).to(self.device) # Pass along batch metadata dict_tensors["metadata"] = data.get("metadata", {}) dict_tensors["metadata"]["mode"] = "train" return dict_tensors
[docs] def _prepare_batch_test(self, data): """ Prepares the test batch by setting up the input dictionary and putting tensors on devices """ (pixels, target) = data dict_tensors = { "pixels": pixels.to(self.device), "target": { "target_" + str(num_member): target.to(self.device) for num_member in range(self.config_args["num_members"]) }, "metadata": { "mode": "inference" } } return dict_tensors