From a1070bb3b79bd1ad9e12c661832d422bc05ab84d Mon Sep 17 00:00:00 2001 From: F-G Fernandez Date: Thu, 5 Mar 2020 14:54:12 +0100 Subject: [PATCH] feat: Default device is set to model device Avoids specifying device since the input tensor needs to be on the same on as the model. This is useful in multi-GPUs environment or to freely use the function on CPU --- torchsummary/torchsummary.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..23cf6ba 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,15 +6,19 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary(model, input_size, batch_size=-1, dtypes=None): result, params_info = summary_string( - model, input_size, batch_size, device, dtypes) + model, input_size, batch_size, dtypes) print(result) return params_info -def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary_string(model, input_size, batch_size=-1, dtypes=None): + + # Take the device of the first model parameter + device = next(model.parameters()).device + if dtypes == None: dtypes = [torch.FloatTensor]*len(input_size)