Skip to content

Commit 66f42c1

Browse files
Add freeze_layers (#6970)
Part of #6552. ### Description Add `freeze_layers`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5fd23b8 commit 66f42c1

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

monai/networks/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,3 +1111,47 @@ def replace_modules_temp(
11111111
# revert
11121112
for name, module in replaced:
11131113
_replace_modules(parent, name, module, [], strict_match=True, match_device=match_device)
1114+
1115+
1116+
def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None):
1117+
"""
1118+
A utilty function to help freeze specific layers.
1119+
1120+
Args:
1121+
model: a source PyTorch model to freeze layer.
1122+
freeze_vars: a regular expression to match the `model` variable names,
1123+
so that their `requires_grad` will set to `False`.
1124+
exclude_vars: a regular expression to match the `model` variable names,
1125+
except for matched variable names, other `requires_grad` will set to `False`.
1126+
1127+
Raises:
1128+
ValueError: when freeze_vars and exclude_vars are both specified.
1129+
1130+
"""
1131+
if freeze_vars is not None and exclude_vars is not None:
1132+
raise ValueError("Incompatible values: freeze_vars and exclude_vars are both specified.")
1133+
src_dict = get_state_dict(model)
1134+
1135+
frozen_keys = list()
1136+
if freeze_vars is not None:
1137+
to_freeze = {s_key for s_key in src_dict if freeze_vars and re.compile(freeze_vars).search(s_key)}
1138+
for name, param in model.named_parameters():
1139+
if name in to_freeze:
1140+
param.requires_grad = False
1141+
frozen_keys.append(name)
1142+
elif not param.requires_grad:
1143+
param.requires_grad = True
1144+
warnings.warn(
1145+
f"The freeze_vars does not include {param}, but requires_grad is False, change it to True."
1146+
)
1147+
if exclude_vars is not None:
1148+
to_exclude = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)}
1149+
for name, param in model.named_parameters():
1150+
if name not in to_exclude:
1151+
param.requires_grad = False
1152+
frozen_keys.append(name)
1153+
elif not param.requires_grad:
1154+
param.requires_grad = True
1155+
warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.")
1156+
1157+
logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.")

tests/test_freeze_layers.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks.utils import freeze_layers
20+
from monai.utils import set_determinism
21+
from tests.test_copy_model_state import _TestModelOne, _TestModelTwo
22+
23+
TEST_CASES = []
24+
__devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",)
25+
for _x in __devices:
26+
TEST_CASES.append(_x)
27+
28+
29+
class TestModuleState(unittest.TestCase):
30+
def tearDown(self):
31+
set_determinism(None)
32+
33+
@parameterized.expand(TEST_CASES)
34+
def test_freeze_vars(self, device):
35+
set_determinism(0)
36+
model = _TestModelOne(10, 20, 3)
37+
model.to(device)
38+
freeze_layers(model, "class")
39+
40+
for name, param in model.named_parameters():
41+
if "class_layer" in name:
42+
self.assertEqual(param.requires_grad, False)
43+
else:
44+
self.assertEqual(param.requires_grad, True)
45+
46+
@parameterized.expand(TEST_CASES)
47+
def test_exclude_vars(self, device):
48+
set_determinism(0)
49+
model = _TestModelTwo(10, 20, 10, 4)
50+
model.to(device)
51+
freeze_layers(model, exclude_vars="class")
52+
53+
for name, param in model.named_parameters():
54+
if "class_layer" in name:
55+
self.assertEqual(param.requires_grad, True)
56+
else:
57+
self.assertEqual(param.requires_grad, False)
58+
59+
60+
if __name__ == "__main__":
61+
unittest.main()

0 commit comments

Comments
 (0)