4848 "RandBiasField" ,
4949 "ScaleIntensity" ,
5050 "RandScaleIntensity" ,
51+ "ScaleIntensityFixedMean" ,
52+ "RandScaleIntensityFixedMean" ,
5153 "NormalizeIntensity" ,
5254 "ThresholdIntensity" ,
5355 "ScaleIntensityRange" ,
@@ -466,6 +468,161 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
466468 return ret
467469
468470
471+ class ScaleIntensityFixedMean (Transform ):
472+ """
473+ Scale the intensity of input image by ``v = v * (1 + factor)``, then shift the output so that the output image has the
474+ same mean as the input.
475+ """
476+
477+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
478+
479+ def __init__ (
480+ self ,
481+ factor : float = 0 ,
482+ preserve_range : bool = False ,
483+ fixed_mean : bool = True ,
484+ channel_wise : bool = False ,
485+ dtype : DtypeLike = np .float32 ,
486+ ) -> None :
487+ """
488+ Args:
489+ factor: factor scale by ``v = v * (1 + factor)``.
490+ preserve_range: clips the output array/tensor to the range of the input array/tensor
491+ fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
492+ to ensure that the output has the same mean as the input.
493+ channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
494+ on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
495+ channel of the image if True.
496+ dtype: output data type, if None, same as input image. defaults to float32.
497+ """
498+ self .factor = factor
499+ self .preserve_range = preserve_range
500+ self .fixed_mean = fixed_mean
501+ self .channel_wise = channel_wise
502+ self .dtype = dtype
503+
504+ def __call__ (self , img : NdarrayOrTensor , factor = None ) -> NdarrayOrTensor :
505+ """
506+ Apply the transform to `img`.
507+ Args:
508+ img: the input tensor/array
509+ factor: factor scale by ``v = v * (1 + factor)``
510+
511+ """
512+
513+ factor = factor if factor is not None else self .factor
514+
515+ img = convert_to_tensor (img , track_meta = get_track_meta ())
516+ img_t = convert_to_tensor (img , track_meta = False )
517+ ret : NdarrayOrTensor
518+ if self .channel_wise :
519+ out = []
520+ for d in img_t :
521+ if self .preserve_range :
522+ clip_min = d .min ()
523+ clip_max = d .max ()
524+
525+ if self .fixed_mean :
526+ mn = d .mean ()
527+ d = d - mn
528+
529+ out_channel = d * (1 + factor )
530+
531+ if self .fixed_mean :
532+ out_channel = out_channel + mn
533+
534+ if self .preserve_range :
535+ out_channel = clip (out_channel , clip_min , clip_max )
536+
537+ out .append (out_channel )
538+ ret = torch .stack (out ) # type: ignore
539+ else :
540+ if self .preserve_range :
541+ clip_min = img_t .min ()
542+ clip_max = img_t .max ()
543+
544+ if self .fixed_mean :
545+ mn = img_t .mean ()
546+ img_t = img_t - mn
547+
548+ ret = img_t * (1 + factor )
549+
550+ if self .fixed_mean :
551+ ret = ret + mn
552+
553+ if self .preserve_range :
554+ ret = clip (ret , clip_min , clip_max )
555+
556+ ret = convert_to_dst_type (ret , dst = img , dtype = self .dtype or img_t .dtype )[0 ]
557+ return ret
558+
559+
560+ class RandScaleIntensityFixedMean (RandomizableTransform ):
561+ """
562+ Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
563+ is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling
564+ to ensure that the output has the same mean as the input.
565+ """
566+
567+ backend = ScaleIntensityFixedMean .backend
568+
569+ def __init__ (
570+ self ,
571+ prob : float = 0.1 ,
572+ factors : Sequence [float ] | float = 0 ,
573+ fixed_mean : bool = True ,
574+ preserve_range : bool = False ,
575+ dtype : DtypeLike = np .float32 ,
576+ ) -> None :
577+ """
578+ Args:
579+ factors: factor range to randomly scale by ``v = v * (1 + factor)``.
580+ if single number, factor value is picked from (-factors, factors).
581+ preserve_range: clips the output array/tensor to the range of the input array/tensor
582+ fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
583+ to ensure that the output has the same mean as the input.
584+ channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
585+ on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
586+ channel of the image if True.
587+ dtype: output data type, if None, same as input image. defaults to float32.
588+
589+ """
590+ RandomizableTransform .__init__ (self , prob )
591+ if isinstance (factors , (int , float )):
592+ self .factors = (min (- factors , factors ), max (- factors , factors ))
593+ elif len (factors ) != 2 :
594+ raise ValueError ("factors should be a number or pair of numbers." )
595+ else :
596+ self .factors = (min (factors ), max (factors ))
597+ self .factor = self .factors [0 ]
598+ self .fixed_mean = fixed_mean
599+ self .preserve_range = preserve_range
600+ self .dtype = dtype
601+
602+ self .scaler = ScaleIntensityFixedMean (
603+ factor = self .factor , fixed_mean = self .fixed_mean , preserve_range = self .preserve_range , dtype = self .dtype
604+ )
605+
606+ def randomize (self , data : Any | None = None ) -> None :
607+ super ().randomize (None )
608+ if not self ._do_transform :
609+ return None
610+ self .factor = self .R .uniform (low = self .factors [0 ], high = self .factors [1 ])
611+
612+ def __call__ (self , img : NdarrayOrTensor , randomize : bool = True ) -> NdarrayOrTensor :
613+ """
614+ Apply the transform to `img`.
615+ """
616+ img = convert_to_tensor (img , track_meta = get_track_meta ())
617+ if randomize :
618+ self .randomize ()
619+
620+ if not self ._do_transform :
621+ return convert_data_type (img , dtype = self .dtype )[0 ]
622+
623+ return self .scaler (img , self .factor )
624+
625+
469626class RandScaleIntensity (RandomizableTransform ):
470627 """
471628 Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
@@ -799,48 +956,99 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
799956
800957class AdjustContrast (Transform ):
801958 """
802- Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
959+ Changes image intensity with gamma transform . Each pixel/voxel intensity is updated as::
803960
804961 x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
805962
806963 Args:
807964 gamma: gamma value to adjust the contrast as function.
965+ invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
966+ values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
967+ from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
968+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
969+ function.
970+ retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
971+ ensure that the output intensity distribution has the same mean and standard deviation as the intensity
972+ distribution of the input. This behaviour is mimicked from `nnU-Net
973+ <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
974+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
975+ function.
808976 """
809977
810978 backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
811979
812- def __init__ (self , gamma : float ) -> None :
980+ def __init__ (self , gamma : float , invert_image : bool = False , retain_stats : bool = False ) -> None :
813981 if not isinstance (gamma , (int , float )):
814982 raise ValueError (f"gamma must be a float or int number, got { type (gamma )} { gamma } ." )
815983 self .gamma = gamma
984+ self .invert_image = invert_image
985+ self .retain_stats = retain_stats
816986
817- def __call__ (self , img : NdarrayOrTensor ) -> NdarrayOrTensor :
987+ def __call__ (self , img : NdarrayOrTensor , gamma = None ) -> NdarrayOrTensor :
818988 """
819989 Apply the transform to `img`.
990+ gamma: gamma value to adjust the contrast as function.
820991 """
821992 img = convert_to_tensor (img , track_meta = get_track_meta ())
993+ gamma = gamma if gamma is not None else self .gamma
994+
995+ if self .invert_image :
996+ img = - img
997+
998+ if self .retain_stats :
999+ mn = img .mean ()
1000+ sd = img .std ()
1001+
8221002 epsilon = 1e-7
8231003 img_min = img .min ()
8241004 img_range = img .max () - img_min
825- ret : NdarrayOrTensor = ((img - img_min ) / float (img_range + epsilon )) ** self .gamma * img_range + img_min
1005+ ret : NdarrayOrTensor = ((img - img_min ) / float (img_range + epsilon )) ** gamma * img_range + img_min
1006+
1007+ if self .retain_stats :
1008+ # zero mean and normalize
1009+ ret = ret - ret .mean ()
1010+ ret = ret / (ret .std () + 1e-8 )
1011+ # restore old mean and standard deviation
1012+ ret = sd * ret + mn
1013+
1014+ if self .invert_image :
1015+ ret = - ret
1016+
8261017 return ret
8271018
8281019
8291020class RandAdjustContrast (RandomizableTransform ):
8301021 """
831- Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as: :
1022+ Randomly changes image intensity with gamma transform . Each pixel/voxel intensity is updated as:
8321023
8331024 x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
8341025
8351026 Args:
8361027 prob: Probability of adjustment.
8371028 gamma: Range of gamma values.
8381029 If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).
1030+ invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
1031+ values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
1032+ from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
1033+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
1034+ function.
1035+ retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
1036+ ensure that the output intensity distribution has the same mean and standard deviation as the intensity
1037+ distribution of the input. This behaviour is mimicked from `nnU-Net
1038+ <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
1039+ <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
1040+ function.
8391041 """
8401042
8411043 backend = AdjustContrast .backend
8421044
843- def __init__ (self , prob : float = 0.1 , gamma : Sequence [float ] | float = (0.5 , 4.5 )) -> None :
1045+ def __init__ (
1046+ self ,
1047+ prob : float = 0.1 ,
1048+ gamma : Sequence [float ] | float = (0.5 , 4.5 ),
1049+ invert_image : bool = False ,
1050+ retain_stats : bool = False ,
1051+ ) -> None :
8441052 RandomizableTransform .__init__ (self , prob )
8451053
8461054 if isinstance (gamma , (int , float )):
@@ -854,7 +1062,13 @@ def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5
8541062 else :
8551063 self .gamma = (min (gamma ), max (gamma ))
8561064
857- self .gamma_value : float | None = None
1065+ self .gamma_value : float = 1.0
1066+ self .invert_image : bool = invert_image
1067+ self .retain_stats : bool = retain_stats
1068+
1069+ self .adjust_contrast = AdjustContrast (
1070+ self .gamma_value , invert_image = self .invert_image , retain_stats = self .retain_stats
1071+ )
8581072
8591073 def randomize (self , data : Any | None = None ) -> None :
8601074 super ().randomize (None )
@@ -875,7 +1089,8 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
8751089
8761090 if self .gamma_value is None :
8771091 raise RuntimeError ("gamma_value is not set, please call `randomize` function first." )
878- return AdjustContrast (self .gamma_value )(img )
1092+
1093+ return self .adjust_contrast (img , self .gamma_value )
8791094
8801095
8811096class ScaleIntensityRangePercentiles (Transform ):
0 commit comments