@@ -105,6 +105,7 @@ def convert_to_tensor(
105105 device : Union [None , str , torch .device ] = None ,
106106 wrap_sequence : bool = False ,
107107 track_meta : bool = False ,
108+ safe : bool = False ,
108109):
109110 """
110111 Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
@@ -121,6 +122,9 @@ def convert_to_tensor(
121122 E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.
122123 track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
123124 default to `False`.
125+ safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
126+ E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`.
127+ If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`.
124128
125129 """
126130
@@ -138,6 +142,8 @@ def _convert_tensor(tensor, **kwargs):
138142 return tensor .as_tensor ()
139143 return tensor
140144
145+ if safe :
146+ data = safe_dtype_range (data , dtype )
141147 dtype = get_equivalent_dtype (dtype , torch .Tensor )
142148 if isinstance (data , torch .Tensor ):
143149 return _convert_tensor (data ).to (dtype = dtype , device = device , memory_format = torch .contiguous_format )
@@ -164,7 +170,7 @@ def _convert_tensor(tensor, **kwargs):
164170 return data
165171
166172
167- def convert_to_numpy (data , dtype : DtypeLike = None , wrap_sequence : bool = False ):
173+ def convert_to_numpy (data , dtype : DtypeLike = None , wrap_sequence : bool = False , safe : bool = False ):
168174 """
169175 Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,
170176 recursively check every item and convert it to numpy array.
@@ -176,7 +182,11 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
176182 dtype: target data type when converting to numpy array.
177183 wrap_sequence: if `False`, then lists will recursively call this function.
178184 E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
185+ safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
186+ E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
179187 """
188+ if safe :
189+ data = safe_dtype_range (data , dtype )
180190 if isinstance (data , torch .Tensor ):
181191 data = np .asarray (data .detach ().to (device = "cpu" ).numpy (), dtype = get_equivalent_dtype (dtype , np .ndarray ))
182192 elif has_cp and isinstance (data , cp_ndarray ):
@@ -205,7 +215,7 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
205215 return data
206216
207217
208- def convert_to_cupy (data , dtype : Optional [np .dtype ] = None , wrap_sequence : bool = False ):
218+ def convert_to_cupy (data , dtype : Optional [np .dtype ] = None , wrap_sequence : bool = False , safe : bool = False ):
209219 """
210220 Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple,
211221 recursively check every item and convert it to cupy array.
@@ -218,8 +228,11 @@ def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool
218228 for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
219229 wrap_sequence: if `False`, then lists will recursively call this function.
220230 E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
231+ safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
232+ E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
221233 """
222-
234+ if safe :
235+ data = safe_dtype_range (data , dtype )
223236 # direct calls
224237 if isinstance (data , (cp_ndarray , np .ndarray , torch .Tensor , float , int , bool )):
225238 data = cp .asarray (data , dtype )
@@ -246,6 +259,7 @@ def convert_data_type(
246259 device : Union [None , str , torch .device ] = None ,
247260 dtype : Union [DtypeLike , torch .dtype ] = None ,
248261 wrap_sequence : bool = False ,
262+ safe : bool = False ,
249263) -> Tuple [NdarrayTensor , type , Optional [torch .device ]]:
250264 """
251265 Convert to `MetaTensor`, `torch.Tensor` or `np.ndarray` from `MetaTensor`, `torch.Tensor`,
@@ -260,6 +274,8 @@ def convert_data_type(
260274 If left blank, it remains unchanged.
261275 wrap_sequence: if `False`, then lists will recursively call this function.
262276 E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
277+ safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
278+ E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
263279
264280 Returns:
265281 modified data, orig_type, orig_device
@@ -288,20 +304,20 @@ def convert_data_type(
288304 orig_device = data .device if isinstance (data , torch .Tensor ) else None
289305
290306 output_type = output_type or orig_type
291-
292307 dtype_ = get_equivalent_dtype (dtype , output_type )
293308
294309 data_ : NdarrayTensor
295-
296310 if issubclass (output_type , torch .Tensor ):
297311 track_meta = issubclass (output_type , monai .data .MetaTensor )
298- data_ = convert_to_tensor (data , dtype = dtype_ , device = device , wrap_sequence = wrap_sequence , track_meta = track_meta )
312+ data_ = convert_to_tensor (
313+ data , dtype = dtype_ , device = device , wrap_sequence = wrap_sequence , track_meta = track_meta , safe = safe
314+ )
299315 return data_ , orig_type , orig_device
300316 if issubclass (output_type , np .ndarray ):
301- data_ = convert_to_numpy (data , dtype = dtype_ , wrap_sequence = wrap_sequence )
317+ data_ = convert_to_numpy (data , dtype = dtype_ , wrap_sequence = wrap_sequence , safe = safe )
302318 return data_ , orig_type , orig_device
303319 elif has_cp and issubclass (output_type , cp .ndarray ):
304- data_ = convert_to_cupy (data , dtype = dtype_ , wrap_sequence = wrap_sequence )
320+ data_ = convert_to_cupy (data , dtype = dtype_ , wrap_sequence = wrap_sequence , safe = safe )
305321 return data_ , orig_type , orig_device
306322 raise ValueError (f"Unsupported output type: { output_type } " )
307323
@@ -312,6 +328,7 @@ def convert_to_dst_type(
312328 dtype : Union [DtypeLike , torch .dtype , None ] = None ,
313329 wrap_sequence : bool = False ,
314330 device : Union [None , str , torch .device ] = None ,
331+ safe : bool = False ,
315332) -> Tuple [NdarrayTensor , type , Optional [torch .device ]]:
316333 """
317334 Convert source data to the same data type and device as the destination data.
@@ -326,6 +343,8 @@ def convert_to_dst_type(
326343 wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
327344 If `True`, then `[1, 2]` -> `array([1, 2])`.
328345 device: target device to put the converted Tensor data. If unspecified, `dst.device` will be used if possible.
346+ safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
347+ E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
329348
330349 See Also:
331350 :func:`convert_data_type`
@@ -349,7 +368,7 @@ def convert_to_dst_type(
349368 output_type = type (dst )
350369 output : NdarrayTensor
351370 output , _type , _device = convert_data_type (
352- data = src , output_type = output_type , device = device , dtype = dtype , wrap_sequence = wrap_sequence
371+ data = src , output_type = output_type , device = device , dtype = dtype , wrap_sequence = wrap_sequence , safe = safe
353372 )
354373 if copy_meta and isinstance (output , monai .data .MetaTensor ):
355374 output .copy_meta_from (dst )
@@ -366,3 +385,74 @@ def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list:
366385
367386 """
368387 return data .tolist () if isinstance (data , (torch .Tensor , np .ndarray )) else list (data )
388+
389+
390+ def get_dtype_bound_value (dtype : Union [DtypeLike , torch .dtype ]):
391+ """
392+ Get dtype bound value
393+ Args:
394+ dtype: dtype to get bound value
395+ Returns:
396+ (bound_min_value, bound_max_value)
397+ """
398+ if dtype in UNSUPPORTED_TYPES :
399+ is_floating_point = False
400+ else :
401+ is_floating_point = get_equivalent_dtype (dtype , torch .Tensor ).is_floating_point
402+ dtype = get_equivalent_dtype (dtype , np .array )
403+ if is_floating_point :
404+ return (np .finfo (dtype ).min , np .finfo (dtype ).max ) # type: ignore
405+ else :
406+ return (np .iinfo (dtype ).min , np .iinfo (dtype ).max )
407+
408+
409+ def safe_dtype_range (data : Any , dtype : Union [DtypeLike , torch .dtype ] = None ):
410+ """
411+ Utility to safely convert the input data to target dtype.
412+
413+ Args:
414+ data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
415+ will convert to target dtype and keep the original type.
416+ for dictionary, list or tuple, convert every item.
417+ dtype: target data type to convert.
418+ """
419+
420+ def _safe_dtype_range (data , dtype ):
421+ output_dtype = dtype if dtype is not None else data .dtype
422+ dtype_bound_value = get_dtype_bound_value (output_dtype )
423+ if data .ndim == 0 :
424+ data_bound = (data , data )
425+ else :
426+ data_bound = (min (data ), max (data ))
427+ if (data_bound [1 ] > dtype_bound_value [1 ]) or (data_bound [0 ] < dtype_bound_value [0 ]):
428+ if isinstance (data , torch .Tensor ):
429+ return torch .clamp (data , dtype_bound_value [0 ], dtype_bound_value [1 ])
430+ elif isinstance (data , np .ndarray ):
431+ return np .clip (data , dtype_bound_value [0 ], dtype_bound_value [1 ])
432+ elif has_cp and isinstance (data , cp_ndarray ):
433+ return cp .clip (data , dtype_bound_value [0 ], dtype_bound_value [1 ])
434+ else :
435+ return data
436+
437+ if has_cp and isinstance (data , cp_ndarray ):
438+ return cp .asarray (_safe_dtype_range (data , dtype ))
439+ elif isinstance (data , np .ndarray ):
440+ return np .asarray (_safe_dtype_range (data , dtype ))
441+ elif isinstance (data , torch .Tensor ):
442+ return _safe_dtype_range (data , dtype )
443+ elif isinstance (data , (float , int , bool )) and dtype is None :
444+ return data
445+ elif isinstance (data , (float , int , bool )) and dtype is not None :
446+ output_dtype = dtype
447+ dtype_bound_value = get_dtype_bound_value (output_dtype )
448+ data = dtype_bound_value [1 ] if data > dtype_bound_value [1 ] else data
449+ data = dtype_bound_value [0 ] if data < dtype_bound_value [0 ] else data
450+ return data
451+
452+ elif isinstance (data , list ):
453+ return [safe_dtype_range (i , dtype = dtype ) for i in data ]
454+ elif isinstance (data , tuple ):
455+ return tuple (safe_dtype_range (i , dtype = dtype ) for i in data )
456+ elif isinstance (data , dict ):
457+ return {k : safe_dtype_range (v , dtype = dtype ) for k , v in data .items ()}
458+ return data
0 commit comments