diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..34216b1 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -32,6 +32,10 @@ def hook(module, input, output): summary[m_key]["trainable"] = module.weight.requires_grad if hasattr(module, "bias") and hasattr(module.bias, "size"): params += torch.prod(torch.LongTensor(list(module.bias.size()))) + if hasattr(module, "running_mean") and hasattr(module.running_mean, "size") and hasattr(module, "track_running_stats") and module.track_running_stats: + params += torch.prod(torch.LongTensor(list(module.running_mean.size()))) + if hasattr(module, "running_var") and hasattr(module.running_var, "size") and hasattr(module, "track_running_stats") and module.track_running_stats: + params += torch.prod(torch.LongTensor(list(module.running_var.size()))) summary[m_key]["nb_params"] = params if (