Skip to content

Commit bb088ec

Browse files
authored
SlidingWindowInfererAdapt fixes (#6440)
- auto adjust buffer_dim for images with small last time, such as 768x768x128 CTs - fixes a small bug when buffered mode is not attempted ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: myron <amyronenko@nvidia.com>
1 parent 5f344cc commit bb088ec

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

monai/inferers/inferer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ def __call__(
461461

462462
device = kwargs.pop("device", self.device)
463463
buffer_steps = kwargs.pop("buffer_steps", self.buffer_steps)
464+
buffer_dim = kwargs.pop("buffer_dim", self.buffer_dim)
464465

465466
if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:
466467
device = "cpu" # stitch in cpu memory if image is too large
@@ -481,7 +482,7 @@ def __call__(
481482
self.roi_weight_map,
482483
None,
483484
buffer_steps,
484-
self.buffer_dim,
485+
buffer_dim,
485486
*args,
486487
**kwargs,
487488
)
@@ -524,6 +525,12 @@ def __call__(
524525
gpu_stitching = inputs.is_cuda and not cpu_cond
525526
buffered_stitching = inputs.is_cuda and cpu_cond and not skip_buffer
526527
buffer_steps = max(1, self.buffer_steps) if self.buffer_steps is not None else 1
528+
buffer_dim = -1
529+
530+
sh = list(inputs.shape[2:])
531+
max_dim = sh.index(max(sh))
532+
if inputs.shape[max_dim + 2] / inputs.shape[-1] >= 2:
533+
buffer_dim = max_dim
527534

528535
for _ in range(10): # at most 10 trials
529536
try:
@@ -532,6 +539,7 @@ def __call__(
532539
network,
533540
device=inputs.device if gpu_stitching else torch.device("cpu"),
534541
buffer_steps=buffer_steps if buffered_stitching else None,
542+
buffer_dim=buffer_dim,
535543
*args,
536544
**kwargs,
537545
)
@@ -547,24 +555,23 @@ def __call__(
547555

548556
if skip_buffer:
549557
buffered_stitching = False
550-
logger.warning(f"GPU stitching failed, attempting on CPU, image dim {inputs.shape}..")
558+
logger.warning(f"GPU stitching failed, attempting on CPU, image dim {inputs.shape}.")
551559

552560
else:
553561
buffered_stitching = True
554562
self.buffer_steps = buffer_steps
555563
logger.warning(
556-
f"GPU stitching failed, attempting with buffer {buffer_steps}, image dim {inputs.shape}.."
564+
f"GPU stitching failed, buffer {buffer_steps} dim {buffer_dim}, image dim {inputs.shape}."
557565
)
558566
elif buffer_steps > 1:
559567
buffer_steps = max(1, buffer_steps // 2)
560568
self.buffer_steps = buffer_steps
561569
logger.warning(
562-
f"GPU buffered stitching failed, image dim {inputs.shape} reducing buffer to {buffer_steps}"
570+
f"GPU buffered stitching failed, image dim {inputs.shape} reducing buffer to {buffer_steps}."
563571
)
564572
else:
565573
buffered_stitching = False
566-
self.buffer_steps = 0 # disable future buffer attempts
567-
logger.warning(f"GPU buffered stitching failed, attempting on CPU, image dim {inputs.shape}")
574+
logger.warning(f"GPU buffered stitching failed, attempting on CPU, image dim {inputs.shape}.")
568575
raise RuntimeError( # not possible to finish after the trials
569576
f"SlidingWindowInfererAdapt {skip_buffer} {cpu_cond} {gpu_stitching} {buffered_stitching} {buffer_steps}"
570577
)

0 commit comments

Comments
 (0)