Skip to content

Commit 3f20b06

Browse files
authored
Merge pull request #106 from sksq96/issue_9
can now handle single layers, for #9
2 parents 2818454 + f0c0db6 commit 3f20b06

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
__pycache__
2+
*.pyc

torchsummary/tests/unit_tests/torchsummary_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from torchsummary import summary
33
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet
4-
4+
import torch
55

66
class torchsummaryTests(unittest.TestCase):
77
def test_single_input(self):
@@ -19,5 +19,20 @@ def test_multiple_input(self):
1919
self.assertEqual(total_params, 31120)
2020
self.assertEqual(trainable_params, 31120)
2121

22+
def test_single_layer_network(self):
23+
model = torch.nn.Linear(2, 5)
24+
input = (1, 2)
25+
total_params, trainable_params = summary(model, input, device="cpu")
26+
self.assertEqual(total_params, 15)
27+
self.assertEqual(trainable_params, 15)
28+
29+
def test_single_layer_network_on_gpu(self):
30+
model = torch.nn.Linear(2, 5)
31+
model.cuda()
32+
input = (1, 2)
33+
total_params, trainable_params = summary(model, input, device="cuda")
34+
self.assertEqual(total_params, 15)
35+
self.assertEqual(trainable_params, 15)
36+
2237
if __name__ == '__main__':
2338
unittest.main(buffer=True)

torchsummary/torchsummary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def hook(module, input, output):
3737
if (
3838
not isinstance(module, nn.Sequential)
3939
and not isinstance(module, nn.ModuleList)
40-
and not (module == model)
4140
):
4241
hooks.append(module.register_forward_hook(hook))
4342

@@ -90,6 +89,7 @@ def hook(module, input, output):
9089
"{0:,}".format(summary[layer]["nb_params"]),
9190
)
9291
total_params += summary[layer]["nb_params"]
92+
9393
total_output += np.prod(summary[layer]["output_shape"])
9494
if "trainable" in summary[layer]:
9595
if summary[layer]["trainable"] == True:
@@ -99,7 +99,7 @@ def hook(module, input, output):
9999
# assume 4 bytes/number (float on cuda).
100100
total_input_size = abs(np.prod(sum(input_size, ())) * batch_size * 4. / (1024 ** 2.))
101101
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
102-
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
102+
total_params_size = abs(total_params * 4. / (1024 ** 2.))
103103
total_size = total_params_size + total_output_size + total_input_size
104104

105105
print("================================================================")

0 commit comments

Comments
 (0)