Skip to content

Commit 9fb9d98

Browse files
authored
4587 mri transforms (#4591)
1 parent ec5376e commit 9fb9d98

File tree

8 files changed

+895
-2
lines changed

8 files changed

+895
-2
lines changed

docs/source/transforms.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,23 @@ Smooth Field
798798
:members:
799799
:special-members: __call__
800800

801+
802+
MRI Transforms
803+
^^^^^^^^^^^^^^
804+
805+
`Kspace under-sampling`
806+
"""""""""""""""""""""""
807+
.. autoclass:: monai.apps.reconstruction.transforms.array.KspaceMask
808+
:members:
809+
:special-members: __call__
810+
811+
.. autoclass:: monai.apps.reconstruction.transforms.array.RandomKspaceMask
812+
:special-members: __call__
813+
814+
.. autoclass:: monai.apps.reconstruction.transforms.array.EquispacedKspaceMask
815+
:special-members: __call__
816+
817+
801818
Utility
802819
^^^^^^^
803820

@@ -1683,6 +1700,34 @@ Smooth Field (Dict)
16831700
:members:
16841701
:special-members: __call__
16851702

1703+
1704+
`MRI transforms (Dict)`
1705+
^^^^^^^^^^^^^^^^^^^^^^^
1706+
1707+
`Kspace under-sampling (Dict)`
1708+
""""""""""""""""""""""""""""""
1709+
.. autoclass:: monai.apps.reconstruction.transforms.dictionary.RandomKspaceMaskd
1710+
:special-members: __call__
1711+
1712+
.. autoclass:: monai.apps.reconstruction.transforms.dictionary.EquispacedKspaceMaskd
1713+
:special-members: __call__
1714+
1715+
`ExtractDataKeyFromMetaKeyd`
1716+
""""""""""""""""""""""""""""
1717+
.. autoclass:: monai.apps.reconstruction.transforms.dictionary.ExtractDataKeyFromMetaKeyd
1718+
:special-members: __call__
1719+
1720+
`ReferenceBasedSpatialCropd`
1721+
""""""""""""""""""""""""""""
1722+
.. autoclass:: monai.apps.reconstruction.transforms.dictionary.ReferenceBasedSpatialCropd
1723+
:special-members: __call__
1724+
1725+
`ReferenceBasedNormalizeIntensityd`
1726+
"""""""""""""""""""""""""""""""""""
1727+
.. autoclass:: monai.apps.reconstruction.transforms.dictionary.ReferenceBasedNormalizeIntensityd
1728+
:special-members: __call__
1729+
1730+
16861731
Utility (Dict)
16871732
^^^^^^^^^^^^^^
16881733

monai/apps/reconstruction/fastmri_reader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ class FastMRIReader(ImageReader):
3131
are HDF5 dictionary-like datasets. The keys are:
3232
3333
- kspace: contains the fully-sampled kspace
34-
- reconstruction_rss: contains the root sume of squares of ifft of kspace
34+
- reconstruction_rss: contains the root sum of squares of ifft of kspace. This
35+
is the ground-truth image.
3536
3637
It also has several attributes with the following keys:
3738
38-
- acquisition (str): acquisition mode of the data
39+
- acquisition (str): acquisition mode of the data (e.g., AXT2 denotes T2 brain MRI scans)
3940
- max (float): dynamic range of the data
4041
- norm (float): norm of the kspace
4142
- patient_id (str): the patient's id whose measurements were recorded
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
from abc import abstractmethod
14+
from typing import Sequence
15+
16+
import numpy as np
17+
from torch import Tensor
18+
19+
from monai.apps.reconstruction.complex_utils import complex_abs, convert_to_tensor_complex
20+
from monai.apps.reconstruction.mri_utils import root_sum_of_squares
21+
from monai.config.type_definitions import NdarrayOrTensor
22+
from monai.data.fft_utils import ifftn_centered
23+
from monai.transforms.transform import RandomizableTransform
24+
from monai.utils.enums import TransformBackends
25+
from monai.utils.type_conversion import convert_to_tensor
26+
27+
28+
class KspaceMask(RandomizableTransform):
29+
"""
30+
A basic class for under-sampling mask setup. It provides common
31+
features for under-sampling mask generators.
32+
For example, RandomMaskFunc and EquispacedMaskFunc (two mask
33+
transform objects defined right after this module)
34+
both inherit MaskFunc to properly setup properties like the
35+
acceleration factor.
36+
"""
37+
38+
def __init__(
39+
self,
40+
center_fractions: Sequence[float],
41+
accelerations: Sequence[float],
42+
spatial_dims: int = 2,
43+
is_complex: bool = True,
44+
):
45+
"""
46+
Args:
47+
center_fractions: Fraction of low-frequency columns to be retained.
48+
If multiple values are provided, then one of these numbers
49+
is chosen uniformly each time.
50+
accelerations: Amount of under-sampling. This should have the
51+
same length as center_fractions. If multiple values are
52+
provided, then one of these is chosen uniformly each time.
53+
spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data;
54+
it's also 2 for psuedo-3D datasets like the fastMRI dataset).
55+
The last spatial dim is selected for sampling. For the fastMRI
56+
dataset, k-space has the form (...,num_slices,num_coils,H,W)
57+
and sampling is done along W. For a general 3D data with the
58+
shape (...,num_coils,H,W,D), sampling is done along D.
59+
is_complex: if True, then the last dimension will be reserved for
60+
real/imaginary parts.
61+
"""
62+
if len(center_fractions) != len(accelerations):
63+
raise ValueError(
64+
"Number of center fractions \
65+
should match number of accelerations"
66+
)
67+
68+
self.center_fractions = center_fractions
69+
self.accelerations = accelerations
70+
self.spatial_dims = spatial_dims
71+
self.is_complex = is_complex
72+
73+
@abstractmethod
74+
def __call__(self, kspace: NdarrayOrTensor):
75+
"""
76+
This is an extra instance to allow for defining new mask generators.
77+
For creating other mask transforms, define a new class and simply
78+
override __call__. See an example of this in
79+
:py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`.
80+
81+
Args:
82+
kspace: The input k-space data. The shape is (...,num_coils,H,W,2)
83+
for complex 2D inputs and (...,num_coils,H,W,D) for real 3D
84+
data.
85+
"""
86+
raise NotImplementedError
87+
88+
def randomize_choose_acceleration(self) -> Sequence[float]:
89+
"""
90+
If multiple values are provided for center_fractions and
91+
accelerations, this function selects one value uniformly
92+
for each training/test sample.
93+
94+
Returns:
95+
A tuple containing
96+
(1) center_fraction: chosen fraction of center kspace
97+
lines to exclude from under-sampling
98+
(2) acceleration: chosen acceleration factor
99+
"""
100+
choice = self.R.randint(0, len(self.accelerations))
101+
center_fraction = self.center_fractions[choice]
102+
acceleration = self.accelerations[choice]
103+
return center_fraction, acceleration
104+
105+
106+
class RandomKspaceMask(KspaceMask):
107+
"""
108+
This k-space mask transform under-samples the k-space according to a
109+
random sampling pattern. Precisely, it uniformly selects a subset of
110+
columns from the input k-space data. If the k-space data has N columns,
111+
the mask picks out:
112+
113+
1. N_low_freqs = (N * center_fraction) columns in the center
114+
corresponding to low-frequencies
115+
116+
2. The other columns are selected uniformly at random with a probability
117+
equal to:
118+
prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs).
119+
This ensures that the expected number of columns selected is equal to
120+
(N / acceleration)
121+
122+
It is possible to use multiple center_fractions and accelerations,
123+
in which case one possible (center_fraction, acceleration) is chosen
124+
uniformly at random each time the transform is called.
125+
126+
Example:
127+
If accelerations = [4, 8] and center_fractions = [0.08, 0.04],
128+
then there is a 50% probability that 4-fold acceleration with 8%
129+
center fraction is selected and a 50% probability that 8-fold
130+
acceleration with 4% center fraction is selected.
131+
132+
Modified and adopted from:
133+
https://github.com/facebookresearch/fastMRI/tree/master/fastmri
134+
"""
135+
136+
backend = [TransformBackends.TORCH]
137+
138+
def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:
139+
"""
140+
Args:
141+
kspace: The input k-space data. The shape is (...,num_coils,H,W,2)
142+
for complex 2D inputs and (...,num_coils,H,W,D) for real 3D
143+
data. The last spatial dim is selected for sampling. For the
144+
fastMRI dataset, k-space has the form
145+
(...,num_slices,num_coils,H,W) and sampling is done along W.
146+
For a general 3D data with the shape (...,num_coils,H,W,D),
147+
sampling is done along D.
148+
149+
Returns:
150+
A tuple containing
151+
(1) the under-sampled kspace
152+
(2) absolute value of the inverse fourier of the under-sampled kspace
153+
"""
154+
kspace_t = convert_to_tensor_complex(kspace)
155+
spatial_size = kspace_t.shape
156+
num_cols = spatial_size[-1]
157+
if self.is_complex: # for complex data
158+
num_cols = spatial_size[-2]
159+
160+
center_fraction, acceleration = self.randomize_choose_acceleration()
161+
162+
# Create the mask
163+
num_low_freqs = int(round(num_cols * center_fraction))
164+
prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)
165+
mask = self.R.uniform(size=num_cols) < prob
166+
pad = (num_cols - num_low_freqs + 1) // 2
167+
mask[pad : pad + num_low_freqs] = True
168+
169+
# Reshape the mask
170+
mask_shape = [1 for _ in spatial_size]
171+
if self.is_complex:
172+
mask_shape[-2] = num_cols
173+
else:
174+
mask_shape[-1] = num_cols
175+
176+
mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32))
177+
178+
# under-sample the ksapce
179+
masked = mask * kspace_t
180+
masked_kspace: Tensor = convert_to_tensor(masked)
181+
self.mask = mask
182+
183+
# compute inverse fourier of the masked kspace
184+
masked_kspace_ifft: Tensor = convert_to_tensor(
185+
complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex))
186+
)
187+
# combine coil images (it is assumed that the coil dimension is
188+
# the first dimension before spatial dimensions)
189+
masked_kspace_ifft_rss: Tensor = convert_to_tensor(
190+
root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1)
191+
)
192+
return masked_kspace, masked_kspace_ifft_rss
193+
194+
195+
class EquispacedKspaceMask(KspaceMask):
196+
"""
197+
This k-space mask transform under-samples the k-space according to an
198+
equi-distant sampling pattern. Precisely, it selects an equi-distant
199+
subset of columns from the input k-space data. If the k-space data has N
200+
columns, the mask picks out:
201+
202+
1. N_low_freqs = (N * center_fraction) columns in the center corresponding
203+
to low-frequencies
204+
205+
2. The other columns are selected with equal spacing at a proportion that
206+
reaches the desired acceleration rate taking into consideration the number
207+
of low frequencies. This ensures that the expected number of columns
208+
selected is equal to (N / acceleration)
209+
210+
It is possible to use multiple center_fractions and accelerations, in
211+
which case one possible (center_fraction, acceleration) is chosen
212+
uniformly at random each time the EquispacedMaskFunc object is called.
213+
214+
Example:
215+
If accelerations = [4, 8] and center_fractions = [0.08, 0.04],
216+
then there is a 50% probability that 4-fold acceleration with 8%
217+
center fraction is selected and a 50% probability that 8-fold
218+
acceleration with 4% center fraction is selected.
219+
220+
Modified and adopted from:
221+
https://github.com/facebookresearch/fastMRI/tree/master/fastmri
222+
"""
223+
224+
backend = [TransformBackends.TORCH]
225+
226+
def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:
227+
"""
228+
Args:
229+
kspace: The input k-space data. The shape is (...,num_coils,H,W,2)
230+
for complex 2D inputs and (...,num_coils,H,W,D) for real 3D
231+
data. The last spatial dim is selected for sampling. For the
232+
fastMRI multi-coil dataset, k-space has the form
233+
(...,num_slices,num_coils,H,W) and sampling is done along W.
234+
For a general 3D data with the shape (...,num_coils,H,W,D),
235+
sampling is done along D.
236+
237+
Returns:
238+
A tuple containing
239+
(1) the under-sampled kspace
240+
(2) absolute value of the inverse fourier of the under-sampled kspace
241+
"""
242+
kspace_t = convert_to_tensor_complex(kspace)
243+
spatial_size = kspace_t.shape
244+
num_cols = spatial_size[-1]
245+
if self.is_complex: # for complex data
246+
num_cols = spatial_size[-2]
247+
248+
center_fraction, acceleration = self.randomize_choose_acceleration()
249+
num_low_freqs = int(round(num_cols * center_fraction))
250+
251+
# Create the mask
252+
mask = np.zeros(num_cols, dtype=np.float32)
253+
pad = (num_cols - num_low_freqs + 1) // 2
254+
mask[pad : pad + num_low_freqs] = True
255+
256+
# Determine acceleration rate by adjusting for the
257+
# number of low frequencies
258+
adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)
259+
offset = self.R.randint(0, round(adjusted_accel))
260+
261+
accel_samples = np.arange(offset, num_cols - 1, adjusted_accel)
262+
accel_samples = np.around(accel_samples).astype(np.uint)
263+
mask[accel_samples] = True
264+
265+
# Reshape the mask
266+
mask_shape = [1 for _ in spatial_size]
267+
if self.is_complex:
268+
mask_shape[-2] = num_cols
269+
else:
270+
mask_shape[-1] = num_cols
271+
272+
mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32))
273+
274+
# under-sample the ksapce
275+
masked = mask * kspace_t
276+
masked_kspace: Tensor = convert_to_tensor(masked)
277+
self.mask = mask
278+
279+
# compute inverse fourier of the masked kspace
280+
masked_kspace_ifft: Tensor = convert_to_tensor(
281+
complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex))
282+
)
283+
# combine coil images (it is assumed that the coil dimension is
284+
# the first dimension before spatial dimensions)
285+
masked_kspace_ifft_rss: Tensor = convert_to_tensor(
286+
root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1)
287+
)
288+
return masked_kspace, masked_kspace_ifft_rss

0 commit comments

Comments
 (0)