From f0ce2e8b7570b85d90e34b2a98e59e1c99b3f8fc Mon Sep 17 00:00:00 2001 From: Muhammad Hamdan Date: Mon, 14 Mar 2022 16:57:59 -0700 Subject: [PATCH] Return model size info This is a useful information to dynamically compute max/recommended batch size. --- torchsummary/torchsummary.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..a090a0d 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -7,11 +7,11 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): - result, params_info = summary_string( + result, summary_info = summary_string( model, input_size, batch_size, device, dtypes) print(result) - return params_info + return summary_info def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): @@ -19,7 +19,7 @@ def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0 dtypes = [torch.FloatTensor]*len(input_size) summary_str = '' - + summary_info = {"params_info": tuple(), "size_info": tuple()} def register_hook(module): def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] @@ -117,4 +117,8 @@ def hook(module, input, output): summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" summary_str += "----------------------------------------------------------------" + "\n" # return summary - return summary_str, (total_params, trainable_params) + + summary_info['params_info'] = (total_params, trainable_params) + summary_info['size_info'] = (total_input_size, total_output_size, total_params_size, total_size) + + return summary_str, summary_info