Skip to content

Commit b940bd9

Browse files
committed
Add unit test for summary_string
1 parent f9c58b4 commit b940bd9

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
__pycache__
22
*.pyc
3+
.vscode/

torchsummary/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .torchsummary import summary
1+
from .torchsummary import summary, summary_string

torchsummary/tests/unit_tests/torchsummary_test.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import unittest
2-
from torchsummary import summary
2+
from torchsummary import summary, summary_string
33
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, MultipleInputNetDifferentDtypes
44
import torch
55

6+
gpu_if_available = "cuda:0" if torch.cuda.is_available() else "cpu"
7+
68
class torchsummaryTests(unittest.TestCase):
79
def test_single_input(self):
810
model = SingleInputNet()
@@ -15,7 +17,8 @@ def test_multiple_input(self):
1517
model = MultipleInputNet()
1618
input1 = (1, 300)
1719
input2 = (1, 300)
18-
total_params, trainable_params = summary(model, [input1, input2], device="cpu")
20+
total_params, trainable_params = summary(
21+
model, [input1, input2], device="cpu")
1922
self.assertEqual(total_params, 31120)
2023
self.assertEqual(trainable_params, 31120)
2124

@@ -28,9 +31,10 @@ def test_single_layer_network(self):
2831

2932
def test_single_layer_network_on_gpu(self):
3033
model = torch.nn.Linear(2, 5)
31-
model.cuda()
34+
if torch.cuda.is_available():
35+
model.cuda()
3236
input = (1, 2)
33-
total_params, trainable_params = summary(model, input, device="cuda:0")
37+
total_params, trainable_params = summary(model, input, device=gpu_if_available)
3438
self.assertEqual(total_params, 15)
3539
self.assertEqual(trainable_params, 15)
3640

@@ -39,9 +43,22 @@ def test_multiple_input_types(self):
3943
input1 = (1, 300)
4044
input2 = (1, 300)
4145
dtypes = [torch.FloatTensor, torch.LongTensor]
42-
total_params, trainable_params = summary(model, [input1, input2], device="cpu", dtypes=dtypes)
46+
total_params, trainable_params = summary(
47+
model, [input1, input2], device="cpu", dtypes=dtypes)
4348
self.assertEqual(total_params, 31120)
4449
self.assertEqual(trainable_params, 31120)
4550

51+
52+
class torchsummarystringTests(unittest.TestCase):
53+
def test_single_input(self):
54+
model = SingleInputNet()
55+
input = (1, 28, 28)
56+
result, (total_params, trainable_params) = summary_string(
57+
model, input, device="cpu")
58+
self.assertEqual(type(result), str)
59+
self.assertEqual(total_params, 21840)
60+
self.assertEqual(trainable_params, 21840)
61+
62+
4663
if __name__ == '__main__':
4764
unittest.main(buffer=True)

0 commit comments

Comments
 (0)