Skip to content

Commit 0737a33

Browse files
KumoLiuwyli
andauthored
Add safe_dtype_convert (#5620)
Fixes #5621 . ### Description As stated in the [issue](#5621), a warning is currently added to remind the user of the intensity overflow. Add `safe_dtype_convert` Add `safe` flag in `convert_data_type`, `convert_to_tensor`, `convert_to_numpy`, `convert_to_cupy` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu <yunl@nvidia.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com>
1 parent b533a53 commit 0737a33

File tree

4 files changed

+241
-27
lines changed

4 files changed

+241
-27
lines changed

monai/utils/type_conversion.py

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def convert_to_tensor(
105105
device: Union[None, str, torch.device] = None,
106106
wrap_sequence: bool = False,
107107
track_meta: bool = False,
108+
safe: bool = False,
108109
):
109110
"""
110111
Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
@@ -121,6 +122,9 @@ def convert_to_tensor(
121122
E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.
122123
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
123124
default to `False`.
125+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
126+
E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`.
127+
If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`.
124128
125129
"""
126130

@@ -138,6 +142,8 @@ def _convert_tensor(tensor, **kwargs):
138142
return tensor.as_tensor()
139143
return tensor
140144

145+
if safe:
146+
data = safe_dtype_range(data, dtype)
141147
dtype = get_equivalent_dtype(dtype, torch.Tensor)
142148
if isinstance(data, torch.Tensor):
143149
return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
@@ -164,7 +170,7 @@ def _convert_tensor(tensor, **kwargs):
164170
return data
165171

166172

167-
def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False):
173+
def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False, safe: bool = False):
168174
"""
169175
Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,
170176
recursively check every item and convert it to numpy array.
@@ -176,7 +182,11 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
176182
dtype: target data type when converting to numpy array.
177183
wrap_sequence: if `False`, then lists will recursively call this function.
178184
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
185+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
186+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
179187
"""
188+
if safe:
189+
data = safe_dtype_range(data, dtype)
180190
if isinstance(data, torch.Tensor):
181191
data = np.asarray(data.detach().to(device="cpu").numpy(), dtype=get_equivalent_dtype(dtype, np.ndarray))
182192
elif has_cp and isinstance(data, cp_ndarray):
@@ -205,7 +215,7 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
205215
return data
206216

207217

