Skip to content

Commit 4ee5ac5

Browse files
authored
Merge pull request #100 from greenmonn/string-summary
Add interface for returning summary string
2 parents 8af9926 + b940bd9 commit 4ee5ac5

File tree

4 files changed

+58
-27
lines changed

4 files changed

+58
-27
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
__pycache__
22
*.pyc
3+
.vscode/

torchsummary/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .torchsummary import summary
1+
from .torchsummary import summary, summary_string

torchsummary/tests/unit_tests/torchsummary_test.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import unittest
2-
from torchsummary import summary
2+
from torchsummary import summary, summary_string
33
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, MultipleInputNetDifferentDtypes
44
import torch
55

6+
gpu_if_available = "cuda:0" if torch.cuda.is_available() else "cpu"
7+
68
class torchsummaryTests(unittest.TestCase):
79
def test_single_input(self):
810
model = SingleInputNet()
@@ -15,7 +17,8 @@ def test_multiple_input(self):
1517
model = MultipleInputNet()
1618
input1 = (1, 300)
1719
input2 = (1, 300)
18-
total_params, trainable_params = summary(model, [input1, input2], device="cpu")
20+
total_params, trainable_params = summary(
21+
model, [input1, input2], device="cpu")
1922
self.assertEqual(total_params, 31120)
2023
self.assertEqual(trainable_params, 31120)
2124

@@ -28,9 +31,10 @@ def test_single_layer_network(self):
2831

2932
def test_single_layer_network_on_gpu(self):
3033
model = torch.nn.Linear(2, 5)
31-
model.cuda()
34+
if torch.cuda.is_available():
35+
model.cuda()
3236
input = (1, 2)
33-
total_params, trainable_params = summary(model, input, device="cuda:0")
37+
total_params, trainable_params = summary(model, input, device=gpu_if_available)
3438
self.assertEqual(total_params, 15)
3539
self.assertEqual(trainable_params, 15)
3640

@@ -39,9 +43,22 @@ def test_multiple_input_types(self):
3943
input1 = (1, 300)
4044
input2 = (1, 300)
4145
dtypes = [torch.FloatTensor, torch.LongTensor]
42-
total_params, trainable_params = summary(model, [input1, input2], device="cpu", dtypes=dtypes)
46+
total_params, trainable_params = summary(
47+
model, [input1, input2], device="cpu", dtypes=dtypes)
4348
self.assertEqual(total_params, 31120)
4449
self.assertEqual(trainable_params, 31120)
4550

51+
52+
class torchsummarystringTests(unittest.TestCase):
53+
def test_single_input(self):
54+
model = SingleInputNet()
55+
input = (1, 28, 28)
56+
result, (total_params, trainable_params) = summary_string(
57+
model, input, device="cpu")
58+
self.assertEqual(type(result), str)
59+
self.assertEqual(total_params, 21840)
60+
self.assertEqual(trainable_params, 21840)
61+
62+
4663
if __name__ == '__main__':
4764
unittest.main(buffer=True)

torchsummary/torchsummary.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,20 @@
77

88

99
def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
10+
result, params_info = summary_string(
11+
model, input_size, batch_size, device, dtypes)
12+
print(result)
13+
14+
return params_info
15+
16+
17+
def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
1018
if dtypes == None:
1119
dtypes = [torch.FloatTensor]*len(input_size)
1220

13-
def register_hook(module):
21+
summary_str = ''
1422

23+
def register_hook(module):
1524
def hook(module, input, output):
1625
class_name = str(module.__class__).split(".")[-1].split("'")[0]
1726
module_idx = len(summary)
@@ -46,9 +55,9 @@ def hook(module, input, output):
4655
if isinstance(input_size, tuple):
4756
input_size = [input_size]
4857

49-
5058
# batch_size of 2 for batchnorm
51-
x = [ torch.rand(2, *in_size).type(dtype).to(device=device) for in_size, dtype in zip(input_size, dtypes)]
59+
x = [torch.rand(2, *in_size).type(dtype).to(device=device)
60+
for in_size, dtype in zip(input_size, dtypes)]
5261

5362
# create properties
5463
summary = OrderedDict()
@@ -65,10 +74,11 @@ def hook(module, input, output):
6574
for h in hooks:
6675
h.remove()
6776

68-
print("----------------------------------------------------------------")
69-
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
70-
print(line_new)
71-
print("================================================================")
77+
summary_str += "----------------------------------------------------------------" + "\n"
78+
line_new = "{:>20} {:>25} {:>15}".format(
79+
"Layer (type)", "Output Shape", "Param #")
80+
summary_str += line_new + "\n"
81+
summary_str += "================================================================" + "\n"
7282
total_params = 0
7383
total_output = 0
7484
trainable_params = 0
@@ -85,23 +95,26 @@ def hook(module, input, output):
8595
if "trainable" in summary[layer]:
8696
if summary[layer]["trainable"] == True:
8797
trainable_params += summary[layer]["nb_params"]
88-
print(line_new)
98+
summary_str += line_new + "\n"
8999

90100
# assume 4 bytes/number (float on cuda).
91-
total_input_size = abs(np.prod(sum(input_size, ())) * batch_size * 4. / (1024 ** 2.))
92-
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
101+
total_input_size = abs(np.prod(sum(input_size, ()))
102+
* batch_size * 4. / (1024 ** 2.))
103+
total_output_size = abs(2. * total_output * 4. /
104+
(1024 ** 2.)) # x2 for gradients
93105
total_params_size = abs(total_params * 4. / (1024 ** 2.))
94106
total_size = total_params_size + total_output_size + total_input_size
95107

96-
print("================================================================")
97-
print("Total params: {0:,}".format(total_params))
98-
print("Trainable params: {0:,}".format(trainable_params))
99-
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
100-
print("----------------------------------------------------------------")
101-
print("Input size (MB): %0.2f" % total_input_size)
102-
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
103-
print("Params size (MB): %0.2f" % total_params_size)
104-
print("Estimated Total Size (MB): %0.2f" % total_size)
105-
print("----------------------------------------------------------------")
108+
summary_str += "================================================================" + "\n"
109+
summary_str += "Total params: {0:,}".format(total_params) + "\n"
110+
summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
111+
summary_str += "Non-trainable params: {0:,}".format(total_params -
112+
trainable_params) + "\n"
113+
summary_str += "----------------------------------------------------------------" + "\n"
114+
summary_str += "Input size (MB): %0.2f" % total_input_size + "\n"
115+
summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n"
116+
summary_str += "Params size (MB): %0.2f" % total_params_size + "\n"
117+
summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n"
118+
summary_str += "----------------------------------------------------------------" + "\n"
106119
# return summary
107-
return total_params, trainable_params
120+
return summary_str, (total_params, trainable_params)

0 commit comments

Comments
 (0)