Skip to content

Commit e44f3b9

Browse files
authored
Generalize SobelGradients to 3D and Any Axis (#5189)
Fixes #5188 ### Description This PR reimplement `SobelGradients` and `SobelGradientsd` using separable kernels and generalize it to images with any spatial dimension (2D, 3D, etc.) and option to calculate the gradient along any given axis. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
1 parent fe60832 commit e44f3b9

File tree

5 files changed

+400
-115
lines changed

5 files changed

+400
-115
lines changed

monai/transforms/post/array.py

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
"""
1515

1616
import warnings
17-
from typing import Callable, Iterable, Optional, Sequence, Union
17+
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
1818

1919
import numpy as np
2020
import torch
21+
import torch.nn.functional as F
2122

2223
from monai.config.type_definitions import NdarrayOrTensor
2324
from monai.data.meta_obj import get_track_meta
2425
from monai.data.meta_tensor import MetaTensor
2526
from monai.networks import one_hot
26-
from monai.networks.layers import GaussianFilter, apply_filter
27+
from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering
2728
from monai.transforms.inverse import InvertibleTransform
2829
from monai.transforms.transform import Transform
2930
from monai.transforms.utils import (
@@ -821,13 +822,18 @@ def __call__(self, data):
821822

822823

823824
class SobelGradients(Transform):
824-
"""Calculate Sobel horizontal and vertical gradients
825+
"""Calculate Sobel gradients of a grayscale image with the shape of (CxH[xWxDx...]).
825826
826827
Args:
827828
kernel_size: the size of the Sobel kernel. Defaults to 3.
828-
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
829+
spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
830+
along each of the provide axis. By default it calculate the gradient for all spatial axes.
831+
normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
832+
normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
833+
padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
834+
Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
835+
See ``torch.nn.Conv1d()`` for more information.
829836
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
830-
device: the device to create the kernel on. Defaults to `"cpu"`.
831837
832838
"""
833839

@@ -836,36 +842,90 @@ class SobelGradients(Transform):
836842
def __init__(
837843
self,
838844
kernel_size: int = 3,
839-
padding: Union[int, str] = "same",
845+
spatial_axes: Optional[Union[Sequence[int], int]] = None,
846+
normalize_kernels: bool = True,
847+
normalize_gradients: bool = False,
848+
padding_mode: str = "reflect",
840849
dtype: torch.dtype = torch.float32,
841-
device: Union[torch.device, int, str] = "cpu",
842850
) -> None:
843851
super().__init__()
844-
self.kernel: torch.Tensor = self._get_kernel(kernel_size, dtype, device)
845-
self.padding = padding
846-
847-
def _get_kernel(self, size, dtype, device) -> torch.Tensor:
852+
self.padding = padding_mode
853+
self.spatial_axes = spatial_axes
854+
self.normalize_kernels = normalize_kernels
855+
self.normalize_gradients = normalize_gradients
856+
self.kernel_diff, self.kernel_smooth = self._get_kernel(kernel_size, dtype)
857+
858+
def _get_kernel(self, size, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
859+
if size < 3:
860+
raise ValueError(f"Sobel kernel size should be at least three. {size} was given.")
848861
if size % 2 == 0:
849862
raise ValueError(f"Sobel kernel size should be an odd number. {size} was given.")
850-
if not dtype.is_floating_point:
851-
raise ValueError(f"`dtype` for Sobel kernel should be floating point. {dtype} was given.")
852-
853-
numerator: torch.Tensor = torch.arange(
854-
-size // 2 + 1, size // 2 + 1, dtype=dtype, device=device, requires_grad=False
855-
).expand(size, size)
856-
denominator = numerator * numerator
857-
denominator = denominator + denominator.T
858-
denominator[:, size // 2] = 1.0 # to avoid division by zero
859-
kernel = numerator / denominator
860-
return kernel
863+
864+
kernel_diff = torch.tensor([[[-1, 0, 1]]], dtype=dtype)
865+
kernel_smooth = torch.tensor([[[1, 2, 1]]], dtype=dtype)
866+
kernel_expansion = torch.tensor([[[1, 2, 1]]], dtype=dtype)
867+
868+
if self.normalize_kernels:
869+
if not dtype.is_floating_point:
870+
raise ValueError(
871+
f"`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. {dtype} was given."
872+
)
873+
kernel_diff /= 2.0
874+
kernel_smooth /= 4.0
875+
kernel_expansion /= 4.0
876+
877+
# Expand the kernel to larger size than 3
878+
expand = (size - 3) // 2
879+
for _ in range(expand):
880+
kernel_diff = F.conv1d(kernel_diff, kernel_expansion, padding=2)
881+
kernel_smooth = F.conv1d(kernel_smooth, kernel_expansion, padding=2)
882+
883+
return kernel_diff.squeeze(), kernel_smooth.squeeze()
861884

862885
def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
863886
image_tensor = convert_to_tensor(image, track_meta=get_track_meta())
864-
kernel_v = self.kernel.to(image_tensor.device)
865-
kernel_h = kernel_v.T
866-
image_tensor = image_tensor.unsqueeze(0) # adds a batch dim
867-
grad_v = apply_filter(image_tensor, kernel_v, padding=self.padding)
868-
grad_h = apply_filter(image_tensor, kernel_h, padding=self.padding)
869-
grad = torch.cat([grad_h, grad_v], dim=1)
870-
grad, *_ = convert_to_dst_type(grad.squeeze(0), image_tensor)
871-
return grad
887+
888+
# Check/set spatial axes
889+
n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension
890+
valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0))
891+
892+
# Check gradient axes to be valid
893+
if self.spatial_axes is None:
894+
spatial_axes = list(range(n_spatial_dims))
895+
else:
896+
invalid_axis = set(ensure_tuple(self.spatial_axes)) - set(valid_spatial_axes)
897+
if invalid_axis:
898+
raise ValueError(
899+
f"The provide axes to calculate gradient is not valid: {invalid_axis}. "
900+
f"The image has {n_spatial_dims} spatial dimensions so it should be: {valid_spatial_axes}."
901+
)
902+
spatial_axes = [ax % n_spatial_dims if ax < 0 else ax for ax in ensure_tuple(self.spatial_axes)]
903+
904+
# Add batch dimension for separable_filtering
905+
image_tensor = image_tensor.unsqueeze(0)
906+
907+
# Get the Sobel kernels
908+
kernel_diff = self.kernel_diff.to(image_tensor.device)
909+
kernel_smooth = self.kernel_smooth.to(image_tensor.device)
910+
911+
# Calculate gradient
912+
grad_list = []
913+
for ax in spatial_axes:
914+
kernels = [kernel_smooth] * n_spatial_dims
915+
kernels[ax - 1] = kernel_diff
916+
grad = separable_filtering(image_tensor, kernels, mode=self.padding)
917+
if self.normalize_gradients:
918+
grad_min = grad.min()
919+
if grad_min != grad.max():
920+
grad -= grad_min
921+
grad_max = grad.max()
922+
if grad_max > 0:
923+
grad /= grad_max
924+
grad_list.append(grad)
925+
926+
grads = torch.cat(grad_list, dim=1)
927+
928+
# Remove batch dimension and convert the gradient type to be the same as input image
929+
grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]
930+
931+
return grads

