Skip to content

Commit f9c58b4

Browse files
committed
Add interface for returning summary string
Add interface for returning concatenated string of summary, instead of directly printing it. This feature is useful for logging. - Issue: #99
1 parent 8af9926 commit f9c58b4

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

torchsummary/torchsummary.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,20 @@
77

88

99
def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
10+
result, params_info = summary_string(
11+
model, input_size, batch_size, device, dtypes)
12+
print(result)
13+
14+
return params_info
15+
16+
17+
def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
1018
if dtypes == None:
1119
dtypes = [torch.FloatTensor]*len(input_size)
1220

13-
def register_hook(module):
21+
summary_str = ''
1422

23+
def register_hook(module):
1524
def hook(module, input, output):
1625
class_name = str(module.__class__).split(".")[-1].split("'")[0]
1726
module_idx = len(summary)
@@ -46,9 +55,9 @@ def hook(module, input, output):
4655
if isinstance(input_size, tuple):
4756
input_size = [input_size]
4857

49-
5058
# batch_size of 2 for batchnorm
51-
x = [ torch.rand(2, *in_size).type(dtype).to(device=device) for in_size, dtype in zip(input_size, dtypes)]
59+
x = [torch.rand(2, *in_size).type(dtype).to(device=device)
60+
for in_size, dtype in zip(input_size, dtypes)]
5261

5362
# create properties
5463
summary = OrderedDict()
@@ -65,10 +74,11 @@ def hook(module, input, output):
6574
for h in hooks:
6675
h.remove()
6776

68-
print("----------------------------------------------------------------")
69-
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
70-
print(line_new)
71-
print("================================================================")
77+
summary_str += "----------------------------------------------------------------" + "\n"
78+
line_new = "{:>20} {:>25} {:>15}".format(
79+
"Layer (type)", "Output Shape", "Param #")
80+
summary_str += line_new + "\n"
81+
summary_str += "================================================================" + "\n"
7282
total_params = 0
7383
total_output = 0
7484
trainable_params = 0
@@ -85,23 +95,26 @@ def hook(module, input, output):
8595
if "trainable" in summary[layer]:
8696
if summary[layer]["trainable"] == True:
8797
trainable_params += summary[layer]["nb_params"]
88-
print(line_new)
98+
summary_str += line_new + "\n"
8999

90100
# assume 4 bytes/number (float on cuda).
91-
total_input_size = abs(np.prod(sum(input_size, ())) * batch_size * 4. / (1024 ** 2.))
92-
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
101+
total_input_size = abs(np.prod(sum(input_size, ()))
102+
* batch_size * 4. / (1024 ** 2.))
103+
total_output_size = abs(2. * total_output * 4. /
104+
(1024 ** 2.)) # x2 for gradients
93105
total_params_size = abs(total_params * 4. / (1024 ** 2.))
94106
total_size = total_params_size + total_output_size + total_input_size
95107

96-
print("================================================================")
97-
print("Total params: {0:,}".format(total_params))
98-
print("Trainable params: {0:,}".format(trainable_params))
99-
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
100-
print("----------------------------------------------------------------")
101-
print("Input size (MB): %0.2f" % total_input_size)
102-
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
103-
print("Params size (MB): %0.2f" % total_params_size)
104-
print("Estimated Total Size (MB): %0.2f" % total_size)
105-
print("----------------------------------------------------------------")
108+
summary_str += "================================================================" + "\n"
109+
summary_str += "Total params: {0:,}".format(total_params) + "\n"
110+
summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
111+
summary_str += "Non-trainable params: {0:,}".format(total_params -
112+
trainable_params) + "\n"
113+
summary_str += "----------------------------------------------------------------" + "\n"
114+
summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
115+
summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
116+
summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
117+
summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
118+
summary_str += "----------------------------------------------------------------" + "\n"
106119
# return summary
107-
return total_params, trainable_params
120+
return summary_str, (total_params, trainable_params)

0 commit comments

Comments
 (0)