Skip to content

Commit 4852f2e

Browse files
authored
Fix handling input_size with multi-input
1 parent 011b2bd commit 4852f2e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchsummary/torchsummary.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def hook(module, input, output):
9898
summary_str += line_new + "\n"
9999

100100
# assume 4 bytes/number (float on cuda).
101-
total_input_size = abs(np.prod(sum(input_size, ()))
101+
# to handle the case of multi-input: prod(input1) + prod(input2) + ...
102+
n_input_size = np.array([np.prod(i) for i in input_size]).sum() if isinstance(input_size, list) else np.prod(input_size)
103+
total_input_size = abs(n_input_size
102104
* batch_size * 4. / (1024 ** 2.))
103105
total_output_size = abs(2. * total_output * 4. /
104106
(1024 ** 2.)) # x2 for gradients

0 commit comments

Comments
 (0)