Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
de2b6bd
feat: add optional gradient checkpointing to unet
Sep 3, 2025
66edcb5
fix: small ruff issue
Sep 3, 2025
e66e357
Update monai/networks/nets/unet.py
ferreirafabio80 Sep 4, 2025
feefcaa
docs: update docstrings
Sep 4, 2025
e112457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
f673ca1
fix: avoid BatchNorm subblocks
Sep 4, 2025
69540ff
fix: revert batch norm changes
Sep 4, 2025
42ec757
refactor: creates a subclass of UNet and overrides the get connection…
Oct 1, 2025
a2e8474
chore: remove use checkpointing from doc string
Oct 1, 2025
4c4782e
fix: linting issues
Oct 2, 2025
515c659
feat: add activation checkpointing to down and up paths to be more ef…
Oct 8, 2025
da5a3a4
refactor: move activation checkpointing wrapper to blocks
Nov 4, 2025
43dec88
chore: add docstrings to checkpointed unet
Nov 4, 2025
84c0f48
test: add checkpoint unet test
Nov 7, 2025
5805515
fix: change test name
Nov 7, 2025
1aa8e3c
fix: simplify test and make sure that checkpoint unet runs well in tr…
Nov 7, 2025
447d9f2
fix: set seed
Nov 7, 2025
b20a19e
fix: fix testing bugs
Nov 7, 2025
41f000f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
a068c0e
chore: add test docstrings
Nov 10, 2025
26668cd
DCO Remediation Commit for Fabio Ferreira <f.ferreira@qureight.com>
Nov 10, 2025
814fa80
fix: remove test script save
Nov 13, 2025
c45ee48
fix: tighten tolerance for numerical equivalence
Nov 13, 2025
4349d3f
chore: update doc strings
Nov 14, 2025
885993b
Merge branch 'dev' into feat/add_activation_checkpointing_to_unet
KumoLiu Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions monai/networks/blocks/activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class ActivationCheckpointWrapper(nn.Module):
"""Wrapper applying activation checkpointing to a module during training.

Args:
module: The module to wrap with activation checkpointing.
"""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with optional activation checkpointing.

Args:
x: Input tensor.

Returns:
Output tensor from the wrapped module.
"""
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
28 changes: 27 additions & 1 deletion monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import torch
import torch.nn as nn

from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection

__all__ = ["UNet", "Unet"]
__all__ = ["UNet", "Unet", "CheckpointUNet"]


class UNet(nn.Module):
Expand Down Expand Up @@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class CheckpointUNet(UNet):
"""UNet variant that wraps internal connection blocks with activation checkpointing.

See `UNet` for constructor arguments. During training with gradients enabled,
intermediate activations inside encoder-decoder connections are recomputed in
the backward pass to reduce peak memory usage at the cost of extra compute.
"""

def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
"""Returns connection block with activation checkpointing applied to all components.

Args:
down_path: encoding half of the layer (will be wrapped with checkpointing).
up_path: decoding half of the layer (will be wrapped with checkpointing).
subblock: block defining the next layer (will be wrapped with checkpointing).

Returns:
Connection block with all components wrapped for activation checkpointing.
"""
subblock = ActivationCheckpointWrapper(subblock)
down_path = ActivationCheckpointWrapper(down_path)
up_path = ActivationCheckpointWrapper(up_path)
return super()._get_connection_block(down_path, up_path, subblock)


Unet = UNet
186 changes: 186 additions & 0 deletions tests/networks/nets/test_checkpointunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.layers import Act, Norm
from monai.networks.nets.unet import CheckpointUNet, UNet

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 0,
},
(16, 1, 32, 32),
(16, 3, 32, 32),
]

TEST_CASE_1 = [ # single channel 2D, batch 16
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
},
(16, 1, 32, 32),
(16, 3, 32, 32),
]

TEST_CASE_2 = [ # single channel 3D, batch 16
{
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
},
(16, 1, 32, 24, 48),
(16, 3, 32, 24, 48),
]

TEST_CASE_3 = [ # 4-channel 3D, batch 16
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
"norm": Norm.BATCH,
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
"act": (Act.LEAKYRELU, {"negative_slope": 0.2}),
"adn_ordering": "NA",
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
"act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}),
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]


class TestCheckpointUNet(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
"""Validate CheckpointUNet output shapes across configurations.

Args:
input_param: Dictionary of UNet constructor arguments.
input_shape: Tuple specifying input tensor dimensions.
expected_shape: Tuple specifying expected output tensor dimensions.
"""
net = CheckpointUNet(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_checkpointing_equivalence_eval(self):
"""Confirm eval parity when checkpointing is inactive."""
params = dict(
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
)

x = torch.randn(2, 1, 32, 32, device=device)

torch.manual_seed(42)
net_plain = UNet(**params).to(device)

torch.manual_seed(42)
net_ckpt = CheckpointUNet(**params).to(device)

# Both in eval mode disables checkpointing logic
with eval_mode(net_ckpt), eval_mode(net_plain):
y_ckpt = net_ckpt(x)
y_plain = net_plain(x)

# Check shape equality
self.assertEqual(y_ckpt.shape, y_plain.shape)

# Check numerical equivalence
self.assertTrue(
torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5),
f"Eval-mode outputs differ: max abs diff={torch.max(torch.abs(y_ckpt - y_plain)).item():.2e}",
)

def test_checkpointing_activates_training(self):
"""Verify checkpointing recomputes activations during training."""
params = dict(
spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
)

net = CheckpointUNet(**params).to(device)
net.train()

x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True)
y = net(x)
loss = y.mean()
loss.backward()

# gradient flow check
grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None)
self.assertGreater(grad_norm.item(), 0.0)


if __name__ == "__main__":
unittest.main()
Loading