monai/transforms/post/dictionary.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,14 +794,19 @@ def get_saver(self):
794794

795795

796796
class SobelGradientsd(MapTransform):
797-
"""Calculate Sobel horizontal and vertical gradients.
797+
"""Calculate Sobel horizontal and vertical gradients of a grayscale image.
798798
799799
Args:
800800
keys: keys of the corresponding items to model output.
801801
kernel_size: the size of the Sobel kernel. Defaults to 3.
802-
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
802+
spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
803+
along each of the provide axis. By default it calculate the gradient for all spatial axes.
804+
normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
805+
normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
806+
padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
807+
Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
808+
See ``torch.nn.Conv1d()`` for more information.
803809
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
804-
device: the device to create the kernel on. Defaults to `"cpu"`.
805810
new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
806811
key intact. By default not prefix is set and the corresponding array to the key will be replaced.
807812
allow_missing_keys: don't raise exception if key is missing.
@@ -814,15 +819,26 @@ def __init__(
814819
self,
815820
keys: KeysCollection,
816821
kernel_size: int = 3,
817-
padding: Union[int, str] = "same",
822+
spatial_axes: Optional[Union[Sequence[int], int]] = None,
823+
normalize_kernels: bool = True,
824+
normalize_gradients: bool = False,
825+
padding_mode: str = "reflect",
818826
dtype: torch.dtype = torch.float32,
819-
device: Union[torch.device, int, str] = "cpu",
820827
new_key_prefix: Optional[str] = None,
821828
allow_missing_keys: bool = False,
822829
) -> None:
823830
super().__init__(keys, allow_missing_keys)
824-
self.transform = SobelGradients(kernel_size=kernel_size, padding=padding, dtype=dtype, device=device)
831+
self.transform = SobelGradients(
832+
kernel_size=kernel_size,
833+
spatial_axes=spatial_axes,
834+
normalize_kernels=normalize_kernels,
835+
normalize_gradients=normalize_gradients,
836+
padding_mode=padding_mode,
837+
dtype=dtype,
838+
)
825839
self.new_key_prefix = new_key_prefix
840+
self.kernel_diff = self.transform.kernel_diff
841+
self.kernel_smooth = self.transform.kernel_smooth
826842

827843
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
828844
d = dict(data)

tests/test_hovernet_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,17 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w
141141

142142
TEST_CASE_3 = [ # batch size of 2, 3 classes with minor rotation of nuclear prediction
143143
{"prediction": inputs_test[3].inputs, "target": inputs_test[3].targets},
144-
6.5777,
144+
3.6169,
145145
]
146146

147147
TEST_CASE_4 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
148148
{"prediction": inputs_test[4].inputs, "target": inputs_test[4].targets},
149-
8.5143,
149+
4.5079,
150150
]
151151

152152
TEST_CASE_5 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
153153
{"prediction": inputs_test[5].inputs, "target": inputs_test[5].targets},
154-
10.1705,
154+
5.4663,
155155
]
156156

157157
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]

0 commit comments

Comments
 (0)