diff --git a/README.md b/README.md index 5c6df2a..a13ed34 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ ## Keras style `model.summary()` in PyTorch [![PyPI version](https://badge.fury.io/py/torchsummary.svg)](https://badge.fury.io/py/torchsummary) -Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch. +Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch. (**New functionality**) The main function `summary` (`from torchsummary import summary`) can also be used to infer the output shape of a pytorch model. Thus, it provides a way to build pytorch model that supports any input shape like in Keras (see an [example](#scalable) below). ### Usage @@ -191,6 +191,100 @@ Estimated Total Size (MB): 0.78 ---------------------------------------------------------------- ``` +### Build pytorch model with scalable input shape (like Keras) + +```python + +import torch +import torch.nn as nn +from torchsummary import summary + +class AutoEncoder(nn.Module): + """ + ResNet autoencoder network that support any input shape as model in Keras + :param img_shape (tuple, channel last): support any image input shape + :param state_dim: (int) latent state dimension + """ + + def __init__(self, img_shape=(3, 224, 224), state_dim=3): + super().__init__() + self.state_dim = state_dim + self.img_shape = img_shape + + self.encoder_conv = nn.Sequential( + nn.Conv2d(self.img_shape[0], 64, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2), + + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2) + ) + ## Without torchsummary here, it's impossible to build model with scalable input shape as Keras. + outshape = summary(self.encoder_conv, img_shape, show=False) # [-1, channels, high, width] + self.img_height, self.img_width = outshape[-2:] + self.encoder_fc = nn.Sequential( + nn.Linear(self.img_height * self.img_width * 64, state_dim) + ) + + self.decoder_fc = nn.Sequential( + nn.Linear(state_dim, self.img_height * self.img_width * 64) + ) + + self.decoder_conv = nn.Sequential( + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2), + nn.BatchNorm2d(64), + nn.ReLU(), + + nn.ConvTranspose2d(64, self.img_shape[0], kernel_size=4, stride=2), + nn.Tanh() + ) + + def encode(self, x): + """ + Encode image to latent state + """ + encoded = self.encoder_conv(x) + encoded = encoded.view(encoded.size(0), -1) + return self.encoder_fc(encoded) + + def decode(self, x): + """ + Decode latent state to image + """ + decoded = self.decoder_fc(x) + decoded = decoded.view(x.size(0), 64, self.img_height, self.img_width) + return self.decoder_conv(decoded) + + def forward(self, x): + reconstruct = self.decode(self.encode(x)) + return reconstruct + + +img_shape = (3,128,128) +model = AutoEncoder(img_shape=img_shape, state_dim=100) +summary(model, img_shape) + +``` ### References diff --git a/setup.py b/setup.py index 35f13cd..17037df 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="torchsummary", - version="1.5.1", + version="1.5.2", description="Model summary in PyTorch similar to `model.summary()` in Keras", url="https://github.com/sksq96/pytorch-summary", author="Shubham Chandel @sksq96", diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..dcd5cd4 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,7 +6,7 @@ import numpy as np -def summary(model, input_size, batch_size=-1, device="cuda"): +def summary(model, input_size, batch_size=-1, device="cpu", show=True): def register_hook(module): @@ -40,24 +40,32 @@ def hook(module, input, output): and not (module == model) ): hooks.append(module.register_forward_hook(hook)) - - device = device.lower() - assert device in [ - "cuda", - "cpu", - ], "Input device is not valid, please specify 'cuda' or 'cpu'" - - if device == "cuda" and torch.cuda.is_available(): - dtype = torch.cuda.FloatTensor - else: - dtype = torch.FloatTensor + def cuda_device_valid(device_str): + valid = device_str.startswith("cuda") + try: + device_index = int(device_str.split(":")[-1]) + total_gpu_num = torch.cuda.device_count() + if (device_index < total_gpu_num): + return valid + else: + print("Cuda device '{}' dosen't exist. Find {} GPU(s)".format(device_str, total_gpu_num)) + return False + except: + print("CUDA device should have form like 'cuda:0', 'cuda:n', etc. (n is an integer)") + return False + if isinstance(device, str): + device = device.lower() + assert device in [ + "cuda", + "cpu", + ] or cuda_device_valid(device), "Input device is not valid, please specify 'cpu' or 'cuda' or 'cuda:n'" # multiple inputs to the network if isinstance(input_size, tuple): input_size = [input_size] # batch_size of 2 for batchnorm - x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] + x = [torch.rand(2, *in_size).to(device) for in_size in input_size] # print(type(x[0])) # create properties @@ -74,11 +82,11 @@ def hook(module, input, output): # remove these hooks for h in hooks: h.remove() - - print("----------------------------------------------------------------") - line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") - print(line_new) - print("================================================================") + if show: + print("----------------------------------------------------------------") + line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") + print(line_new) + print("================================================================") total_params = 0 total_output = 0 trainable_params = 0 @@ -94,22 +102,23 @@ def hook(module, input, output): if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: trainable_params += summary[layer]["nb_params"] - print(line_new) + if show: + print(line_new) # assume 4 bytes/number (float on cuda). total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size - - print("================================================================") - print("Total params: {0:,}".format(total_params)) - print("Trainable params: {0:,}".format(trainable_params)) - print("Non-trainable params: {0:,}".format(total_params - trainable_params)) - print("----------------------------------------------------------------") - print("Input size (MB): %0.2f" % total_input_size) - print("Forward/backward pass size (MB): %0.2f" % total_output_size) - print("Params size (MB): %0.2f" % total_params_size) - print("Estimated Total Size (MB): %0.2f" % total_size) - print("----------------------------------------------------------------") - # return summary + if show: + print("================================================================") + print("Total params: {0:,}".format(total_params)) + print("Trainable params: {0:,}".format(trainable_params)) + print("Non-trainable params: {0:,}".format(total_params - trainable_params)) + print("----------------------------------------------------------------") + print("Input size (MB): %0.2f" % total_input_size) + print("Forward/backward pass size (MB): %0.2f" % total_output_size) + print("Params size (MB): %0.2f" % total_params_size) + print("Estimated Total Size (MB): %0.2f" % total_size) + print("----------------------------------------------------------------") + return summary[layer]["output_shape"]