From 9ec7980804639c4b5b382da61148caf9a75a3549 Mon Sep 17 00:00:00 2001 From: ncble Date: Sat, 27 Apr 2019 12:57:26 +0200 Subject: [PATCH 1/5] solve device issue (default to 'cpu'); support torch 0.4.1 or newer --- torchsummary/torchsummary.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..fff5e36 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,7 +6,7 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device="cuda"): +def summary(model, input_size, batch_size=-1, device="cpu"): def register_hook(module): @@ -41,23 +41,26 @@ def hook(module, input, output): ): hooks.append(module.register_forward_hook(hook)) - device = device.lower() - assert device in [ - "cuda", - "cpu", - ], "Input device is not valid, please specify 'cuda' or 'cpu'" + if isinstance(device, str): + device = device.lower() + assert device in [ + "cuda", + "cuda:0", + "cuda:1", + "cpu", + ], "Input device is not valid, please specify 'cuda' or 'cpu'" - if device == "cuda" and torch.cuda.is_available(): - dtype = torch.cuda.FloatTensor - else: - dtype = torch.FloatTensor + # if device == "cuda" and torch.cuda.is_available(): + # dtype = torch.cuda.FloatTensor + # else: + # dtype = torch.FloatTensor # multiple inputs to the network if isinstance(input_size, tuple): input_size = [input_size] # batch_size of 2 for batchnorm - x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] + x = [torch.rand(2, *in_size).to(device) for in_size in input_size] # print(type(x[0])) # create properties From e557239f28e76964c5df45441d998723cead5f07 Mon Sep 17 00:00:00 2001 From: ncble Date: Sat, 27 Apr 2019 13:02:30 +0200 Subject: [PATCH 2/5] edit setup --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 35f13cd..4e2265a 100644 --- a/setup.py +++ b/setup.py @@ -2,10 +2,10 @@ setup( name="torchsummary", - version="1.5.1", + version="1.5.2", description="Model summary in PyTorch similar to `model.summary()` in Keras", - url="https://github.com/sksq96/pytorch-summary", - author="Shubham Chandel @sksq96", - author_email="shubham.zeez@gmail.com", + url="https://github.com/ncble/pytorch-summary", + author="Lu Lin @ncble", + author_email="ncble17@gmail.com", packages=["torchsummary"], ) From b5d2f8d226733c84e6755e9f175d1fb6862afb53 Mon Sep 17 00:00:00 2001 From: lulin Date: Fri, 24 May 2019 17:22:25 +0200 Subject: [PATCH 3/5] new feature: output (return) image shape --- setup.py | 2 +- torchsummary/torchsummary.py | 39 ++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index 4e2265a..05a006e 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="torchsummary", - version="1.5.2", + version="1.5.3", description="Model summary in PyTorch similar to `model.summary()` in Keras", url="https://github.com/ncble/pytorch-summary", author="Lu Lin @ncble", diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index fff5e36..2348198 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,7 +6,7 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device="cpu"): +def summary(model, input_size, batch_size=-1, device="cpu", show=True): def register_hook(module): @@ -77,11 +77,11 @@ def hook(module, input, output): # remove these hooks for h in hooks: h.remove() - - print("----------------------------------------------------------------") - line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") - print(line_new) - print("================================================================") + if show: + print("----------------------------------------------------------------") + line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") + print(line_new) + print("================================================================") total_params = 0 total_output = 0 trainable_params = 0 @@ -97,22 +97,23 @@ def hook(module, input, output): if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: trainable_params += summary[layer]["nb_params"] - print(line_new) + if show: + print(line_new) # assume 4 bytes/number (float on cuda). total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size - - print("================================================================") - print("Total params: {0:,}".format(total_params)) - print("Trainable params: {0:,}".format(trainable_params)) - print("Non-trainable params: {0:,}".format(total_params - trainable_params)) - print("----------------------------------------------------------------") - print("Input size (MB): %0.2f" % total_input_size) - print("Forward/backward pass size (MB): %0.2f" % total_output_size) - print("Params size (MB): %0.2f" % total_params_size) - print("Estimated Total Size (MB): %0.2f" % total_size) - print("----------------------------------------------------------------") - # return summary + if show: + print("================================================================") + print("Total params: {0:,}".format(total_params)) + print("Trainable params: {0:,}".format(trainable_params)) + print("Non-trainable params: {0:,}".format(total_params - trainable_params)) + print("----------------------------------------------------------------") + print("Input size (MB): %0.2f" % total_input_size) + print("Forward/backward pass size (MB): %0.2f" % total_output_size) + print("Params size (MB): %0.2f" % total_params_size) + print("Estimated Total Size (MB): %0.2f" % total_size) + print("----------------------------------------------------------------") + return summary[layer]["output_shape"] From fc4ead46d0776060358febeef3e123070d907152 Mon Sep 17 00:00:00 2001 From: lulin Date: Thu, 4 Jul 2019 14:44:43 +0200 Subject: [PATCH 4/5] update README.md (new functionality) and parse device str cuda:n --- README.md | 96 +++++++++++++++++++++++++++++++++++- setup.py | 2 +- torchsummary/torchsummary.py | 23 +++++---- 3 files changed, 110 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 5c6df2a..a13ed34 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ ## Keras style `model.summary()` in PyTorch [![PyPI version](https://badge.fury.io/py/torchsummary.svg)](https://badge.fury.io/py/torchsummary) -Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch. +Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch. (**New functionality**) The main function `summary` (`from torchsummary import summary`) can also be used to infer the output shape of a pytorch model. Thus, it provides a way to build pytorch model that supports any input shape like in Keras (see an [example](#scalable) below). ### Usage @@ -191,6 +191,100 @@ Estimated Total Size (MB): 0.78 ---------------------------------------------------------------- ``` +### Build pytorch model with scalable input shape (like Keras) + +```python + +import torch +import torch.nn as nn +from torchsummary import summary + +class AutoEncoder(nn.Module): + """ + ResNet autoencoder network that support any input shape as model in Keras + :param img_shape (tuple, channel last): support any image input shape + :param state_dim: (int) latent state dimension + """ + + def __init__(self, img_shape=(3, 224, 224), state_dim=3): + super().__init__() + self.state_dim = state_dim + self.img_shape = img_shape + + self.encoder_conv = nn.Sequential( + nn.Conv2d(self.img_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2), + + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2) + ) + ## Without torchsummary here, it's impossible to build model with scalable input shape as Keras. + outshape = summary(self.encoder_conv, img_shape, show=False) # [-1, channels, high, width] + self.img_height, self.img_width = outshape[-2:] + self.encoder_fc = nn.Sequential( + nn.Linear(self.img_height * self.img_width * 64, state_dim) + ) + + self.decoder_fc = nn.Sequential( + nn.Linear(state_dim, self.img_height * self.img_width * 64) + ) + + self.decoder_conv = nn.Sequential( + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, self.img_shape[0], kernel_size=4, stride=2), + nn.Tanh() + ) + + def encode(self, x): + """ + Encode image to latent state + """ + encoded = self.encoder_conv(x) + encoded = encoded.view(encoded.size(0), -1) + return self.encoder_fc(encoded) + + def decode(self, x): + """ + Decode latent state to image + """ + decoded = self.decoder_fc(x) + decoded = decoded.view(x.size(0), 64, self.img_height, self.img_width) + return self.decoder_conv(decoded) + + def forward(self, x): + reconstruct = self.decode(self.encode(x)) + return reconstruct + + +img_shape = (3,128,128) +model = AutoEncoder(img_shape=img_shape, state_dim=100) +summary(model, img_shape) + +``` ### References diff --git a/setup.py b/setup.py index 05a006e..4e2265a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="torchsummary", - version="1.5.3", + version="1.5.2", description="Model summary in PyTorch similar to `model.summary()` in Keras", url="https://github.com/ncble/pytorch-summary", author="Lu Lin @ncble", diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 2348198..dcd5cd4 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -40,20 +40,25 @@ def hook(module, input, output): and not (module == model) ): hooks.append(module.register_forward_hook(hook)) - + def cuda_device_valid(device_str): + valid = device_str.startswith("cuda") + try: + device_index = int(device_str.split(":")[-1]) + total_gpu_num = torch.cuda.device_count() + if (device_index < total_gpu_num): + return valid + else: + print("Cuda device '{}' dosen't exist. Find {} GPU(s)".format(device_str, total_gpu_num)) + return False + except: + print("CUDA device should have form like 'cuda:0', 'cuda:n', etc. (n is an integer)") + return False if isinstance(device, str): device = device.lower() assert device in [ "cuda", - "cuda:0", - "cuda:1", "cpu", - ], "Input device is not valid, please specify 'cuda' or 'cpu'" - - # if device == "cuda" and torch.cuda.is_available(): - # dtype = torch.cuda.FloatTensor - # else: - # dtype = torch.FloatTensor + ] or cuda_device_valid(device), "Input device is not valid, please specify 'cpu' or 'cuda' or 'cuda:n'" # multiple inputs to the network if isinstance(input_size, tuple): From 8decb9336862807e90938f4d6c11dfb00f3cc65b Mon Sep 17 00:00:00 2001 From: lulin Date: Thu, 4 Jul 2019 15:19:38 +0200 Subject: [PATCH 5/5] update version --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 4e2265a..17037df 100644 --- a/setup.py +++ b/setup.py @@ -4,8 +4,8 @@ name="torchsummary", version="1.5.2", description="Model summary in PyTorch similar to `model.summary()` in Keras", - url="https://github.com/ncble/pytorch-summary", - author="Lu Lin @ncble", - author_email="ncble17@gmail.com", + url="https://github.com/sksq96/pytorch-summary", + author="Shubham Chandel @sksq96", + author_email="shubham.zeez@gmail.com", packages=["torchsummary"], )