1- import warnings
21from abc import ABC , abstractmethod
32from collections import OrderedDict
43from itertools import chain
@@ -33,9 +32,6 @@ def __init__(
3332 normalization : str = None ,
3433 device : str = "cuda" ,
3534 n_devices : int = 1 ,
36- save_intermediate : bool = False ,
37- save_dir : Union [Path , str ] = None ,
38- save_format : str = ".mat" ,
3935 checkpoint_path : Union [Path , str ] = None ,
4036 n_images : int = None ,
4137 type_post_proc : Callable = None ,
@@ -46,57 +42,49 @@ def __init__(
4642
4743 Parameters
4844 ----------
49- model : nn.Module
50- A segmentation model.
51- input_path : Path | str
52- Path to a folder of images or to hdf5 db.
53- out_activations : Dict[str, str]
54- Dictionary of head names mapped to a string value that specifies the
55- activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
56- Allowed values: "softmax", "sigmoid", "tanh", None.
57- out_boundary_weights : Dict[str, bool]
58- Dictionary of head names mapped to a boolean value. If the value is
59- True, after a prediction, a weight matrix is applied that assigns bigger
60- weight on pixels in the center and less weight to pixels on the tile
61- boundaries. helps dealing with prediction artefacts on the boundaries.
62- E.g. {"type": False, "cellpose": True}
63- patch_size : Tuple[int, int]:
64- The size of the input patches that are fed to the segmentation model.
65- instance_postproc : str
66- The post-processing method for the instance segmentation mask. One of:
67- "cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
68- padding : int, optional
69- The amount of reflection padding for the input images.
70- batch_size : int, default=8
71- Number of images loaded from the folder at every batch.
72- normalization : str, optional
73- Apply img normalization at forward pass (Same as during training).
74- One of: "dataset", "minmax", "norm", "percentile", None.
75- device : str, default="cuda"
76- The device of the input and model. One of: "cuda", "cpu"
77- n_devices : int, default=1
78- Number of devices (cpus/gpus) used for inference.
79- The model will be copied into these devices.
80- save_dir : bool, optional
81- Path to save directory. If None, no masks will be saved to disk as .mat
82- or .json files. Instead the masks will be saved in `self.out_masks`.
83- save_intermediate : bool, default=False
84- If True, intermediate soft masks will be saved into `soft_masks` var.
85- save_format : str, default=".mat"
86- The file format for the saved output masks. One of (".mat", ".json").
87- The ".json" option will save masks into geojson format.
88- checkpoint_path : Path | str, optional
89- Path to the model weight checkpoints.
90- n_images : int, optional
91- First n-number of images used from the `input_path`.
92- type_post_proc : Callable, optional
93- A post-processing function for the type maps. If not None, overrides
94- the default.
95- sem_post_proc : Callable, optional
96- A post-processing function for the semantc seg maps. If not None,
97- overrides the default.
98- **kwargs:
99- Arbitrary keyword arguments expecially for post-processing and saving.
45+ model : nn.Module
46+ A segmentation model.
47+ input_path : Path | str
48+ Path to a folder of images or to hdf5 db.
49+ out_activations : Dict[str, str]
50+ Dictionary of head names mapped to a string value that specifies the
51+ activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
52+ Allowed values: "softmax", "sigmoid", "tanh", None.
53+ out_boundary_weights : Dict[str, bool]
54+ Dictionary of head names mapped to a boolean value. If the value is
55+ True, after a prediction, a weight matrix is applied that assigns bigger
56+ weight on pixels in the center and less weight to pixels on the tile
57+ boundaries. helps dealing with prediction artefacts on the boundaries.
58+ E.g. {"type": False, "cellpose": True}
59+ patch_size : Tuple[int, int]:
60+ The size of the input patches that are fed to the segmentation model.
61+ instance_postproc : str
62+ The post-processing method for the instance segmentation mask. One of:
63+ "cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
64+ padding : int, optional
65+ The amount of reflection padding for the input images.
66+ batch_size : int, default=8
67+ Number of images loaded from the folder at every batch.
68+ normalization : str, optional
69+ Apply img normalization at forward pass (Same as during training).
70+ One of: "dataset", "minmax", "norm", "percentile", None.
71+ device : str, default="cuda"
72+ The device of the input and model. One of: "cuda", "cpu"
73+ n_devices : int, default=1
74+ Number of devices (cpus/gpus) used for inference.
75+ The model will be copied into these devices.
76+ checkpoint_path : Path | str, optional
77+ Path to the model weight checkpoints.
78+ n_images : int, optional
79+ First n-number of images used from the `input_path`.
80+ type_post_proc : Callable, optional
81+ A post-processing function for the type maps. If not None, overrides
82+ the default.
83+ sem_post_proc : Callable, optional
84+ A post-processing function for the semantc seg maps. If not None,
85+ overrides the default.
86+ **kwargs:
87+ Arbitrary keyword arguments for post-processing.
10088 """
10189 # basic inits
10290 self .model = model
@@ -109,22 +97,10 @@ def __init__(
10997 self .head_kwargs = self ._check_and_set_head_args ()
11098 self .kwargs = kwargs
11199
112- self .save_dir = Path (save_dir ) if save_dir is not None else None
113- self .save_intermediate = save_intermediate
114- self .save_format = save_format
115-
116100 # dataset & dataloader
117101 self .path = Path (input_path )
118102 if self .path .is_dir ():
119103 ds = FolderDatasetInfer (self .path , n_images = n_images )
120- if self .save_dir is None and len (ds .fnames ) > 40 and n_images is None :
121- warnings .warn (
122- "`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
123- "class attribute. If the input folder contains many images, running"
124- " inference will likely flood the memory depending on the size and "
125- "number of the images. Consider saving outputs to disk by providing"
126- " `save_dir` argument."
127- )
128104 elif self .path .is_file () and self .path .suffix in (".h5" , ".hdf5" ):
129105 from .hdf5_dataset_infer import HDF5DatasetInfer
130106
@@ -167,10 +143,10 @@ def __init__(
167143
168144 # try loading the weights to the model
169145 try :
170- msg = self .model .load_state_dict (state_dict , strict = True )
146+ msg = self .model .load_state_dict (state_dict , strict = False )
171147 except RuntimeError :
172148 new_ckpt = self ._strip_state_dict (state_dict )
173- msg = self .model .load_state_dict (new_ckpt , strict = True )
149+ msg = self .model .load_state_dict (new_ckpt , strict = False )
174150 except BaseException as e :
175151 raise RuntimeError (f"Error when loading checkpoint: { e } " )
176152
@@ -218,34 +194,74 @@ def from_yaml(cls, model: nn.Module, yaml_path: str):
218194 def _infer_batch (self ):
219195 raise NotImplementedError
220196
221- def infer (self , mixed_precision : bool = False ) -> None :
222- """Run inference and post-processing for the images.
197+ def infer (
198+ self ,
199+ save_dir : Union [Path , str ] = None ,
200+ save_format : str = ".mat" ,
201+ save_intermediate : bool = False ,
202+ classes_type : Dict [str , int ] = None ,
203+ classes_sem : Dict [str , int ] = None ,
204+ offsets : bool = False ,
205+ mixed_precision : bool = False ,
206+ ) -> None :
207+ """Run inference and post-processing for the image(s) inside `input_path`.
223208
224- NOTE:
225- - Saves outputs in class attributes or to disk (.mat/.json) files.
226- - If masks are saved to .json (geojson) files, more key word arguments
227- need to be given at class initialization. Namely: `geo_format`,
228- `classes_type`, `classes_sem`, `offsets`. See more in the
229- `FileHandler.save_masks` docs.
209+ NOTE: If `save_dir` is None, the output masks will be cached in a class
210+ attribute `self.out_masks`. Otherwise the masks will be saved to disk.
230211
212+ WARNING: Running inference without setting `save_dir` can take a lot of memory
213+ if the input directory contains many images.
231214
232215 Parameters
233216 ----------
234- mixed_precision : bool, default=False
235- If True, inference is performed with mixed precision.
217+ save_dir : bool, optional
218+ Path to save directory. If None, no masks will be saved to disk.
219+ Instead the masks will be cached in a class attribute `self.out_masks`.
220+ save_format : str, default=".mat"
221+ The file format for the saved output masks. One of ".mat", ".geojson",
222+ "feather" "parquet".
223+ save_intermediate : bool, default=False
224+ If True, intermediate soft masks will be saved into `self.soft_masks`
225+ class attribute. WARNING: This can take a lot of memory if the input
226+ directory contains many images.
227+ classes_type : Dict[str, str], optional
228+ Cell type dictionary. e.g. {"inflam":1, "epithelial":2, "connec":3}.
229+ This is required only if `save_format` is one of the following formats:
230+ ".geojson", ".parquet", ".feather".
231+ classes_sem : Dict[str, str], otional
232+ Tissue type dictionary. e.g. {"tissue1":1, "tissue2":2, "tissue3":3}
233+ This is required only if `save_format` is one of the following formats:
234+ ".geojson", ".parquet", ".feather".
235+ offsets : bool, default=False
236+ If True, geojson coords are shifted by the offsets that are encoded in
237+ the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` == `.mat`.
238+ mixed_precision : bool, default=False
239+ If True, inference is performed with mixed precision.
236240
237241 Attributes
238242 ----------
239- - out_masks : Dict[str, Dict[str, np.ndarray]]
240- The output masks for each image. The keys are the image names and the
241- values are dictionaries of the masks. E.g.
242- {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
243- - soft_masks : Dict[str, Dict[str, np.ndarray]]
244- NOTE: This attribute is set only if `save_intermediate = True`.
245- The soft masks for each image. I.e. the soft predictions of the trained
246- model The keys are the image names and the values are dictionaries of
247- the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
243+ - out_masks : Dict[str, Dict[str, np.ndarray]]
244+ The output masks for each image. The keys are the image names and the
245+ values are dictionaries of the masks. E.g.
246+ {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
247+ - soft_masks : Dict[str, Dict[str, np.ndarray]]
248+ NOTE: This attribute is set only if `save_intermediate = True`.
249+ The soft masks for each image. I.e. the soft predictions of the trained
250+ model The keys are the image names and the values are dictionaries of
251+ the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
248252 """
253+ # check save_dir and save_format
254+ save_dir = Path (save_dir ) if save_dir is not None else None
255+ save_intermediate = save_intermediate
256+ save_format = save_format
257+ if save_dir is not None :
258+ allowed_formats = (".mat" , ".geojson" , ".feather" , ".parquet" )
259+ if save_format not in allowed_formats :
260+ raise ValueError (
261+ f"Given `save_format`: { save_format } is not one of the allowed "
262+ f"formats: { allowed_formats } "
263+ )
264+
249265 self .soft_masks = {}
250266 self .out_masks = {}
251267 self .elapsed = []
@@ -271,7 +287,7 @@ def infer(self, mixed_precision: bool = False) -> None:
271287 self .elapsed .append (loader .format_dict ["elapsed" ])
272288 self .rate .append (loader .format_dict ["rate" ])
273289
274- if self . save_intermediate :
290+ if save_intermediate :
275291 for n , m in zip (names , soft_masks ):
276292 self .soft_masks [n ] = m
277293
@@ -283,25 +299,33 @@ def infer(self, mixed_precision: bool = False) -> None:
283299 seg ["soft_sem" ] = soft ["sem" ]
284300
285301 # save to cache or disk
286- if self . save_dir is None :
302+ if save_dir is None :
287303 for n , m in zip (names , seg_results ):
288304 self .out_masks [n ] = m
289305 else :
290306 loader .set_postfix_str ("Saving results to disk" )
291307 if self .batch_size > 1 :
292- fnames = [Path (self . save_dir ) / n for n in names ]
308+ fnames = [Path (save_dir ) / n for n in names ]
293309 FileHandler .save_masks_parallel (
294310 maps = seg_results ,
295311 fnames = fnames ,
296- ** {** self .kwargs , "format" : self .save_format },
312+ format = save_format ,
313+ classes_type = classes_type ,
314+ classes_sem = classes_sem ,
315+ offsets = offsets ,
316+ pooltype = "thread" ,
317+ maptype = "amap" ,
297318 )
298319 else :
299320 for n , m in zip (names , seg_results ):
300- fname = Path (self . save_dir ) / n
321+ fname = Path (save_dir ) / n
301322 FileHandler .save_masks (
302323 fname = fname ,
303324 maps = m ,
304- ** {** self .kwargs , "format" : self .save_format },
325+ format = save_format ,
326+ classes_type = classes_type ,
327+ classes_sem = classes_sem ,
328+ offsets = offsets ,
305329 )
306330
307331 def _strip_state_dict (self , ckpt : Dict ) -> OrderedDict :
0 commit comments