Source code for mixmo.core.metrics_wrapper

"""
Wrapper functions for metric tracking
Mostly taken from https://github.com/bayesgroup/pytorch-ensembles/blob/master/metrics.py
"""

import numpy as np
from sklearn.metrics import roc_auc_score

from mixmo.utils import visualize
from mixmo.utils.logger import get_logger
from mixmo.core import metrics_ensemble

LOGGER = get_logger(__name__, level="INFO")



[docs]def merge_scores(scores_test, scores_val): """ Aggregate scores """ scores_valtest = {} for key in scores_test: key_valtest = "final/" + key.split("/")[1] if key.startswith("test/"): keyval = "val/" + key.split("/")[1] value = 0.5 * (scores_test[key]["value"] + scores_val[keyval]["value"]) if scores_test[key]["string"].endswith("%"): value_str = f"{value:05.2%}" else: value_str = f"{value:.6}" stats = {"value": value, "string": value_str} scores_valtest[key_valtest] = stats else: scores_valtest[key_valtest] = scores_test[key] return scores_valtest
[docs]def _clean_metrics(metrics, output_format="float"): """ Reformat metrics dictionary """ new_dict = {} for k, v in metrics.items(): if isinstance(v, dict): v = v["string"] if isinstance(v, str): if v.endswith("%"): v = v[:-1] if output_format == "float": v = float(v) new_dict[k] = v return new_dict
[docs]def show_metrics(scores_test): """ Results printer """ keys = [ "final/accuracy", "final/accuracytop5", "final/nll", "final/ece", ] clean_scores_test = _clean_metrics(scores_test, output_format="str") our_results = [clean_scores_test[key] for key in keys] print(" & ".join(keys)) print(" & ".join(our_results))
[docs]def get_ece(proba_pred, accurate, n_bins=15, min_pred=0, write_file=None, **args): """ Compute ECE and write to file """ if min_pred == "minpred": min_pred = min(proba_pred) else: assert min_pred >= 0 bin_boundaries = np.linspace(min_pred, 1., n_bins + 1) bin_lowers = bin_boundaries[:-1] bin_uppers = bin_boundaries[1:] acc_in_bin_list = [] avg_confs_in_bins = [] list_prop_bin = [] ece = 0.0 for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): in_bin = np.logical_and(proba_pred > bin_lower, proba_pred <= bin_upper) prop_in_bin = np.mean(in_bin) list_prop_bin.append(prop_in_bin) if prop_in_bin > 0: accuracy_in_bin = np.mean(accurate[in_bin]) avg_confidence_in_bin = np.mean(proba_pred[in_bin]) delta = avg_confidence_in_bin - accuracy_in_bin acc_in_bin_list.append(accuracy_in_bin) avg_confs_in_bins.append(avg_confidence_in_bin) ece += np.abs(delta) * prop_in_bin LOGGER.debug( f"From {bin_lower:4.5} to {bin_upper:4.5} and mean {avg_confidence_in_bin:3.5}, {(prop_in_bin * 100):4.5} % samples with accuracy {accuracy_in_bin:4.5}" ) else: avg_confs_in_bins.append(None) acc_in_bin_list.append(None) if write_file is not None: visualize.write_calibration( avg_confs_in_bins=avg_confs_in_bins, acc_in_bin_list=acc_in_bin_list, prop_bin=list_prop_bin, min_bin=bin_lowers, max_bin=bin_uppers, min_pred=min_pred, suffix=f"{ece:.4}", write_file=write_file) # For reliability diagrams, also need to return these: # return ece, bin_lowers, avg_confs_in_bins return ece, avg_confs_in_bins, acc_in_bin_list
[docs]def get_tace_bayesgroup(preds, targets, n_bins=15, threshold=1e-3, write_file=None, **args): """ Compute TACE and write to file """ n_objects, n_classes = preds.shape res = 0.0 for cur_class in range(n_classes): cur_class_conf = preds[:, cur_class] targets_sorted = targets[cur_class_conf.argsort()] cur_class_conf_sorted = np.sort(cur_class_conf) targets_sorted = targets_sorted[cur_class_conf_sorted > threshold] cur_class_conf_sorted = cur_class_conf_sorted[cur_class_conf_sorted > threshold] bin_size = len(cur_class_conf_sorted) // n_bins for bin_i in range(n_bins): bin_start_ind = bin_i * bin_size if bin_i < n_bins-1: bin_end_ind = bin_start_ind + bin_size else: bin_end_ind = len(targets_sorted) bin_size = bin_end_ind - bin_start_ind # extend last bin until the end of prediction array bin_acc = (targets_sorted[bin_start_ind : bin_end_ind] == cur_class) bin_conf = cur_class_conf_sorted[bin_start_ind : bin_end_ind] avg_confidence_in_bin = np.mean(bin_conf) avg_accuracy_in_bin = np.mean(bin_acc) delta = np.abs(avg_confidence_in_bin - avg_accuracy_in_bin) res += delta * bin_size / (n_objects * n_classes) return res
[docs]def get_ll(preds, targets, **args): """ Compute log likelihood """ preds_target = preds[np.arange(len(targets)), targets] return np.log(1e-12 + preds_target).sum()
[docs]def get_brier(preds, targets, **args): """ Compute brier score """ one_hot_targets = np.zeros(preds.shape) one_hot_targets[np.arange(len(targets)), targets] = 1.0 return np.mean((preds - one_hot_targets) ** 2) * len(targets)
[docs]class MetricsWrapper: """ Metric storing object """
[docs] def __init__(self, metrics): self.metrics = metrics self.accurate_or_wrong, self.list_confidences = [], [] # metrics if "nll" in self.metrics: self.nll = 0 if "brier" in self.metrics: self.brier = 0 if "accuracytop5" in self.metrics: self.num_accurate_top5 = 0 if "tace" in self.metrics: LOGGER.debug("Keep all predictions to compute TACE. Can be heavy, do not use in training") self._list_np_probs = [] self._list_np_targets = [] if "diversity" in self.metrics: self._list_target_diversity = [] self._list_matrix_predictions_diversity = []
[docs] def update(self, pred, target, confidence, probs): """ Compute tracked metrics and update records """ self.accurate_or_wrong.extend(pred.eq(target.view_as(pred)).detach().to("cpu").numpy()) self.list_confidences.extend(confidence.detach().to("cpu").numpy()) np_probs = probs.detach().to("cpu").numpy() np_targets = target.detach().to("cpu").numpy() if "nll" in self.metrics: self.nll -= get_ll(np_probs, np_targets) if "brier" in self.metrics: brier = get_brier(np_probs, np_targets) self.brier += brier if "accuracytop5" in self.metrics: _, pred5 = probs.topk(5, 1, True, True) pred5 = pred5.t() correct5 = pred5.eq(target.view(1, -1).expand_as(pred5)) correct5 = correct5[:5].view(-1).float().sum(0, keepdim=True).detach().to("cpu").numpy()[0] self.num_accurate_top5 += correct5 if "tace" in self.metrics: self._list_np_probs.append(np_probs) self._list_np_targets.append(np_targets)
[docs] def update_diversity(self, target, predictions): """ Compute and update records of diversity metrics """ self._list_target_diversity.extend(target) self._list_matrix_predictions_diversity.append(predictions)
[docs] def get_scores(self, split="train"): """ Print stored results """ if not len(self.list_confidences): LOGGER.warning("No predictions so far") return {} accurate_or_wrong = np.reshape(self.accurate_or_wrong, newshape=(len(self.accurate_or_wrong), -1)).flatten() list_confidences = np.reshape(self.list_confidences, newshape=(len(self.list_confidences), -1)).flatten() len_dataset = len(list_confidences) scores = {} if "diversity" in self.metrics: diversity_stats = self._get_diversity_stats() for diversity_key, diversity_stat in diversity_stats.items(): scores[f"{split}/" + diversity_key] = diversity_stat if "accuracy" in self.metrics: accuracy = np.mean(accurate_or_wrong) scores[f"{split}/accuracy"] = {"value": accuracy, "string": f"{accuracy:05.2%}"} if "nll" in self.metrics: nll = self.nll / len_dataset scores[f"{split}/nll"] = {"value": nll, "string": f"{nll:.6}"} if "accuracytop5" in self.metrics: accuracytop5 = self.num_accurate_top5 / len_dataset scores[f"{split}/accuracytop5"] = {"value": accuracytop5, "string": f"{accuracytop5:05.2%}"} if "auc" in self.metrics: auc = roc_auc_score(accurate_or_wrong, list_confidences) scores[f"{split}/auc"] = {"value": auc, "string": f"{auc:.6}"} if "brier" in self.metrics: brier = self.brier / len_dataset scores[f"{split}/brier"] = {"value": brier, "string": f"{brier:.6}"} if "ece" in self.metrics: ece, _, _ = get_ece( list_confidences, accurate_or_wrong, min_pred=0, write_file=None) scores[f"{split}/ece"] = {"value": ece, "string": f"{ece:.6}"} if "tace" in self.metrics: preds = np.concatenate(self._list_np_probs, axis=0) targets = np.concatenate(self._list_np_targets) tace = get_tace_bayesgroup(preds=preds, targets=targets) scores[f"{split}/tace"] = {"value": tace, "string": f"{tace:.6}"} return scores
[docs] def _get_diversity_stats(self,): """ Retrieve stored diversity stats """ if not len(self._list_matrix_predictions_diversity): return {} num_members = len(self._list_matrix_predictions_diversity[0]) list_predictions_diversity = [[] for _ in range(num_members)] for matrix_predictions_diversity in self._list_matrix_predictions_diversity: for num_member in range(num_members): list_predictions_diversity[num_member].extend( matrix_predictions_diversity[num_member] ) stats = metrics_ensemble.MetricsEnsemble( true=self._list_target_diversity, predictions=list_predictions_diversity, names=[str(i) for i in range(num_members)] ).get_report() return stats