208-
def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool = False):
218+
def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool = False, safe: bool = False):
209219
"""
210220
Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple,
211221
recursively check every item and convert it to cupy array.
@@ -218,8 +228,11 @@ def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool
218228
for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
219229
wrap_sequence: if `False`, then lists will recursively call this function.
220230
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
231+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
232+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
221233
"""
222-
234+
if safe:
235+
data = safe_dtype_range(data, dtype)
223236
# direct calls
224237
if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)):
225238
data = cp.asarray(data, dtype)
@@ -246,6 +259,7 @@ def convert_data_type(
246259
device: Union[None, str, torch.device] = None,
247260
dtype: Union[DtypeLike, torch.dtype] = None,
248261
wrap_sequence: bool = False,
262+
safe: bool = False,
249263
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
250264
"""
251265
Convert to `MetaTensor`, `torch.Tensor` or `np.ndarray` from `MetaTensor`, `torch.Tensor`,
@@ -260,6 +274,8 @@ def convert_data_type(
260274
If left blank, it remains unchanged.
261275
wrap_sequence: if `False`, then lists will recursively call this function.
262276
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
277+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
278+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
263279
264280
Returns:
265281
modified data, orig_type, orig_device
@@ -288,20 +304,20 @@ def convert_data_type(
288304
orig_device = data.device if isinstance(data, torch.Tensor) else None
289305

290306
output_type = output_type or orig_type
291-
292307
dtype_ = get_equivalent_dtype(dtype, output_type)
293308

294309
data_: NdarrayTensor
295-
296310
if issubclass(output_type, torch.Tensor):
297311
track_meta = issubclass(output_type, monai.data.MetaTensor)
298-
data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta)
312+
data_ = convert_to_tensor(
313+
data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta, safe=safe
314+
)
299315
return data_, orig_type, orig_device
300316
if issubclass(output_type, np.ndarray):
301-
data_ = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence)
317+
data_ = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence, safe=safe)
302318
return data_, orig_type, orig_device
303319
elif has_cp and issubclass(output_type, cp.ndarray):
304-
data_ = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence)
320+
data_ = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence, safe=safe)
305321
return data_, orig_type, orig_device
306322
raise ValueError(f"Unsupported output type: {output_type}")
307323

@@ -312,6 +328,7 @@ def convert_to_dst_type(
312328
dtype: Union[DtypeLike, torch.dtype, None] = None,
313329
wrap_sequence: bool = False,
314330
device: Union[None, str, torch.device] = None,
331+
safe: bool = False,
315332
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
316333
"""
317334
Convert source data to the same data type and device as the destination data.
@@ -326,6 +343,8 @@ def convert_to_dst_type(
326343
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
327344
If `True`, then `[1, 2]` -> `array([1, 2])`.
328345
device: target device to put the converted Tensor data. If unspecified, `dst.device` will be used if possible.
346+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
347+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
329348
330349
See Also:
331350
:func:`convert_data_type`
@@ -349,7 +368,7 @@ def convert_to_dst_type(
349368
output_type = type(dst)
350369
output: NdarrayTensor
351370
output, _type, _device = convert_data_type(
352-
data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence
371+
data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence, safe=safe
353372
)
354373
if copy_meta and isinstance(output, monai.data.MetaTensor):
355374
output.copy_meta_from(dst)
@@ -366,3 +385,74 @@ def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list:
366385
367386
"""
368387
return data.tolist() if isinstance(data, (torch.Tensor, np.ndarray)) else list(data)
388+
389+
390+
def get_dtype_bound_value(dtype: Union[DtypeLike, torch.dtype]):
391+
"""
392+
Get dtype bound value
393+
Args:
394+
dtype: dtype to get bound value
395+
Returns:
396+
(bound_min_value, bound_max_value)
397+
"""
398+
if dtype in UNSUPPORTED_TYPES:
399+
is_floating_point = False
400+
else:
401+
is_floating_point = get_equivalent_dtype(dtype, torch.Tensor).is_floating_point
402+
dtype = get_equivalent_dtype(dtype, np.array)
403+
if is_floating_point:
404+
return (np.finfo(dtype).min, np.finfo(dtype).max) # type: ignore
405+
else:
406+
return (np.iinfo(dtype).min, np.iinfo(dtype).max)
407+
408+
409+
def safe_dtype_range(data: Any, dtype: Union[DtypeLike, torch.dtype] = None):
410+
"""
411+
Utility to safely convert the input data to target dtype.
412+
413+
Args:
414+
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
415+
will convert to target dtype and keep the original type.
416+
for dictionary, list or tuple, convert every item.
417+
dtype: target data type to convert.
418+
"""
419+
420+
def _safe_dtype_range(data, dtype):
421+
output_dtype = dtype if dtype is not None else data.dtype
422+
dtype_bound_value = get_dtype_bound_value(output_dtype)
423+
if data.ndim == 0:
424+
data_bound = (data, data)
425+
else:
426+
data_bound = (min(data), max(data))
427+
if (data_bound[1] > dtype_bound_value[1]) or (data_bound[0] < dtype_bound_value[0]):
428+
if isinstance(data, torch.Tensor):
429+
return torch.clamp(data, dtype_bound_value[0], dtype_bound_value[1])
430+
elif isinstance(data, np.ndarray):
431+
return np.clip(data, dtype_bound_value[0], dtype_bound_value[1])
432+
elif has_cp and isinstance(data, cp_ndarray):
433+
return cp.clip(data, dtype_bound_value[0], dtype_bound_value[1])
434+
else:
435+
return data
436+
437+
if has_cp and isinstance(data, cp_ndarray):
438+
return cp.asarray(_safe_dtype_range(data, dtype))
439+
elif isinstance(data, np.ndarray):
440+
return np.asarray(_safe_dtype_range(data, dtype))
441+
elif isinstance(data, torch.Tensor):
442+
return _safe_dtype_range(data, dtype)
443+
elif isinstance(data, (float, int, bool)) and dtype is None:
444+
return data
445+
elif isinstance(data, (float, int, bool)) and dtype is not None:
446+
output_dtype = dtype
447+
dtype_bound_value = get_dtype_bound_value(output_dtype)
448+
data = dtype_bound_value[1] if data > dtype_bound_value[1] else data
449+
data = dtype_bound_value[0] if data < dtype_bound_value[0] else data
450+
return data
451+
452+
elif isinstance(data, list):
453+
return [safe_dtype_range(i, dtype=dtype) for i in data]
454+
elif isinstance(data, tuple):
455+
return tuple(safe_dtype_range(i, dtype=dtype) for i in data)
456+
elif isinstance(data, dict):
457+
return {k: safe_dtype_range(v, dtype=dtype) for k, v in data.items()}
458+
return data

tests/test_convert_data_type.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,45 @@
1717
from parameterized import parameterized
1818

1919
from monai.data import MetaTensor
20-
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
21-
from tests.utils import TEST_NDARRAYS_ALL
20+
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, get_equivalent_dtype
21+
from tests.utils import TEST_NDARRAYS_ALL, assert_allclose
2222

2323
TESTS: List[Tuple] = []
2424
for in_type in TEST_NDARRAYS_ALL + (int, float):
2525
for out_type in TEST_NDARRAYS_ALL:
26-
TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)))) # type: ignore
26+
TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)), None, False)) # type: ignore
27+
if in_type is not float:
28+
TESTS.append((in_type(np.array(256)), out_type(np.array(255)), np.uint8, True)) # type: ignore
2729

2830
TESTS_LIST: List[Tuple] = []
2931
for in_type in TEST_NDARRAYS_ALL + (int, float):
3032
for out_type in TEST_NDARRAYS_ALL:
3133
TESTS_LIST.append(
32-
([in_type(np.array(1.0)), in_type(np.array(1.0))], out_type(np.array([1.0, 1.0])), True) # type: ignore
34+
([in_type(np.array(1.0)), in_type(np.array(1.0))], out_type(np.array([1.0, 1.0])), True, None, False) # type: ignore
3335
)
3436
TESTS_LIST.append(
3537
(
3638
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
3739
[out_type(np.array(1.0)), out_type(np.array(1.0))],
3840
False,
41+
None,
42+
False,
3943
)
4044
)
45+
if in_type is not float:
46+
TESTS_LIST.append(
47+
([in_type(np.array(257)), in_type(np.array(1))], out_type(np.array([255, 1])), True, np.uint8, True) # type: ignore
48+
)
49+
TESTS_LIST.append(
50+
(
51+
[in_type(np.array(257)), in_type(np.array(-12))], # type: ignore
52+
[out_type(np.array(255)), out_type(np.array(0))],
53+
False,
54+
np.uint8,
55+
True,
56+
)
57+
)
58+
4159

4260
UNSUPPORTED_TYPES = {np.dtype("uint16"): torch.int32, np.dtype("uint32"): torch.int64, np.dtype("uint64"): torch.int64}
4361

@@ -48,17 +66,20 @@ class TestTensor(torch.Tensor):
4866

4967
class TestConvertDataType(unittest.TestCase):
5068
@parameterized.expand(TESTS)
51-
def test_convert_data_type(self, in_image, im_out):
52-
converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out))
69+
def test_convert_data_type(self, in_image, im_out, out_dtype, safe):
70+
converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out), dtype=out_dtype, safe=safe)
5371
# check input is unchanged
5472
self.assertEqual(type(in_image), orig_type)
5573
if isinstance(in_image, torch.Tensor):
5674
self.assertEqual(in_image.device, orig_device)
5775
# check output is desired type
5876
self.assertEqual(type(converted_im), type(im_out))
77+
# check data has been clipped
78+
assert_allclose(converted_im, im_out)
5979
# check dtype is unchanged
60-
if isinstance(in_type, (np.ndarray, torch.Tensor)):
61-
self.assertEqual(converted_im.dtype, im_out.dtype)
80+
if out_dtype is None:
81+
if isinstance(in_image, (np.ndarray, torch.Tensor)):
82+
self.assertEqual(converted_im.dtype, im_out.dtype)
6283

6384
def test_neg_stride(self):
6485
_ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor)
@@ -71,26 +92,32 @@ def test_unsupported_np_types(self, np_type, pt_type):
7192
self.assertEqual(converted_im.dtype, pt_type)
7293

7394
@parameterized.expand(TESTS_LIST)
74-
def test_convert_list(self, in_image, im_out, wrap):
95+
def test_convert_list(self, in_image, im_out, wrap, out_dtype, safe):
7596
output_type = type(im_out) if wrap else type(im_out[0])
76-
converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap)
97+
converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap, dtype=out_dtype, safe=safe)
7798
# check output is desired type
7899
if not wrap:
79100
converted_im = converted_im[0]
80101
im_out = im_out[0]
81102
self.assertEqual(type(converted_im), type(im_out))
103+
assert_allclose(converted_im, im_out)
82104
# check dtype is unchanged
83-
if isinstance(in_type, (np.ndarray, torch.Tensor)):
84-
self.assertEqual(converted_im.dtype, im_out.dtype)
105+
if isinstance(in_image[0], (np.ndarray, torch.Tensor)):
106+
if out_dtype is None:
107+
self.assertEqual(converted_im.dtype, im_out.dtype)
108+
else:
109+
_out_dtype = get_equivalent_dtype(out_dtype, output_type)
110+
self.assertEqual(converted_im.dtype, _out_dtype)
85111

86112

87113
class TestConvertDataSame(unittest.TestCase):
88114
# add test for subclass of Tensor
89-
@parameterized.expand(TESTS + [(np.array(1.0), TestTensor(np.array(1.0)))])
90-
def test_convert_data_type(self, in_image, im_out):
91-
converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out)
115+
@parameterized.expand(TESTS + [(np.array(256), TestTensor(np.array([255])), torch.uint8, True)])
116+
def test_convert_data_type(self, in_image, im_out, out_dtype, safe):
117+
converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out, dtype=out_dtype, safe=safe)
92118
# check input is unchanged
93119
self.assertEqual(type(in_image), orig_type)
120+
assert_allclose(converted_im, im_out)
94121
if isinstance(in_image, torch.Tensor):
95122
self.assertEqual(in_image.device, orig_device)
96123

@@ -103,8 +130,9 @@ def test_convert_data_type(self, in_image, im_out):
103130
output_type = np.ndarray
104131
self.assertEqual(type(converted_im), output_type)
105132
# check dtype is unchanged
106-
if isinstance(in_type, (np.ndarray, torch.Tensor, MetaTensor)):
107-
self.assertEqual(converted_im.dtype, im_out.dtype)
133+
if out_dtype is None:
134+
if isinstance(in_image, (np.ndarray, torch.Tensor, MetaTensor)):
135+
self.assertEqual(converted_im.dtype, im_out.dtype)
108136

109137

110138
if __name__ == "__main__":

0 commit comments

Comments
 (0)