1414"""
1515
1616import warnings
17- from typing import Callable , Iterable , Optional , Sequence , Union
17+ from typing import Callable , Iterable , Optional , Sequence , Tuple , Union
1818
1919import numpy as np
2020import torch
21+ import torch .nn .functional as F
2122
2223from monai .config .type_definitions import NdarrayOrTensor
2324from monai .data .meta_obj import get_track_meta
2425from monai .data .meta_tensor import MetaTensor
2526from monai .networks import one_hot
26- from monai .networks .layers import GaussianFilter , apply_filter
27+ from monai .networks .layers import GaussianFilter , apply_filter , separable_filtering
2728from monai .transforms .inverse import InvertibleTransform
2829from monai .transforms .transform import Transform
2930from monai .transforms .utils import (
@@ -821,13 +822,18 @@ def __call__(self, data):
821822
822823
823824class SobelGradients (Transform ):
824- """Calculate Sobel horizontal and vertical gradients
825+ """Calculate Sobel gradients of a grayscale image with the shape of (CxH[xWxDx...]).
825826
826827 Args:
827828 kernel_size: the size of the Sobel kernel. Defaults to 3.
828- padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
829+ spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
830+ along each of the provide axis. By default it calculate the gradient for all spatial axes.
831+ normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
832+ normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
833+ padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
834+ Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
835+ See ``torch.nn.Conv1d()`` for more information.
829836 dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
830- device: the device to create the kernel on. Defaults to `"cpu"`.
831837
832838 """
833839
@@ -836,36 +842,90 @@ class SobelGradients(Transform):
836842 def __init__ (
837843 self ,
838844 kernel_size : int = 3 ,
839- padding : Union [int , str ] = "same" ,
845+ spatial_axes : Optional [Union [Sequence [int ], int ]] = None ,
846+ normalize_kernels : bool = True ,
847+ normalize_gradients : bool = False ,
848+ padding_mode : str = "reflect" ,
840849 dtype : torch .dtype = torch .float32 ,
841- device : Union [torch .device , int , str ] = "cpu" ,
842850 ) -> None :
843851 super ().__init__ ()
844- self .kernel : torch .Tensor = self ._get_kernel (kernel_size , dtype , device )
845- self .padding = padding
846-
847- def _get_kernel (self , size , dtype , device ) -> torch .Tensor :
852+ self .padding = padding_mode
853+ self .spatial_axes = spatial_axes
854+ self .normalize_kernels = normalize_kernels
855+ self .normalize_gradients = normalize_gradients
856+ self .kernel_diff , self .kernel_smooth = self ._get_kernel (kernel_size , dtype )
857+
858+ def _get_kernel (self , size , dtype ) -> Tuple [torch .Tensor , torch .Tensor ]:
859+ if size < 3 :
860+ raise ValueError (f"Sobel kernel size should be at least three. { size } was given." )
848861 if size % 2 == 0 :
849862 raise ValueError (f"Sobel kernel size should be an odd number. { size } was given." )
850- if not dtype .is_floating_point :
851- raise ValueError (f"`dtype` for Sobel kernel should be floating point. { dtype } was given." )
852-
853- numerator : torch .Tensor = torch .arange (
854- - size // 2 + 1 , size // 2 + 1 , dtype = dtype , device = device , requires_grad = False
855- ).expand (size , size )
856- denominator = numerator * numerator
857- denominator = denominator + denominator .T
858- denominator [:, size // 2 ] = 1.0 # to avoid division by zero
859- kernel = numerator / denominator
860- return kernel
863+
864+ kernel_diff = torch .tensor ([[[- 1 , 0 , 1 ]]], dtype = dtype )
865+ kernel_smooth = torch .tensor ([[[1 , 2 , 1 ]]], dtype = dtype )
866+ kernel_expansion = torch .tensor ([[[1 , 2 , 1 ]]], dtype = dtype )
867+
868+ if self .normalize_kernels :
869+ if not dtype .is_floating_point :
870+ raise ValueError (
871+ f"`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. { dtype } was given."
872+ )
873+ kernel_diff /= 2.0
874+ kernel_smooth /= 4.0
875+ kernel_expansion /= 4.0
876+
877+ # Expand the kernel to larger size than 3
878+ expand = (size - 3 ) // 2
879+ for _ in range (expand ):
880+ kernel_diff = F .conv1d (kernel_diff , kernel_expansion , padding = 2 )
881+ kernel_smooth = F .conv1d (kernel_smooth , kernel_expansion , padding = 2 )
882+
883+ return kernel_diff .squeeze (), kernel_smooth .squeeze ()
861884
862885 def __call__ (self , image : NdarrayOrTensor ) -> torch .Tensor :
863886 image_tensor = convert_to_tensor (image , track_meta = get_track_meta ())
864- kernel_v = self .kernel .to (image_tensor .device )
865- kernel_h = kernel_v .T
866- image_tensor = image_tensor .unsqueeze (0 ) # adds a batch dim
867- grad_v = apply_filter (image_tensor , kernel_v , padding = self .padding )
868- grad_h = apply_filter (image_tensor , kernel_h , padding = self .padding )
869- grad = torch .cat ([grad_h , grad_v ], dim = 1 )
870- grad , * _ = convert_to_dst_type (grad .squeeze (0 ), image_tensor )
871- return grad
887+
888+ # Check/set spatial axes
889+ n_spatial_dims = image_tensor .ndim - 1 # excluding the channel dimension
890+ valid_spatial_axes = list (range (n_spatial_dims )) + list (range (- n_spatial_dims , 0 ))
891+
892+ # Check gradient axes to be valid
893+ if self .spatial_axes is None :
894+ spatial_axes = list (range (n_spatial_dims ))
895+ else :
896+ invalid_axis = set (ensure_tuple (self .spatial_axes )) - set (valid_spatial_axes )
897+ if invalid_axis :
898+ raise ValueError (
899+ f"The provide axes to calculate gradient is not valid: { invalid_axis } . "
900+ f"The image has { n_spatial_dims } spatial dimensions so it should be: { valid_spatial_axes } ."
901+ )
902+ spatial_axes = [ax % n_spatial_dims if ax < 0 else ax for ax in ensure_tuple (self .spatial_axes )]
903+
904+ # Add batch dimension for separable_filtering
905+ image_tensor = image_tensor .unsqueeze (0 )
906+
907+ # Get the Sobel kernels
908+ kernel_diff = self .kernel_diff .to (image_tensor .device )
909+ kernel_smooth = self .kernel_smooth .to (image_tensor .device )
910+
911+ # Calculate gradient
912+ grad_list = []
913+ for ax in spatial_axes :
914+ kernels = [kernel_smooth ] * n_spatial_dims
915+ kernels [ax - 1 ] = kernel_diff
916+ grad = separable_filtering (image_tensor , kernels , mode = self .padding )
917+ if self .normalize_gradients :
918+ grad_min = grad .min ()
919+ if grad_min != grad .max ():
920+ grad -= grad_min
921+ grad_max = grad .max ()
922+ if grad_max > 0 :
923+ grad /= grad_max
924+ grad_list .append (grad )
925+
926+ grads = torch .cat (grad_list , dim = 1 )
927+
928+ # Remove batch dimension and convert the gradient type to be the same as input image
929+ grads = convert_to_dst_type (grads .squeeze (0 ), image_tensor )[0 ]
930+
931+ return grads
0 commit comments