Skip to content

Commit 258f56d

Browse files
Pkaps25Peter Kaplinsky
andauthored
Add activation parameter to ResNet (#7749)
Fixes #7653 . ### Description Includes an `act` parameter to `ResNet` and its submodules to allow for passing the `inplace` param. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Peter Kaplinsky <peterkaplinsky@gmail.com> Co-authored-by: Peter Kaplinsky <peterkaplinsky@gmail.com>
1 parent d83fa56 commit 258f56d

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

monai/networks/nets/resnet.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from monai.networks.blocks.encoder import BaseEncoder
2525
from monai.networks.layers.factories import Conv, Norm, Pool
26-
from monai.networks.layers.utils import get_pool_layer
26+
from monai.networks.layers.utils import get_act_layer, get_pool_layer
2727
from monai.utils import ensure_tuple_rep
2828
from monai.utils.module import look_up_option, optional_import
2929

@@ -78,6 +78,7 @@ def __init__(
7878
spatial_dims: int = 3,
7979
stride: int = 1,
8080
downsample: nn.Module | partial | None = None,
81+
act: str | tuple = ("relu", {"inplace": True}),
8182
) -> None:
8283
"""
8384
Args:
@@ -86,6 +87,7 @@ def __init__(
8687
spatial_dims: number of spatial dimensions of the input image.
8788
stride: stride to use for first conv layer.
8889
downsample: which downsample layer to use.
90+
act: activation type and arguments. Defaults to relu.
8991
"""
9092
super().__init__()
9193

@@ -94,7 +96,7 @@ def __init__(
9496

9597
self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
9698
self.bn1 = norm_type(planes)
97-
self.relu = nn.ReLU(inplace=True)
99+
self.act = get_act_layer(name=act)
98100
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
99101
self.bn2 = norm_type(planes)
100102
self.downsample = downsample
@@ -105,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
105107

106108
out: torch.Tensor = self.conv1(x)
107109
out = self.bn1(out)
108-
out = self.relu(out)
110+
out = self.act(out)
109111

110112
out = self.conv2(out)
111113
out = self.bn2(out)
@@ -114,7 +116,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
114116
residual = self.downsample(x)
115117

116118
out += residual
117-
out = self.relu(out)
119+
out = self.act(out)
118120

119121
return out
120122

@@ -129,6 +131,7 @@ def __init__(
129131
spatial_dims: int = 3,
130132
stride: int = 1,
131133
downsample: nn.Module | partial | None = None,
134+
act: str | tuple = ("relu", {"inplace": True}),
132135
) -> None:
133136
"""
134137
Args:
@@ -137,6 +140,7 @@ def __init__(
137140
spatial_dims: number of spatial dimensions of the input image.
138141
stride: stride to use for second conv layer.
139142
downsample: which downsample layer to use.
143+
act: activation type and arguments. Defaults to relu.
140144
"""
141145

142146
super().__init__()
@@ -150,7 +154,7 @@ def __init__(
150154
self.bn2 = norm_type(planes)
151155
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
152156
self.bn3 = norm_type(planes * self.expansion)
153-
self.relu = nn.ReLU(inplace=True)
157+
self.act = get_act_layer(name=act)
154158
self.downsample = downsample
155159
self.stride = stride
156160

@@ -159,11 +163,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
159163

160164
out: torch.Tensor = self.conv1(x)
161165
out = self.bn1(out)
162-
out = self.relu(out)
166+
out = self.act(out)
163167

164168
out = self.conv2(out)
165169
out = self.bn2(out)
166-
out = self.relu(out)
170+
out = self.act(out)
167171

168172
out = self.conv3(out)
169173
out = self.bn3(out)
@@ -172,7 +176,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
172176
residual = self.downsample(x)
173177

174178
out += residual
175-
out = self.relu(out)
179+
out = self.act(out)
176180

177181
return out
178182

@@ -202,6 +206,7 @@ class ResNet(nn.Module):
202206
num_classes: number of output (classifications).
203207
feed_forward: whether to add the FC layer for the output, default to `True`.
204208
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
209+
act: activation type and arguments. Defaults to relu.
205210
206211
"""
207212

@@ -220,6 +225,7 @@ def __init__(
220225
num_classes: int = 400,
221226
feed_forward: bool = True,
222227
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
228+
act: str | tuple = ("relu", {"inplace": True}),
223229
) -> None:
224230
super().__init__()
225231

@@ -257,7 +263,7 @@ def __init__(
257263
bias=False,
258264
)
259265
self.bn1 = norm_type(self.in_planes)
260-
self.relu = nn.ReLU(inplace=True)
266+
self.act = get_act_layer(name=act)
261267
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
262268
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
263269
self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2)
@@ -329,7 +335,7 @@ def _make_layer(
329335
def forward(self, x: torch.Tensor) -> torch.Tensor:
330336
x = self.conv1(x)
331337
x = self.bn1(x)
332-
x = self.relu(x)
338+
x = self.act(x)
333339
if not self.no_max_pool:
334340
x = self.maxpool(x)
335341

@@ -396,7 +402,7 @@ def forward(self, inputs: torch.Tensor):
396402
"""
397403
x = self.conv1(inputs)
398404
x = self.bn1(x)
399-
x = self.relu(x)
405+
x = self.act(x)
400406

401407
features = []
402408
features.append(x)

tests/test_resnet.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
"num_classes": 3,
108108
"conv1_t_size": [3],
109109
"conv1_t_stride": 1,
110+
"act": ("relu", {"inplace": False}),
110111
},
111112
(1, 2, 32),
112113
(1, 3),
@@ -185,13 +186,29 @@
185186
(1, 3),
186187
]
187188

189+
TEST_CASE_8 = [
190+
{
191+
"block": "bottleneck",
192+
"layers": [3, 4, 6, 3],
193+
"block_inplanes": [64, 128, 256, 512],
194+
"spatial_dims": 1,
195+
"n_input_channels": 2,
196+
"num_classes": 3,
197+
"conv1_t_size": [3],
198+
"conv1_t_stride": 1,
199+
"act": ("relu", {"inplace": False}),
200+
},
201+
(1, 2, 32),
202+
(1, 3),
203+
]
204+
188205
TEST_CASES = []
189206
PRETRAINED_TEST_CASES = []
190207
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
191208
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
192209
TEST_CASES.append([model, *case])
193210
PRETRAINED_TEST_CASES.append([model, *case])
194-
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
211+
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]:
195212
TEST_CASES.append([ResNet, *case])
196213

197214
TEST_SCRIPT_CASES = [

0 commit comments

Comments
 (0)