@@ -214,8 +214,11 @@ class SpatialCrop(Transform):
214214 """
215215 General purpose cropper to produce sub-volume region of interest (ROI).
216216 It can support to crop ND spatial (channel-first) data.
217- Either a spatial center and size must be provided, or alternatively,
218- if center and size are not provided, the start and end coordinates of the ROI must be provided.
217+
218+ The cropped region can be parameterised in various ways:
219+ - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`)
220+ - a spatial center and size
221+ - the start and end coordinates of the ROI
219222 """
220223
221224 def __init__ (
@@ -224,35 +227,44 @@ def __init__(
224227 roi_size : Union [Sequence [int ], np .ndarray , None ] = None ,
225228 roi_start : Union [Sequence [int ], np .ndarray , None ] = None ,
226229 roi_end : Union [Sequence [int ], np .ndarray , None ] = None ,
230+ roi_slices : Optional [Sequence [slice ]] = None ,
227231 ) -> None :
228232 """
229233 Args:
230234 roi_center: voxel coordinates for center of the crop ROI.
231235 roi_size: size of the crop ROI.
232236 roi_start: voxel coordinates for start of the crop ROI.
233237 roi_end: voxel coordinates for end of the crop ROI.
238+ roi_slices: list of slices for each of the spatial dimensions.
234239 """
235- if roi_center is not None and roi_size is not None :
236- roi_center = np .asarray (roi_center , dtype = np .int16 )
237- roi_size = np .asarray (roi_size , dtype = np .int16 )
238- self .roi_start = np .maximum (roi_center - np .floor_divide (roi_size , 2 ), 0 )
239- self .roi_end = np .maximum (self .roi_start + roi_size , self .roi_start )
240+ if roi_slices :
241+ if not all (s .step is None or s .step == 1 for s in roi_slices ):
242+ raise ValueError ("Only slice steps of 1/None are currently supported" )
243+ self .slices = list (roi_slices )
240244 else :
241- if roi_start is None or roi_end is None :
242- raise ValueError ("Please specify either roi_center, roi_size or roi_start, roi_end." )
243- self .roi_start = np .maximum (np .asarray (roi_start , dtype = np .int16 ), 0 )
244- self .roi_end = np .maximum (np .asarray (roi_end , dtype = np .int16 ), self .roi_start )
245- # Allow for 1D by converting back to np.array (since np.maximum will convert to int)
246- self .roi_start = self .roi_start if isinstance (self .roi_start , np .ndarray ) else np .array ([self .roi_start ])
247- self .roi_end = self .roi_end if isinstance (self .roi_end , np .ndarray ) else np .array ([self .roi_end ])
245+ if roi_center is not None and roi_size is not None :
246+ roi_center = np .asarray (roi_center , dtype = np .int16 )
247+ roi_size = np .asarray (roi_size , dtype = np .int16 )
248+ roi_start_np = np .maximum (roi_center - np .floor_divide (roi_size , 2 ), 0 )
249+ roi_end_np = np .maximum (roi_start_np + roi_size , roi_start_np )
250+ else :
251+ if roi_start is None or roi_end is None :
252+ raise ValueError ("Please specify either roi_center, roi_size or roi_start, roi_end." )
253+ roi_start_np = np .maximum (np .asarray (roi_start , dtype = np .int16 ), 0 )
254+ roi_end_np = np .maximum (np .asarray (roi_end , dtype = np .int16 ), roi_start_np )
255+ # Allow for 1D by converting back to np.array (since np.maximum will convert to int)
256+ roi_start_np = roi_start_np if isinstance (roi_start_np , np .ndarray ) else np .array ([roi_start_np ])
257+ roi_end_np = roi_end_np if isinstance (roi_end_np , np .ndarray ) else np .array ([roi_end_np ])
258+ # convert to slices
259+ self .slices = [slice (s , e ) for s , e in zip (roi_start_np , roi_end_np )]
248260
249261 def __call__ (self , img : Union [np .ndarray , torch .Tensor ]):
250262 """
251263 Apply the transform to `img`, assuming `img` is channel-first and
252264 slicing doesn't apply to the channel dim.
253265 """
254- sd = min (self .roi_start . size , self . roi_end . size , len (img .shape [1 :])) # spatial dims
255- slices = [slice (None )] + [ slice ( s , e ) for s , e in zip ( self .roi_start [:sd ], self . roi_end [: sd ]) ]
266+ sd = min (len ( self .slices ) , len (img .shape [1 :])) # spatial dims
267+ slices = [slice (None )] + self .slices [:sd ]
256268 return img [tuple (slices )]
257269
258270
0 commit comments