diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..23cf6ba 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,15 +6,19 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary(model, input_size, batch_size=-1, dtypes=None): result, params_info = summary_string( - model, input_size, batch_size, device, dtypes) + model, input_size, batch_size, dtypes) print(result) return params_info -def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary_string(model, input_size, batch_size=-1, dtypes=None): + + # Take the device of the first model parameter + device = next(model.parameters()).device + if dtypes == None: dtypes = [torch.FloatTensor]*len(input_size)