2222import torch .nn as nn
2323
2424from monai .networks .blocks .encoder import BaseEncoder
25- from monai .networks .layers .factories import Conv , Norm , Pool
26- from monai .networks .layers .utils import get_act_layer , get_pool_layer
25+ from monai .networks .layers .factories import Conv , Pool
26+ from monai .networks .layers .utils import get_act_layer , get_norm_layer , get_pool_layer
2727from monai .utils import ensure_tuple_rep
2828from monai .utils .module import look_up_option , optional_import
2929
@@ -79,6 +79,7 @@ def __init__(
7979 stride : int = 1 ,
8080 downsample : nn .Module | partial | None = None ,
8181 act : str | tuple = ("relu" , {"inplace" : True }),
82+ norm : str | tuple = "batch" ,
8283 ) -> None :
8384 """
8485 Args:
@@ -88,17 +89,18 @@ def __init__(
8889 stride: stride to use for first conv layer.
8990 downsample: which downsample layer to use.
9091 act: activation type and arguments. Defaults to relu.
92+ norm: feature normalization type and arguments. Defaults to batch norm.
9193 """
9294 super ().__init__ ()
9395
9496 conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
95- norm_type : Callable = Norm [ Norm . BATCH , spatial_dims ]
97+ norm_layer = get_norm_layer ( name = norm , spatial_dims = spatial_dims , channels = planes )
9698
9799 self .conv1 = conv_type (in_planes , planes , kernel_size = 3 , padding = 1 , stride = stride , bias = False )
98- self .bn1 = norm_type ( planes )
100+ self .bn1 = norm_layer
99101 self .act = get_act_layer (name = act )
100102 self .conv2 = conv_type (planes , planes , kernel_size = 3 , padding = 1 , bias = False )
101- self .bn2 = norm_type ( planes )
103+ self .bn2 = norm_layer
102104 self .downsample = downsample
103105 self .stride = stride
104106
@@ -132,6 +134,7 @@ def __init__(
132134 stride : int = 1 ,
133135 downsample : nn .Module | partial | None = None ,
134136 act : str | tuple = ("relu" , {"inplace" : True }),
137+ norm : str | tuple = "batch" ,
135138 ) -> None :
136139 """
137140 Args:
@@ -141,19 +144,20 @@ def __init__(
141144 stride: stride to use for second conv layer.
142145 downsample: which downsample layer to use.
143146 act: activation type and arguments. Defaults to relu.
147+ norm: feature normalization type and arguments. Defaults to batch norm.
144148 """
145149
146150 super ().__init__ ()
147151
148152 conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
149- norm_type : Callable = Norm [ Norm . BATCH , spatial_dims ]
153+ norm_layer = partial ( get_norm_layer , name = norm , spatial_dims = spatial_dims )
150154
151155 self .conv1 = conv_type (in_planes , planes , kernel_size = 1 , bias = False )
152- self .bn1 = norm_type ( planes )
156+ self .bn1 = norm_layer ( channels = planes )
153157 self .conv2 = conv_type (planes , planes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
154- self .bn2 = norm_type ( planes )
158+ self .bn2 = norm_layer ( channels = planes )
155159 self .conv3 = conv_type (planes , planes * self .expansion , kernel_size = 1 , bias = False )
156- self .bn3 = norm_type ( planes * self .expansion )
160+ self .bn3 = norm_layer ( channels = planes * self .expansion )
157161 self .act = get_act_layer (name = act )
158162 self .downsample = downsample
159163 self .stride = stride
@@ -207,6 +211,7 @@ class ResNet(nn.Module):
207211 feed_forward: whether to add the FC layer for the output, default to `True`.
208212 bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
209213 act: activation type and arguments. Defaults to relu.
214+ norm: feature normalization type and arguments. Defaults to batch norm.
210215
211216 """
212217
@@ -226,6 +231,7 @@ def __init__(
226231 feed_forward : bool = True ,
227232 bias_downsample : bool = True , # for backwards compatibility (also see PR #5477)
228233 act : str | tuple = ("relu" , {"inplace" : True }),
234+ norm : str | tuple = "batch" ,
229235 ) -> None :
230236 super ().__init__ ()
231237
@@ -238,7 +244,6 @@ def __init__(
238244 raise ValueError ("Unknown block '%s', use basic or bottleneck" % block )
239245
240246 conv_type : type [nn .Conv1d | nn .Conv2d | nn .Conv3d ] = Conv [Conv .CONV , spatial_dims ]
241- norm_type : type [nn .BatchNorm1d | nn .BatchNorm2d | nn .BatchNorm3d ] = Norm [Norm .BATCH , spatial_dims ]
242247 pool_type : type [nn .MaxPool1d | nn .MaxPool2d | nn .MaxPool3d ] = Pool [Pool .MAX , spatial_dims ]
243248 avgp_type : type [nn .AdaptiveAvgPool1d | nn .AdaptiveAvgPool2d | nn .AdaptiveAvgPool3d ] = Pool [
244249 Pool .ADAPTIVEAVG , spatial_dims
@@ -262,7 +267,9 @@ def __init__(
262267 padding = tuple (k // 2 for k in conv1_kernel_size ),
263268 bias = False ,
264269 )
265- self .bn1 = norm_type (self .in_planes )
270+
271+ norm_layer = get_norm_layer (name = norm , spatial_dims = spatial_dims , channels = self .in_planes )
272+ self .bn1 = norm_layer
266273 self .act = get_act_layer (name = act )
267274 self .maxpool = pool_type (kernel_size = 3 , stride = 2 , padding = 1 )
268275 self .layer1 = self ._make_layer (block , block_inplanes [0 ], layers [0 ], spatial_dims , shortcut_type )
@@ -275,7 +282,7 @@ def __init__(
275282 for m in self .modules ():
276283 if isinstance (m , conv_type ):
277284 nn .init .kaiming_normal_ (torch .as_tensor (m .weight ), mode = "fan_out" , nonlinearity = "relu" )
278- elif isinstance (m , norm_type ):
285+ elif isinstance (m , type ( norm_layer ) ):
279286 nn .init .constant_ (torch .as_tensor (m .weight ), 1 )
280287 nn .init .constant_ (torch .as_tensor (m .bias ), 0 )
281288 elif isinstance (m , nn .Linear ):
@@ -295,9 +302,9 @@ def _make_layer(
295302 spatial_dims : int ,
296303 shortcut_type : str ,
297304 stride : int = 1 ,
305+ norm : str | tuple = "batch" ,
298306 ) -> nn .Sequential :
299307 conv_type : Callable = Conv [Conv .CONV , spatial_dims ]
300- norm_type : Callable = Norm [Norm .BATCH , spatial_dims ]
301308
302309 downsample : nn .Module | partial | None = None
303310 if stride != 1 or self .in_planes != planes * block .expansion :
@@ -317,18 +324,23 @@ def _make_layer(
317324 stride = stride ,
318325 bias = self .bias_downsample ,
319326 ),
320- norm_type ( planes * block .expansion ),
327+ get_norm_layer ( name = norm , spatial_dims = spatial_dims , channels = planes * block .expansion ),
321328 )
322329
323330 layers = [
324331 block (
325- in_planes = self .in_planes , planes = planes , spatial_dims = spatial_dims , stride = stride , downsample = downsample
332+ in_planes = self .in_planes ,
333+ planes = planes ,
334+ spatial_dims = spatial_dims ,
335+ stride = stride ,
336+ downsample = downsample ,
337+ norm = norm ,
326338 )
327339 ]
328340
329341 self .in_planes = planes * block .expansion
330342 for _i in range (1 , blocks ):
331- layers .append (block (self .in_planes , planes , spatial_dims = spatial_dims ))
343+ layers .append (block (self .in_planes , planes , spatial_dims = spatial_dims , norm = norm ))
332344
333345 return nn .Sequential (* layers )
334346
0 commit comments