99"""
1010import math
1111from functools import partial
12- from typing import Any , Dict , List , Optional , Tuple , Type
12+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
1313
1414import torch
1515import torch .nn as nn
1616import torch .nn .functional as F
1717from torch import Tensor
1818
1919from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
20- from timm .layers import DropBlock2d , DropPath , AvgPool2dSame , BlurPool2d , GroupNorm , create_attn , get_attn , \
21- get_act_layer , get_norm_layer , create_classifier
20+ from timm .layers import DropBlock2d , DropPath , AvgPool2dSame , BlurPool2d , GroupNorm , LayerType , create_attn , \
21+ get_attn , get_act_layer , get_norm_layer , create_classifier
2222from ._builder import build_model_with_cfg
2323from ._manipulate import checkpoint_seq
2424from ._registry import register_model , generate_default_cfgs , register_model_deprecations
25- from ._typing import LayerType
2625
2726__all__ = ['ResNet' , 'BasicBlock' , 'Bottleneck' ] # model_registry will add each entrypoint fn to this
2827
2928
30- def get_padding (kernel_size : int , stride : int , dilation : int = 1 ):
29+ def get_padding (kernel_size : int , stride : int , dilation : int = 1 ) -> int :
3130 padding = ((stride - 1 ) + dilation * (kernel_size - 1 )) // 2
3231 return padding
3332
3433
35- def create_aa (aa_layer , channels , stride = 2 , enable = True ):
34+ def create_aa (aa_layer : Type [ nn . Module ] , channels : int , stride : int = 2 , enable : bool = True ) -> nn . Module :
3635 if not aa_layer or not enable :
3736 return nn .Identity ()
3837 if issubclass (aa_layer , nn .AvgPool2d ):
@@ -55,11 +54,11 @@ def __init__(
5554 reduce_first : int = 1 ,
5655 dilation : int = 1 ,
5756 first_dilation : Optional [int ] = None ,
58- act_layer : nn .Module = nn .ReLU ,
59- norm_layer : nn .Module = nn .BatchNorm2d ,
60- attn_layer : Optional [nn .Module ] = None ,
61- aa_layer : Optional [nn .Module ] = None ,
62- drop_block : Type [nn .Module ] = None ,
57+ act_layer : Type [ nn .Module ] = nn .ReLU ,
58+ norm_layer : Type [ nn .Module ] = nn .BatchNorm2d ,
59+ attn_layer : Optional [Type [ nn .Module ] ] = None ,
60+ aa_layer : Optional [Type [ nn .Module ] ] = None ,
61+ drop_block : Optional [ Type [nn .Module ] ] = None ,
6362 drop_path : Optional [nn .Module ] = None ,
6463 ):
6564 """
@@ -153,11 +152,11 @@ def __init__(
153152 reduce_first : int = 1 ,
154153 dilation : int = 1 ,
155154 first_dilation : Optional [int ] = None ,
156- act_layer : nn .Module = nn .ReLU ,
157- norm_layer : nn .Module = nn .BatchNorm2d ,
158- attn_layer : Optional [nn .Module ] = None ,
159- aa_layer : Optional [nn .Module ] = None ,
160- drop_block : Type [nn .Module ] = None ,
155+ act_layer : Type [ nn .Module ] = nn .ReLU ,
156+ norm_layer : Type [ nn .Module ] = nn .BatchNorm2d ,
157+ attn_layer : Optional [Type [ nn .Module ] ] = None ,
158+ aa_layer : Optional [Type [ nn .Module ] ] = None ,
159+ drop_block : Optional [ Type [nn .Module ] ] = None ,
161160 drop_path : Optional [nn .Module ] = None ,
162161 ):
163162 """
@@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.):
296295
297296
298297def make_blocks (
299- block_fn : nn . Module ,
298+ block_fn : Union [ BasicBlock , Bottleneck ] ,
300299 channels : List [int ],
301300 block_repeats : List [int ],
302301 inplanes : int ,
@@ -395,7 +394,7 @@ class ResNet(nn.Module):
395394
396395 def __init__ (
397396 self ,
398- block : nn . Module ,
397+ block : Union [ BasicBlock , Bottleneck ] ,
399398 layers : List [int ],
400399 num_classes : int = 1000 ,
401400 in_chans : int = 3 ,
@@ -411,7 +410,7 @@ def __init__(
411410 avg_down : bool = False ,
412411 act_layer : LayerType = nn .ReLU ,
413412 norm_layer : LayerType = nn .BatchNorm2d ,
414- aa_layer : Optional [nn .Module ] = None ,
413+ aa_layer : Optional [Type [ nn .Module ] ] = None ,
415414 drop_rate : float = 0.0 ,
416415 drop_path_rate : float = 0. ,
417416 drop_block_rate : float = 0. ,
0 commit comments