diff --git a/.gitignore b/.gitignore index 76c6ab0d12..1470ae8375 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ runs *.pth *zarr/* + +monai-dev/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c6ddd262a..8c741caf0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to MONAI are documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] +### Added +* Added `RandNonCentralChiNoise` and `RandNonCentralChiNoised` for generalized Rician noise simulation in MRI. ## [1.5.1] - 2025-09-22 diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d15042181b..6a12541875 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -117,6 +117,7 @@ RandHistogramShift, RandIntensityRemap, RandKSpaceSpikeNoise, + RandNonCentralChiNoise, RandRicianNoise, RandScaleIntensity, RandScaleIntensityFixedMean, @@ -199,6 +200,9 @@ RandKSpaceSpikeNoised, RandKSpaceSpikeNoiseD, RandKSpaceSpikeNoiseDict, + RandNonCentralChiNoised, + RandNonCentralChiNoiseD, + RandNonCentralChiNoiseDict, RandRicianNoised, RandRicianNoiseD, RandRicianNoiseDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 0421d34492..4e9e9b80a2 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -41,6 +41,7 @@ __all__ = [ "RandGaussianNoise", + "RandNonCentralChiNoise", "RandRicianNoise", "ShiftIntensity", "RandShiftIntensity", @@ -140,6 +141,149 @@ def __call__(self, img: NdarrayOrTensor, mean: float | None = None, randomize: b return img + noise +class RandNonCentralChiNoise(RandomizableTransform): + """ + Add non-central chi noise to an image. + This distribution is the square root of the sum of squares of k independent + Gaussian random variables, where one of the variables has a non-zero mean + (the signal). + This is a generalization of Rician noise. `degrees_of_freedom=2` is Rician noise. + See: https://en.wikipedia.org/wiki/Noncentral_chi_distribution and https://archive.ismrm.org/2024/3123_NZkvJdQat.html + + Args: + prob: Probability to add noise. + mean: Mean or "centre" of the Gaussian noise distributions. + std: Standard deviation (spread) of the Gaussian noise distributions. + degrees_of_freedom: Number of Gaussian distributions (degrees of freedom). + `degrees_of_freedom=2` is Rician noise. + channel_wise: If True, treats each channel of the image separately. + relative: If True, the spread of the sampled Gaussian distributions will + be std times the standard deviation of the image or channel's intensity + histogram. + sample_std: If True, sample the spread of the Gaussian distributions + uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + prob: float = 0.1, + mean: Sequence[float] | float = 0.0, + std: Sequence[float] | float = 1.0, + degrees_of_freedom: int = 64, # 64 default because typical modern brain MRI is 32 quadrature coils + channel_wise: bool = False, + relative: bool = False, + sample_std: bool = True, + dtype: DtypeLike = np.float32, + ) -> None: + """ + Initializes the transform. + + Args: + prob: Probability to add noise. + mean: Mean of the Gaussian noise distributions. + std: Standard deviation (spread) of the Gaussian noise distributions. + degrees_of_freedom: Number of Gaussian distributions (degrees of freedom). + `degrees_of_freedom=2` is Rician noise. Defaults to 64 (32 quadrature coils). + channel_wise: If True, treats each channel of the image separately. + relative: If True, the spread of the sampled Gaussian distributions will + be std times the standard deviation of the image or channel's intensity + histogram. + sample_std: If True, sample the spread of the Gaussian distributions + uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. + + Raises: + ValueError: If `degrees_of_freedom` is not an integer or is less than 1. + """ + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.mean = mean + self.std = std + if not isinstance(degrees_of_freedom, int) or degrees_of_freedom < 1: + raise ValueError("degrees_of_freedom must be an integer >= 1.") + self.degrees_of_freedom = degrees_of_freedom + self.channel_wise = channel_wise + self.relative = relative + self.sample_std = sample_std + self.dtype = dtype + + def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float, k: int): + """ + Applies non-central chi noise to a single image or channel. + + This method generates `k` Gaussian noise arrays, adds the input `img` + to the first one (as the non-centrality component), and then computes + the square root of the sum of squares. + + Args: + img: Input image array. + mean: Mean for the Gaussian noise distributions. + std: Standard deviation for the Gaussian noise distributions. + k: Degrees of freedom (number of noise arrays). + + Returns: + Image with non-central chi noise applied, with the same + backend (Numpy/Torch) as the input. + """ + dtype_np = get_equivalent_dtype(img.dtype, np.ndarray) + im_shape = img.shape + _std = self.R.uniform(0, std) if self.sample_std else std + + # Create a stack of k noise arrays + noise_shape = (k, *im_shape) + all_noises_np = self.R.normal(mean, _std, size=noise_shape).astype(dtype_np, copy=False) + + if isinstance(img, torch.Tensor): + all_noises = torch.tensor(all_noises_np, device=img.device) + all_noises[0] = all_noises[0] + img + sum_sq = torch.sum(all_noises**2, dim=0) + return torch.sqrt(sum_sq) + + all_noises_np[0] = all_noises_np[0] + img + sum_sq = np.sum(all_noises_np**2, axis=0) + return np.sqrt(sum_sq) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + src = img + img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype) + if randomize: + super().randomize(None) + + if not self._do_transform: + img, *_ = convert_to_dst_type(img, dst=src, dtype=self.dtype) + return img + + if self.channel_wise: + _mean = ensure_tuple_rep(self.mean, len(img)) + _std = ensure_tuple_rep(self.std, len(img)) + for i, d in enumerate(img): + img[i] = self._add_noise( + d, + mean=_mean[i], + std=_std[i] * d.std() if self.relative else _std[i], + k=self.degrees_of_freedom, + ) + else: + if not isinstance(self.mean, (int, float)): + raise RuntimeError(f"If channel_wise is False, mean must be a float or int, got {type(self.mean)}.") + if not isinstance(self.std, (int, float)): + raise RuntimeError(f"If channel_wise is False, std must be a float or int, got {type(self.std)}.") + std = self.std * img.std().item() if self.relative else self.std + if not isinstance(std, (int, float)): + raise RuntimeError(f"std must be a float or int number, got {type(std)}.") + img = self._add_noise(img, mean=self.mean, std=std, k=self.degrees_of_freedom) + + img, *_ = convert_to_dst_type(img, dst=src, dtype=self.dtype) + return img + + class RandRicianNoise(RandomizableTransform): """ Add Rician noise to image. diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 3d29b3031d..c2cfaf2707 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -48,6 +48,7 @@ RandGibbsNoise, RandHistogramShift, RandKSpaceSpikeNoise, + RandNonCentralChiNoise, RandRicianNoise, RandScaleIntensity, RandScaleIntensityFixedMean, @@ -69,6 +70,9 @@ __all__ = [ "RandGaussianNoised", "RandRicianNoised", + "RandNonCentralChiNoised", + "RandNonCentralChiNoiseD", + "RandNonCentralChiNoiseDict", "ShiftIntensityd", "RandShiftIntensityd", "ScaleIntensityd", @@ -236,6 +240,81 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class RandNonCentralChiNoised(RandomizableTransform, MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.RandNonCentralChiNoise`. + Add non-central chi noise to image. This transform assumes all the expected fields have same shape, if want to add + different noise for every field, please use this transform separately. + This is a generalization of Rician noise. `degrees_of_freedom=2` is Rician noise. + + Args: + keys: Keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + prob: Probability to add non-central chi noise to the dictionary. + mean: Mean or "centre" of the Gaussian distributions sampled to make up + the noise. + std: Standard deviation (spread) of the Gaussian distributions sampled + to make up the noise. + degrees_of_freedom: Number of Gaussian distributions (degrees of freedom). + `degrees_of_freedom=2` is Rician noise. + channel_wise: If True, treats each channel of the image separately. + relative: If True, the spread of the sampled Gaussian distributions will + be std times the standard deviation of the image or channel's intensity + histogram. + sample_std: If True, sample the spread of the Gaussian distributions + uniformly from 0 to std. + dtype: output data type, if None, same as input image. defaults to float32. + allow_missing_keys: Don't raise exception if key is missing. + """ + + backend = RandNonCentralChiNoise.backend + + def __init__( + self, + keys: KeysCollection, + prob: float = 0.1, + mean: Sequence[float] | float = 0.0, + std: Sequence[float] | float = 1.0, + degrees_of_freedom: int = 64, + channel_wise: bool = False, + relative: bool = False, + sample_std: bool = True, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.rand_non_central_chi_noise = RandNonCentralChiNoise( + prob=1.0, + mean=mean, + std=std, + degrees_of_freedom=degrees_of_freedom, + channel_wise=channel_wise, + relative=relative, + sample_std=sample_std, + dtype=dtype, + ) + + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandNonCentralChiNoised: + super().set_random_state(seed, state) + self.rand_non_central_chi_noise.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + for key in self.key_iterator(d): + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + return d + + for key in self.key_iterator(d): + d[key] = self.rand_non_central_chi_noise(d[key], randomize=True) + return d + + class RandRicianNoised(RandomizableTransform, MapTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`. @@ -1953,6 +2032,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised +RandNonCentralChiNoiseD = RandNonCentralChiNoiseDict = RandNonCentralChiNoised ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd RandShiftIntensityD = RandShiftIntensityDict = RandShiftIntensityd StdShiftIntensityD = StdShiftIntensityDict = StdShiftIntensityd diff --git a/tests/transforms/test_rand_noncentralchi_noise.py b/tests/transforms/test_rand_noncentralchi_noise.py new file mode 100644 index 0000000000..41efb6b8de --- /dev/null +++ b/tests/transforms/test_rand_noncentralchi_noise.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandNonCentralChiNoise +from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(("test_zero_mean", p, 0, 0.1)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5)) + + +class TestRandNonCentralChiNoise(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, in_type, mean, std): + seed = 0 + degrees_of_freedom = 64 # 64 is common due to 32 channel head coil + noise_fn = RandNonCentralChiNoise(prob=1.0, mean=mean, std=std, degrees_of_freedom=degrees_of_freedom) + noise_fn.set_random_state(seed) + im = in_type(self.imt) + noised = noise_fn(im) + if isinstance(im, torch.Tensor): + self.assertEqual(im.dtype, noised.dtype) + np.random.seed(seed) + np.random.random() + _std = np.random.uniform(0, std) + + noise_shape = (degrees_of_freedom, *self.imt.shape) + all_noises = np.random.normal(mean, _std, size=noise_shape).astype(np.float32) + all_noises[0] += self.imt + sum_sq = np.sum(all_noises**2, axis=0) + expected = np.sqrt(sum_sq) + + if isinstance(noised, torch.Tensor): + noised = noised.cpu() + np.testing.assert_allclose(expected, noised, atol=1e-5) + + @parameterized.expand(TESTS) + def test_correct_results_dof2(self, _, in_type, mean, std): + """ + Test with k=2 (the Rician case) + """ + seed = 0 + degrees_of_freedom = 2 + noise_fn = RandNonCentralChiNoise(prob=1.0, mean=mean, std=std, degrees_of_freedom=degrees_of_freedom) + noise_fn.set_random_state(seed) + im = in_type(self.imt) + noised = noise_fn(im) + if isinstance(im, torch.Tensor): + self.assertEqual(im.dtype, noised.dtype) + + np.random.seed(seed) + np.random.random() # for prob + _std = np.random.uniform(0, std) # for sample_std + noise_shape = (degrees_of_freedom, *self.imt.shape) + all_noises = np.random.normal(mean, _std, size=noise_shape).astype(np.float32) + all_noises[0] += self.imt + sum_sq = np.sum(all_noises**2, axis=0) + expected = np.sqrt(sum_sq) + + if isinstance(noised, torch.Tensor): + noised = noised.cpu() + np.testing.assert_allclose(expected, noised, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_rand_noncentralchi_noised.py b/tests/transforms/test_rand_noncentralchi_noised.py new file mode 100644 index 0000000000..6bf50dc721 --- /dev/null +++ b/tests/transforms/test_rand_noncentralchi_noised.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandNonCentralChiNoised +from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) + +seed = 0 + + +class TestRandNonCentralChiNoised(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, in_type, keys, mean, std): + degrees_of_freedom = 64 + noise_fn = RandNonCentralChiNoised( + keys=keys, + prob=1.0, + mean=mean, + std=std, + degrees_of_freedom=degrees_of_freedom, + dtype=np.float64, + ) + noise_fn.set_random_state(seed) + noised = noise_fn({k: in_type(self.imt) for k in keys}) + np.random.seed(seed) + for k in keys: + # simulate the `randomize` function of transform + np.random.random() + _std = np.random.uniform(0, std) + noise_shape = (degrees_of_freedom, *self.imt.shape) + all_noises = np.random.normal(mean, _std, size=noise_shape).astype(np.float32) + all_noises[0] += self.imt + sum_sq = np.sum(all_noises**2, axis=0) + expected = np.sqrt(sum_sq) + if isinstance(noised[k], torch.Tensor): + noised[k] = noised[k].cpu() + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) + + @parameterized.expand(TESTS) + def test_correct_results_k2(self, _, in_type, keys, mean, std): + degrees_of_freedom = 2 + noise_fn = RandNonCentralChiNoised( + keys=keys, + prob=1.0, + mean=mean, + std=std, + degrees_of_freedom=degrees_of_freedom, + dtype=np.float64, + ) + noise_fn.set_random_state(seed) + noised = noise_fn({k: in_type(self.imt) for k in keys}) + np.random.seed(seed) + for k in keys: + np.random.random() + _std = np.random.uniform(0, std) + + noise_shape = (degrees_of_freedom, *self.imt.shape) + all_noises = np.random.normal(mean, _std, size=noise_shape).astype(np.float32) + all_noises[0] += self.imt + sum_sq = np.sum(all_noises**2, axis=0) + expected = np.sqrt(sum_sq) + + if isinstance(noised[k], torch.Tensor): + noised[k] = noised[k].cpu() + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main()