2323
2424from monai .networks .blocks .encoder import BaseEncoder
2525from monai .networks .layers .factories import Conv , Norm , Pool
26- from monai .networks .layers .utils import get_pool_layer
26+ from monai .networks .layers .utils import get_act_layer , get_pool_layer
2727from monai .utils import ensure_tuple_rep
2828from monai .utils .module import look_up_option , optional_import
2929
@@ -78,6 +78,7 @@ def __init__(
7878 spatial_dims : int = 3 ,
7979 stride : int = 1 ,
8080 downsample : nn .Module | partial | None = None ,
81+ act : str | tuple = ("relu" , {"inplace" : True }),
8182 ) -> None :
8283 """
8384 Args:
@@ -86,6 +87,7 @@ def __init__(
8687 spatial_dims: number of spatial dimensions of the input image.
8788 stride: stride to use for first conv layer.
8889 downsample: which downsample layer to use.
90+ act: activation type and arguments. Defaults to relu.
8991 """
9092 super ().__init__ ()
9193
@@ -94,7 +96,7 @@ def __init__(
9496
9597 self .conv1 = conv_type (in_planes , planes , kernel_size = 3 , padding = 1 , stride = stride , bias = False )
9698 self .bn1 = norm_type (planes )
97- self .relu = nn . ReLU ( inplace = True )
99+ self .act = get_act_layer ( name = act )
98100 self .conv2 = conv_type (planes , planes , kernel_size = 3 , padding = 1 , bias = False )
99101 self .bn2 = norm_type (planes )
100102 self .downsample = downsample
@@ -105,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
105107
106108 out : torch .Tensor = self .conv1 (x )
107109 out = self .bn1 (out )
108- out = self .relu (out )
110+ out = self .act (out )
109111
110112 out = self .conv2 (out )
111113 out = self .bn2 (out )
@@ -114,7 +116,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
114116 residual = self .downsample (x )
115117
116118 out += residual
117- out = self .relu (out )
119+ out = self .act (out )
118120
119121 return out
120122
@@ -129,6 +131,7 @@ def __init__(
129131 spatial_dims : int = 3 ,
130132 stride : int = 1 ,
131133 downsample : nn .Module | partial | None = None ,
134+ act : str | tuple = ("relu" , {"inplace" : True }),
132135 ) -> None :
133136 """
134137 Args:
@@ -137,6 +140,7 @@ def __init__(
137140 spatial_dims: number of spatial dimensions of the input image.
138141 stride: stride to use for second conv layer.
139142 downsample: which downsample layer to use.
143+ act: activation type and arguments. Defaults to relu.
140144 """
141145
142146 super ().__init__ ()
@@ -150,7 +154,7 @@ def __init__(
150154 self .bn2 = norm_type (planes )
151155 self .conv3 = conv_type (planes , planes * self .expansion , kernel_size = 1 , bias = False )
152156 self .bn3 = norm_type (planes * self .expansion )
153- self .relu = nn . ReLU ( inplace = True )
157+ self .act = get_act_layer ( name = act )
154158 self .downsample = downsample
155159 self .stride = stride
156160
@@ -159,11 +163,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
159163
160164 out : torch .Tensor = self .conv1 (x )
161165 out = self .bn1 (out )
162- out = self .relu (out )
166+ out = self .act (out )
163167
164168 out = self .conv2 (out )
165169 out = self .bn2 (out )
166- out = self .relu (out )
170+ out = self .act (out )
167171
168172 out = self .conv3 (out )
169173 out = self .bn3 (out )
@@ -172,7 +176,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
172176 residual = self .downsample (x )
173177
174178 out += residual
175- out = self .relu (out )
179+ out = self .act (out )
176180
177181 return out
178182
@@ -202,6 +206,7 @@ class ResNet(nn.Module):
202206 num_classes: number of output (classifications).
203207 feed_forward: whether to add the FC layer for the output, default to `True`.
204208 bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
209+ act: activation type and arguments. Defaults to relu.
205210
206211 """
207212
@@ -220,6 +225,7 @@ def __init__(
220225 num_classes : int = 400 ,
221226 feed_forward : bool = True ,
222227 bias_downsample : bool = True , # for backwards compatibility (also see PR #5477)
228+ act : str | tuple = ("relu" , {"inplace" : True }),
223229 ) -> None :
224230 super ().__init__ ()
225231
@@ -257,7 +263,7 @@ def __init__(
257263 bias = False ,
258264 )
259265 self .bn1 = norm_type (self .in_planes )
260- self .relu = nn . ReLU ( inplace = True )
266+ self .act = get_act_layer ( name = act )
261267 self .maxpool = pool_type (kernel_size = 3 , stride = 2 , padding = 1 )
262268 self .layer1 = self ._make_layer (block , block_inplanes [0 ], layers [0 ], spatial_dims , shortcut_type )
263269 self .layer2 = self ._make_layer (block , block_inplanes [1 ], layers [1 ], spatial_dims , shortcut_type , stride = 2 )
@@ -329,7 +335,7 @@ def _make_layer(
329335 def forward (self , x : torch .Tensor ) -> torch .Tensor :
330336 x = self .conv1 (x )
331337 x = self .bn1 (x )
332- x = self .relu (x )
338+ x = self .act (x )
333339 if not self .no_max_pool :
334340 x = self .maxpool (x )
335341
@@ -396,7 +402,7 @@ def forward(self, inputs: torch.Tensor):
396402 """
397403 x = self .conv1 (inputs )
398404 x = self .bn1 (x )
399- x = self .relu (x )
405+ x = self .act (x )
400406
401407 features = []
402408 features .append (x )
0 commit comments