|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | + |
| 13 | +from abc import abstractmethod |
| 14 | +from typing import Sequence |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +from torch import Tensor |
| 18 | + |
| 19 | +from monai.apps.reconstruction.complex_utils import complex_abs, convert_to_tensor_complex |
| 20 | +from monai.apps.reconstruction.mri_utils import root_sum_of_squares |
| 21 | +from monai.config.type_definitions import NdarrayOrTensor |
| 22 | +from monai.data.fft_utils import ifftn_centered |
| 23 | +from monai.transforms.transform import RandomizableTransform |
| 24 | +from monai.utils.enums import TransformBackends |
| 25 | +from monai.utils.type_conversion import convert_to_tensor |
| 26 | + |
| 27 | + |
| 28 | +class KspaceMask(RandomizableTransform): |
| 29 | + """ |
| 30 | + A basic class for under-sampling mask setup. It provides common |
| 31 | + features for under-sampling mask generators. |
| 32 | + For example, RandomMaskFunc and EquispacedMaskFunc (two mask |
| 33 | + transform objects defined right after this module) |
| 34 | + both inherit MaskFunc to properly setup properties like the |
| 35 | + acceleration factor. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + center_fractions: Sequence[float], |
| 41 | + accelerations: Sequence[float], |
| 42 | + spatial_dims: int = 2, |
| 43 | + is_complex: bool = True, |
| 44 | + ): |
| 45 | + """ |
| 46 | + Args: |
| 47 | + center_fractions: Fraction of low-frequency columns to be retained. |
| 48 | + If multiple values are provided, then one of these numbers |
| 49 | + is chosen uniformly each time. |
| 50 | + accelerations: Amount of under-sampling. This should have the |
| 51 | + same length as center_fractions. If multiple values are |
| 52 | + provided, then one of these is chosen uniformly each time. |
| 53 | + spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data; |
| 54 | + it's also 2 for psuedo-3D datasets like the fastMRI dataset). |
| 55 | + The last spatial dim is selected for sampling. For the fastMRI |
| 56 | + dataset, k-space has the form (...,num_slices,num_coils,H,W) |
| 57 | + and sampling is done along W. For a general 3D data with the |
| 58 | + shape (...,num_coils,H,W,D), sampling is done along D. |
| 59 | + is_complex: if True, then the last dimension will be reserved for |
| 60 | + real/imaginary parts. |
| 61 | + """ |
| 62 | + if len(center_fractions) != len(accelerations): |
| 63 | + raise ValueError( |
| 64 | + "Number of center fractions \ |
| 65 | + should match number of accelerations" |
| 66 | + ) |
| 67 | + |
| 68 | + self.center_fractions = center_fractions |
| 69 | + self.accelerations = accelerations |
| 70 | + self.spatial_dims = spatial_dims |
| 71 | + self.is_complex = is_complex |
| 72 | + |
| 73 | + @abstractmethod |
| 74 | + def __call__(self, kspace: NdarrayOrTensor): |
| 75 | + """ |
| 76 | + This is an extra instance to allow for defining new mask generators. |
| 77 | + For creating other mask transforms, define a new class and simply |
| 78 | + override __call__. See an example of this in |
| 79 | + :py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`. |
| 80 | +
|
| 81 | + Args: |
| 82 | + kspace: The input k-space data. The shape is (...,num_coils,H,W,2) |
| 83 | + for complex 2D inputs and (...,num_coils,H,W,D) for real 3D |
| 84 | + data. |
| 85 | + """ |
| 86 | + raise NotImplementedError |
| 87 | + |
| 88 | + def randomize_choose_acceleration(self) -> Sequence[float]: |
| 89 | + """ |
| 90 | + If multiple values are provided for center_fractions and |
| 91 | + accelerations, this function selects one value uniformly |
| 92 | + for each training/test sample. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + A tuple containing |
| 96 | + (1) center_fraction: chosen fraction of center kspace |
| 97 | + lines to exclude from under-sampling |
| 98 | + (2) acceleration: chosen acceleration factor |
| 99 | + """ |
| 100 | + choice = self.R.randint(0, len(self.accelerations)) |
| 101 | + center_fraction = self.center_fractions[choice] |
| 102 | + acceleration = self.accelerations[choice] |
| 103 | + return center_fraction, acceleration |
| 104 | + |
| 105 | + |
| 106 | +class RandomKspaceMask(KspaceMask): |
| 107 | + """ |
| 108 | + This k-space mask transform under-samples the k-space according to a |
| 109 | + random sampling pattern. Precisely, it uniformly selects a subset of |
| 110 | + columns from the input k-space data. If the k-space data has N columns, |
| 111 | + the mask picks out: |
| 112 | +
|
| 113 | + 1. N_low_freqs = (N * center_fraction) columns in the center |
| 114 | + corresponding to low-frequencies |
| 115 | +
|
| 116 | + 2. The other columns are selected uniformly at random with a probability |
| 117 | + equal to: |
| 118 | + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). |
| 119 | + This ensures that the expected number of columns selected is equal to |
| 120 | + (N / acceleration) |
| 121 | +
|
| 122 | + It is possible to use multiple center_fractions and accelerations, |
| 123 | + in which case one possible (center_fraction, acceleration) is chosen |
| 124 | + uniformly at random each time the transform is called. |
| 125 | +
|
| 126 | + Example: |
| 127 | + If accelerations = [4, 8] and center_fractions = [0.08, 0.04], |
| 128 | + then there is a 50% probability that 4-fold acceleration with 8% |
| 129 | + center fraction is selected and a 50% probability that 8-fold |
| 130 | + acceleration with 4% center fraction is selected. |
| 131 | +
|
| 132 | + Modified and adopted from: |
| 133 | + https://github.com/facebookresearch/fastMRI/tree/master/fastmri |
| 134 | + """ |
| 135 | + |
| 136 | + backend = [TransformBackends.TORCH] |
| 137 | + |
| 138 | + def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]: |
| 139 | + """ |
| 140 | + Args: |
| 141 | + kspace: The input k-space data. The shape is (...,num_coils,H,W,2) |
| 142 | + for complex 2D inputs and (...,num_coils,H,W,D) for real 3D |
| 143 | + data. The last spatial dim is selected for sampling. For the |
| 144 | + fastMRI dataset, k-space has the form |
| 145 | + (...,num_slices,num_coils,H,W) and sampling is done along W. |
| 146 | + For a general 3D data with the shape (...,num_coils,H,W,D), |
| 147 | + sampling is done along D. |
| 148 | +
|
| 149 | + Returns: |
| 150 | + A tuple containing |
| 151 | + (1) the under-sampled kspace |
| 152 | + (2) absolute value of the inverse fourier of the under-sampled kspace |
| 153 | + """ |
| 154 | + kspace_t = convert_to_tensor_complex(kspace) |
| 155 | + spatial_size = kspace_t.shape |
| 156 | + num_cols = spatial_size[-1] |
| 157 | + if self.is_complex: # for complex data |
| 158 | + num_cols = spatial_size[-2] |
| 159 | + |
| 160 | + center_fraction, acceleration = self.randomize_choose_acceleration() |
| 161 | + |
| 162 | + # Create the mask |
| 163 | + num_low_freqs = int(round(num_cols * center_fraction)) |
| 164 | + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) |
| 165 | + mask = self.R.uniform(size=num_cols) < prob |
| 166 | + pad = (num_cols - num_low_freqs + 1) // 2 |
| 167 | + mask[pad : pad + num_low_freqs] = True |
| 168 | + |
| 169 | + # Reshape the mask |
| 170 | + mask_shape = [1 for _ in spatial_size] |
| 171 | + if self.is_complex: |
| 172 | + mask_shape[-2] = num_cols |
| 173 | + else: |
| 174 | + mask_shape[-1] = num_cols |
| 175 | + |
| 176 | + mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32)) |
| 177 | + |
| 178 | + # under-sample the ksapce |
| 179 | + masked = mask * kspace_t |
| 180 | + masked_kspace: Tensor = convert_to_tensor(masked) |
| 181 | + self.mask = mask |
| 182 | + |
| 183 | + # compute inverse fourier of the masked kspace |
| 184 | + masked_kspace_ifft: Tensor = convert_to_tensor( |
| 185 | + complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex)) |
| 186 | + ) |
| 187 | + # combine coil images (it is assumed that the coil dimension is |
| 188 | + # the first dimension before spatial dimensions) |
| 189 | + masked_kspace_ifft_rss: Tensor = convert_to_tensor( |
| 190 | + root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1) |
| 191 | + ) |
| 192 | + return masked_kspace, masked_kspace_ifft_rss |
| 193 | + |
| 194 | + |
| 195 | +class EquispacedKspaceMask(KspaceMask): |
| 196 | + """ |
| 197 | + This k-space mask transform under-samples the k-space according to an |
| 198 | + equi-distant sampling pattern. Precisely, it selects an equi-distant |
| 199 | + subset of columns from the input k-space data. If the k-space data has N |
| 200 | + columns, the mask picks out: |
| 201 | +
|
| 202 | + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding |
| 203 | + to low-frequencies |
| 204 | +
|
| 205 | + 2. The other columns are selected with equal spacing at a proportion that |
| 206 | + reaches the desired acceleration rate taking into consideration the number |
| 207 | + of low frequencies. This ensures that the expected number of columns |
| 208 | + selected is equal to (N / acceleration) |
| 209 | +
|
| 210 | + It is possible to use multiple center_fractions and accelerations, in |
| 211 | + which case one possible (center_fraction, acceleration) is chosen |
| 212 | + uniformly at random each time the EquispacedMaskFunc object is called. |
| 213 | +
|
| 214 | + Example: |
| 215 | + If accelerations = [4, 8] and center_fractions = [0.08, 0.04], |
| 216 | + then there is a 50% probability that 4-fold acceleration with 8% |
| 217 | + center fraction is selected and a 50% probability that 8-fold |
| 218 | + acceleration with 4% center fraction is selected. |
| 219 | +
|
| 220 | + Modified and adopted from: |
| 221 | + https://github.com/facebookresearch/fastMRI/tree/master/fastmri |
| 222 | + """ |
| 223 | + |
| 224 | + backend = [TransformBackends.TORCH] |
| 225 | + |
| 226 | + def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]: |
| 227 | + """ |
| 228 | + Args: |
| 229 | + kspace: The input k-space data. The shape is (...,num_coils,H,W,2) |
| 230 | + for complex 2D inputs and (...,num_coils,H,W,D) for real 3D |
| 231 | + data. The last spatial dim is selected for sampling. For the |
| 232 | + fastMRI multi-coil dataset, k-space has the form |
| 233 | + (...,num_slices,num_coils,H,W) and sampling is done along W. |
| 234 | + For a general 3D data with the shape (...,num_coils,H,W,D), |
| 235 | + sampling is done along D. |
| 236 | +
|
| 237 | + Returns: |
| 238 | + A tuple containing |
| 239 | + (1) the under-sampled kspace |
| 240 | + (2) absolute value of the inverse fourier of the under-sampled kspace |
| 241 | + """ |
| 242 | + kspace_t = convert_to_tensor_complex(kspace) |
| 243 | + spatial_size = kspace_t.shape |
| 244 | + num_cols = spatial_size[-1] |
| 245 | + if self.is_complex: # for complex data |
| 246 | + num_cols = spatial_size[-2] |
| 247 | + |
| 248 | + center_fraction, acceleration = self.randomize_choose_acceleration() |
| 249 | + num_low_freqs = int(round(num_cols * center_fraction)) |
| 250 | + |
| 251 | + # Create the mask |
| 252 | + mask = np.zeros(num_cols, dtype=np.float32) |
| 253 | + pad = (num_cols - num_low_freqs + 1) // 2 |
| 254 | + mask[pad : pad + num_low_freqs] = True |
| 255 | + |
| 256 | + # Determine acceleration rate by adjusting for the |
| 257 | + # number of low frequencies |
| 258 | + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) |
| 259 | + offset = self.R.randint(0, round(adjusted_accel)) |
| 260 | + |
| 261 | + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) |
| 262 | + accel_samples = np.around(accel_samples).astype(np.uint) |
| 263 | + mask[accel_samples] = True |
| 264 | + |
| 265 | + # Reshape the mask |
| 266 | + mask_shape = [1 for _ in spatial_size] |
| 267 | + if self.is_complex: |
| 268 | + mask_shape[-2] = num_cols |
| 269 | + else: |
| 270 | + mask_shape[-1] = num_cols |
| 271 | + |
| 272 | + mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32)) |
| 273 | + |
| 274 | + # under-sample the ksapce |
| 275 | + masked = mask * kspace_t |
| 276 | + masked_kspace: Tensor = convert_to_tensor(masked) |
| 277 | + self.mask = mask |
| 278 | + |
| 279 | + # compute inverse fourier of the masked kspace |
| 280 | + masked_kspace_ifft: Tensor = convert_to_tensor( |
| 281 | + complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex)) |
| 282 | + ) |
| 283 | + # combine coil images (it is assumed that the coil dimension is |
| 284 | + # the first dimension before spatial dimensions) |
| 285 | + masked_kspace_ifft_rss: Tensor = convert_to_tensor( |
| 286 | + root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1) |
| 287 | + ) |
| 288 | + return masked_kspace, masked_kspace_ifft_rss |
0 commit comments