Skip to content

Commit 7a760e6

Browse files
authored
6766 data analyser label argmax (#6852)
Fixes #6766 ### Description the label might have been processed during the GPU transform, in the retry on CPU, the argmax should be skipped in this case. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 3990cd4 commit 7a760e6

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

monai/apps/auto3dseg/data_analyzer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,11 @@ def _get_all_case_stats(
332332
batch_data = batch_data[0]
333333
try:
334334
batch_data[self.image_key] = batch_data[self.image_key].to(device)
335+
_label_argmax = False
335336
if self.label_key is not None:
336337
label = batch_data[self.label_key]
337338
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
339+
_label_argmax = True # track if label is argmaxed
338340
batch_data[self.label_key] = label.to(device)
339341
d = summarizer(batch_data)
340342
except BaseException as err:
@@ -348,7 +350,8 @@ def _get_all_case_stats(
348350
batch_data[self.image_key] = batch_data[self.image_key].to("cpu")
349351
if self.label_key is not None:
350352
label = batch_data[self.label_key]
351-
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
353+
if not _label_argmax:
354+
label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]
352355
batch_data[self.label_key] = label.to("cpu")
353356
d = summarizer(batch_data)
354357

0 commit comments

Comments
 (0)