77
88
99def summary (model , input_size , batch_size = - 1 , device = torch .device ('cuda:0' ), dtypes = None ):
10+ result , params_info = summary_string (
11+ model , input_size , batch_size , device , dtypes )
12+ print (result )
13+
14+ return params_info
15+
16+
17+ def summary_string (model , input_size , batch_size = - 1 , device = torch .device ('cuda:0' ), dtypes = None ):
1018 if dtypes == None :
1119 dtypes = [torch .FloatTensor ]* len (input_size )
1220
13- def register_hook ( module ):
21+ summary_str = ''
1422
23+ def register_hook (module ):
1524 def hook (module , input , output ):
1625 class_name = str (module .__class__ ).split ("." )[- 1 ].split ("'" )[0 ]
1726 module_idx = len (summary )
@@ -46,9 +55,9 @@ def hook(module, input, output):
4655 if isinstance (input_size , tuple ):
4756 input_size = [input_size ]
4857
49-
5058 # batch_size of 2 for batchnorm
51- x = [ torch .rand (2 , * in_size ).type (dtype ).to (device = device ) for in_size , dtype in zip (input_size , dtypes )]
59+ x = [torch .rand (2 , * in_size ).type (dtype ).to (device = device )
60+ for in_size , dtype in zip (input_size , dtypes )]
5261
5362 # create properties
5463 summary = OrderedDict ()
@@ -65,10 +74,11 @@ def hook(module, input, output):
6574 for h in hooks :
6675 h .remove ()
6776
68- print ("----------------------------------------------------------------" )
69- line_new = "{:>20} {:>25} {:>15}" .format ("Layer (type)" , "Output Shape" , "Param #" )
70- print (line_new )
71- print ("================================================================" )
77+ summary_str += "----------------------------------------------------------------" + "\n "
78+ line_new = "{:>20} {:>25} {:>15}" .format (
79+ "Layer (type)" , "Output Shape" , "Param #" )
80+ summary_str += line_new + "\n "
81+ summary_str += "================================================================" + "\n "
7282 total_params = 0
7383 total_output = 0
7484 trainable_params = 0
@@ -85,23 +95,26 @@ def hook(module, input, output):
8595 if "trainable" in summary [layer ]:
8696 if summary [layer ]["trainable" ] == True :
8797 trainable_params += summary [layer ]["nb_params" ]
88- print ( line_new )
98+ summary_str += line_new + " \n "
8999
90100 # assume 4 bytes/number (float on cuda).
91- total_input_size = abs (np .prod (sum (input_size , ())) * batch_size * 4. / (1024 ** 2. ))
92- total_output_size = abs (2. * total_output * 4. / (1024 ** 2. )) # x2 for gradients
101+ total_input_size = abs (np .prod (sum (input_size , ()))
102+ * batch_size * 4. / (1024 ** 2. ))
103+ total_output_size = abs (2. * total_output * 4. /
104+ (1024 ** 2. )) # x2 for gradients
93105 total_params_size = abs (total_params * 4. / (1024 ** 2. ))
94106 total_size = total_params_size + total_output_size + total_input_size
95107
96- print ("================================================================" )
97- print ("Total params: {0:,}" .format (total_params ))
98- print ("Trainable params: {0:,}" .format (trainable_params ))
99- print ("Non-trainable params: {0:,}" .format (total_params - trainable_params ))
100- print ("----------------------------------------------------------------" )
101- print ("Input size (MB): %0.2f" % total_input_size )
102- print ("Forward/backward pass size (MB): %0.2f" % total_output_size )
103- print ("Params size (MB): %0.2f" % total_params_size )
104- print ("Estimated Total Size (MB): %0.2f" % total_size )
105- print ("----------------------------------------------------------------" )
108+ summary_str += "================================================================" + "\n "
109+ summary_str += "Total params: {0:,}" .format (total_params ) + "\n "
110+ summary_str += "Trainable params: {0:,}" .format (trainable_params ) + "\n "
111+ summary_str += "Non-trainable params: {0:,}" .format (total_params -
112+ trainable_params ) + "\n "
113+ summary_str += "----------------------------------------------------------------" + "\n "
114+ summary_str += "Input size (MB): %0.2f" % total_input_size + "\n "
115+ summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n "
116+ summary_str += "Params size (MB): %0.2f" % total_params_size + "\n "
117+ summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n "
118+ summary_str += "----------------------------------------------------------------" + "\n "
106119 # return summary
107- return total_params , trainable_params
120+ return summary_str , ( total_params , trainable_params )
0 commit comments