Skip to content

Commit d23221f

Browse files
6442 fix thread safe issue of DynUNet (#6444)
Fixes #6442 . ### Description This PR is used to fix the thread safe issue of DynUNet. Created list in forward function is replaced by a tensor. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent bb088ec commit d23221f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

monai/networks/nets/dynunet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,11 @@ def forward(self, x):
269269
out = self.skip_layers(x)
270270
out = self.output_block(out)
271271
if self.training and self.deep_supervision:
272-
out_all = [out]
273-
for feature_map in self.heads:
274-
out_all.append(interpolate(feature_map, out.shape[2:]))
275-
return torch.stack(out_all, dim=1)
272+
out_all = torch.zeros(out.shape[0], len(self.heads) + 1, *out.shape[1:], device=out.device, dtype=out.dtype)
273+
out_all[:, 0] = out
274+
for idx, feature_map in enumerate(self.heads):
275+
out_all[:, idx + 1] = interpolate(feature_map, out.shape[2:])
276+
return out_all
276277
return out
277278

278279
def get_input_block(self):

0 commit comments

Comments
 (0)