11import unittest
2- from torchsummary import summary
2+ from torchsummary import summary , summary_string
33from torchsummary .tests .test_models .test_model import SingleInputNet , MultipleInputNet , MultipleInputNetDifferentDtypes
44import torch
55
6+ gpu_if_available = "cuda:0" if torch .cuda .is_available () else "cpu"
7+
68class torchsummaryTests (unittest .TestCase ):
79 def test_single_input (self ):
810 model = SingleInputNet ()
@@ -15,7 +17,8 @@ def test_multiple_input(self):
1517 model = MultipleInputNet ()
1618 input1 = (1 , 300 )
1719 input2 = (1 , 300 )
18- total_params , trainable_params = summary (model , [input1 , input2 ], device = "cpu" )
20+ total_params , trainable_params = summary (
21+ model , [input1 , input2 ], device = "cpu" )
1922 self .assertEqual (total_params , 31120 )
2023 self .assertEqual (trainable_params , 31120 )
2124
@@ -28,9 +31,10 @@ def test_single_layer_network(self):
2831
2932 def test_single_layer_network_on_gpu (self ):
3033 model = torch .nn .Linear (2 , 5 )
31- model .cuda ()
34+ if torch .cuda .is_available ():
35+ model .cuda ()
3236 input = (1 , 2 )
33- total_params , trainable_params = summary (model , input , device = "cuda:0" )
37+ total_params , trainable_params = summary (model , input , device = gpu_if_available )
3438 self .assertEqual (total_params , 15 )
3539 self .assertEqual (trainable_params , 15 )
3640
@@ -39,9 +43,22 @@ def test_multiple_input_types(self):
3943 input1 = (1 , 300 )
4044 input2 = (1 , 300 )
4145 dtypes = [torch .FloatTensor , torch .LongTensor ]
42- total_params , trainable_params = summary (model , [input1 , input2 ], device = "cpu" , dtypes = dtypes )
46+ total_params , trainable_params = summary (
47+ model , [input1 , input2 ], device = "cpu" , dtypes = dtypes )
4348 self .assertEqual (total_params , 31120 )
4449 self .assertEqual (trainable_params , 31120 )
4550
51+
52+ class torchsummarystringTests (unittest .TestCase ):
53+ def test_single_input (self ):
54+ model = SingleInputNet ()
55+ input = (1 , 28 , 28 )
56+ result , (total_params , trainable_params ) = summary_string (
57+ model , input , device = "cpu" )
58+ self .assertEqual (type (result ), str )
59+ self .assertEqual (total_params , 21840 )
60+ self .assertEqual (trainable_params , 21840 )
61+
62+
4663if __name__ == '__main__' :
4764 unittest .main (buffer = True )
0 commit comments