mixmo.learners.model_wrapper.ModelWrapper

class mixmo.learners.model_wrapper.ModelWrapper(config, config_args, device)[source]

Bases: object

Augment a model with losses, metrics, internal logs and other things

__init__(config, config_args, device)[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__(config, config_args, device)

Initialize self.

calibrate_via_tempscale(tempscale_loader)

Returns calibrated temperature on val/test set

get_dict_to_scores(split)

Format logs into a dictionary

get_short_logs()

Return summary of internal records

predict(data)

Perform a forward pass through the model and return the output

print_summary([pixels_size])

step(output, target[, backprop])

Compute loss, backward step and metrics if required by config Update internal records

to_eval_mode()

Switch model to eval mode

to_train_mode(epoch)

Switch model to train mode

Attributes

scaled_network

Returns scaled_model if necessary for amp

_compute_diversity(output, target)[source]

Compute diversity and update internal records

calibrate_via_tempscale(tempscale_loader)[source]

Returns calibrated temperature on val/test set

get_dict_to_scores(split)[source]

Format logs into a dictionary

get_short_logs()[source]

Return summary of internal records

predict(data)[source]

Perform a forward pass through the model and return the output

property scaled_network

Returns scaled_model if necessary for amp

step(output, target, backprop=False)[source]

Compute loss, backward step and metrics if required by config Update internal records

to_eval_mode()[source]

Switch model to eval mode

to_train_mode(epoch)[source]

Switch model to train mode