From 27df4ca9d122cf313dc87f370def5ec977d982ea Mon Sep 17 00:00:00 2001 From: fish Date: Tue, 7 Apr 2020 20:59:54 +0800 Subject: [PATCH] modify --- torchsummary/torchsummary.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..890346c 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -38,11 +38,11 @@ def hook(module, input, output): summary[m_key]["output_shape"][0] = batch_size params = 0 - if hasattr(module, "weight") and hasattr(module.weight, "size"): - params += torch.prod(torch.LongTensor(list(module.weight.size()))) - summary[m_key]["trainable"] = module.weight.requires_grad - if hasattr(module, "bias") and hasattr(module.bias, "size"): - params += torch.prod(torch.LongTensor(list(module.bias.size()))) + + for p in module.parameters(recurse=False): + param = torch.tensor(p.size()).prod() + summary[m_key]["trainable"] = p.requires_grad + params += param summary[m_key]["nb_params"] = params if (