Source code for mixmo.utils.torchsummary

"""
Model printing utility
"""

import torch
import torch.nn as nn

from collections import OrderedDict
import numpy as np


[docs]def summary(model, input_size, batch_size=2, input_initializer=torch.rand, list_dtype=None, device="cuda"): """ Model printing function """ def register_hook(module): """ Register a hook that writes model information to an existing summary dictionary (in scope) """ def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] module_idx = len(summary) m_key = "%s-%i" % (class_name, module_idx + 1) summary[m_key] = OrderedDict() summary[m_key]["input_shape"] = list(input[0].size()) summary[m_key]["input_shape"][0] = batch_size if isinstance(output, (list, tuple)): summary[m_key]["output_shape"] = [ [-1] + list(o.size())[1:] for o in output ] else: summary[m_key]["output_shape"] = list(output.size()) summary[m_key]["output_shape"][0] = batch_size params = 0 if hasattr(module, "weight") and hasattr(module.weight, "size"): params += torch.prod(torch.LongTensor(list(module.weight.size()))) summary[m_key]["trainable"] = module.weight.requires_grad if hasattr(module, "bias") and hasattr(module.bias, "size"): params += torch.prod(torch.LongTensor(list(module.bias.size()))) summary[m_key]["nb_params"] = params if ( not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model) ): hooks.append(module.register_forward_hook(hook)) device = device.lower() assert device in [ "cuda", "cpu", ], "Input device is not valid, please specify 'cuda' or 'cpu'" device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu') # multiple inputs to the network if isinstance(input_size, tuple): input_size = [input_size] if list_dtype is None: list_dtype = [torch.float for _ in input_size] # batch_size of 2 for batchnorm x = [input_initializer((batch_size, *in_size)).type(dtype).to(device) for (in_size, dtype) in zip(input_size, list_dtype)] # create properties summary = OrderedDict() hooks = [] # register hook model.apply(register_hook) # make a forward pass model(*x) # remove these hooks for h in hooks: h.remove() print("----------------------------------------------------------------") line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") print(line_new) print("================================================================") total_params = 0 total_output = 0 trainable_params = 0 for layer in summary: # input_shape, output_shape, trainable, nb_params line_new = "{:>20} {:>25} {:>15}".format( layer, str(summary[layer]["output_shape"]), "{0:,}".format(summary[layer]["nb_params"]), ) total_params += summary[layer]["nb_params"] total_output += np.prod(summary[layer]["output_shape"]) if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: trainable_params += summary[layer]["nb_params"] print(line_new) # assume 4 bytes/number (float on cuda). _total_params_for_size = total_params.numpy() if hasattr(total_params, "numpy") else total_params total_input_size = abs(np.prod([dimension for tensor_size in input_size for dimension in tensor_size]) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs((_total_params_for_size) * 4. / (1024**2.)) total_size = total_params_size + total_output_size + total_input_size print("================================================================") print("Total params: {0:,}".format(total_params)) print("Trainable params: {0:,}".format(trainable_params)) print("Non-trainable params: {0:,}".format(total_params - trainable_params)) print("----------------------------------------------------------------") print("Input size (MB): %0.2f" % total_input_size) print("Forward/backward pass size (MB): %0.2f" % total_output_size) print("Params size (MB): %0.2f" % total_params_size) print("Estimated Total Size (MB): %0.2f" % total_size) print("----------------------------------------------------------------")