Skip to content

Commit 0066971

Browse files
committed
feat: add mixed precision inference
1 parent e494dc9 commit 0066971

File tree

4 files changed

+40
-15
lines changed

4 files changed

+40
-15
lines changed

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,32 @@ def from_yaml(cls, model: nn.Module, yaml_path: str):
203203
def _infer_batch(self):
204204
raise NotImplementedError
205205

206-
def infer(self) -> None:
206+
def infer(self, mixed_precision: bool = False) -> None:
207207
"""Run inference and post-processing for the images.
208208
209209
NOTE:
210-
- Saves outputs in `self.out_masks` or to disk (.mat/.json) files.
211-
- If `save_intermediate` is set to True, also intermiediate model outputs are
212-
saved to `self.soft_masks`
213-
- `self.out_masks` and `self.soft_masks` are nested dicts: E.g.
214-
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
215-
- If masks are saved to geojson .json files, more key word arguments
210+
- Saves outputs in class attributes or to disk (.mat/.json) files.
211+
- If masks are saved to .json (geojson) files, more key word arguments
216212
need to be given at class initialization. Namely: `geo_format`,
217213
`classes_type`, `classes_sem`, `offsets`. See more in the
218214
`FileHandler.save_masks` docs.
215+
216+
Attributes
217+
----------
218+
- out_masks : Dict[str, Dict[str, np.ndarray]]
219+
The output masks for each image. The keys are the image names and the
220+
values are dictionaries of the masks. E.g.
221+
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
222+
- soft_masks : Dict[str, Dict[str, np.ndarray]]
223+
NOTE: This attribute is set only if `save_intermediate = True`.
224+
The soft masks for each image. I.e. the soft predictions of the trained
225+
model The keys are the image names and the values are dictionaries of
226+
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
227+
228+
Parameters
229+
----------
230+
mixed_precision : bool, default=False
231+
If True, inference is performed with mixed precision.
219232
"""
220233
self.soft_masks = {}
221234
self.out_masks = {}
@@ -227,7 +240,7 @@ def infer(self) -> None:
227240
names = data["file"]
228241
loader.set_description("Running inference")
229242
loader.set_postfix_str("Forward pass")
230-
soft_masks = self._infer_batch(data["im"])
243+
soft_masks = self._infer_batch(data["im"], mixed_precision)
231244
loader.set_postfix_str("post-processing")
232245
soft_masks = self._prepare_mask_list(names, soft_masks)
233246

cellseg_models_pytorch/inference/predictor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ def __init__(
108108
)
109109

110110
def forward_pass(
111-
self, patch: Union[np.ndarray, torch.Tensor], in_dim_format: str = "HWC"
111+
self,
112+
patch: Union[np.ndarray, torch.Tensor],
113+
in_dim_format: str = "HWC",
114+
mixed_precision: bool = False,
112115
) -> Dict[str, torch.Tensor]:
113116
"""Input an image patch or batch of patches to the network and return logits.
114117
@@ -119,7 +122,8 @@ def forward_pass(
119122
in_dim_format : str, default="HWC"
120123
The order of the dimensions in the input array.
121124
One of: "HWC", "BHWC"
122-
125+
mixed_precision : bool, default=False
126+
Use mixed precision for inference.
123127
124128
Returns
125129
-------
@@ -145,7 +149,11 @@ def forward_pass(
145149
patch = patch.float()
146150

147151
with torch.no_grad():
148-
out = self.model(patch)
152+
if mixed_precision:
153+
with torch.autocast(device_type="cuda", dtype=torch.float16):
154+
out = self.model(patch)
155+
else:
156+
out = self.model(patch)
149157

150158
return out
151159

cellseg_models_pytorch/inference/resize_inferer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ def __init__(
139139
**kwargs,
140140
)
141141

142-
def _infer_batch(self, input_batch: torch.Tensor) -> Dict[str, torch.Tensor]:
142+
def _infer_batch(
143+
self, input_batch: torch.Tensor, mixed_precision: bool = False
144+
) -> Dict[str, torch.Tensor]:
143145
"""Infer one batch of images."""
144146
inp_shape = tuple(input_batch.shape[2:])
145147

@@ -154,7 +156,7 @@ def _infer_batch(self, input_batch: torch.Tensor) -> Dict[str, torch.Tensor]:
154156
input_batch = F.interpolate(input_batch, self.patch_size)
155157

156158
batch = input_batch.to(self.device).float()
157-
logits = self.predictor.forward_pass(batch)
159+
logits = self.predictor.forward_pass(batch, mixed_precision=mixed_precision)
158160

159161
probs = {}
160162
for k, logit in logits.items():

cellseg_models_pytorch/inference/sliding_window_inferer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def _get_slices(
193193

194194
return xyslices, pady, padx
195195

196-
def _infer_batch(self, input_batch: torch.Tensor) -> Dict[str, torch.Tensor]:
196+
def _infer_batch(
197+
self, input_batch: torch.Tensor, mixed_precision: bool = False
198+
) -> Dict[str, torch.Tensor]:
197199
"""Infer one batch of images."""
198200
slices, pady, padx = self._get_slices(
199201
self.stride, self.patch_size, tuple(input_batch.shape[2:]), self.padding
@@ -235,7 +237,7 @@ def _infer_batch(self, input_batch: torch.Tensor) -> Dict[str, torch.Tensor]:
235237
# run inference with the slices
236238
for k, (yslice, xslice) in slices.items():
237239
batch = input_batch[..., yslice, xslice].to(self.device).float()
238-
logits = self.predictor.forward_pass(batch)
240+
logits = self.predictor.forward_pass(batch, mixed_precision=mixed_precision)
239241

240242
probs = {}
241243
for k, logit in logits.items():

0 commit comments

Comments
 (0)