From f731bd5717cd0bcc7a858e388d2a72560e7a7248 Mon Sep 17 00:00:00 2001 From: Maximilian Berr Date: Mon, 12 Aug 2019 09:26:56 +0300 Subject: [PATCH] Extensions: * added 'dtype' to torchsummary input variables * added 'input_initializer' to torchsummary input variables Bugfix: * total_input return TypeError: File "/home/developer/AmI/pytorch-summary/torchsummary/torchsummary.py", line 96, in summary total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) File "/conda/envs/rapids/lib/python3.6/site-packages/numpy/core/fromnumeric.py", line 2772, in prod initial=initial) File "/conda/envs/rapids/lib/python3.6/site-packages/numpy/core/fromnumeric.py", line 86, in _wrapreduction return ufunc.reduce(obj, axis, dtype, out, **passkwargs) TypeError: can't multiply sequence by non-int of type 'tuple' --- torchsummary/torchsummary.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..b7321b9 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,7 +6,7 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device="cuda"): +def summary(model, input_size, batch_size=2, input_initializer=torch.rand, dtype=torch.float64, device="cuda"): def register_hook(module): @@ -47,17 +47,14 @@ def hook(module, input, output): "cpu", ], "Input device is not valid, please specify 'cuda' or 'cpu'" - if device == "cuda" and torch.cuda.is_available(): - dtype = torch.cuda.FloatTensor - else: - dtype = torch.FloatTensor + device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu') # multiple inputs to the network if isinstance(input_size, tuple): input_size = [input_size] # batch_size of 2 for batchnorm - x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] + x = [input_initializer((batch_size, *in_size)).type(dtype).to(device) for in_size in input_size] # print(type(x[0])) # create properties @@ -97,7 +94,7 @@ def hook(module, input, output): print(line_new) # assume 4 bytes/number (float on cuda). - total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) + total_input_size = abs(np.prod([dimension for tensor_size in input_size for dimension in tensor_size]) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size