Skip to content

Commit fa1c1af

Browse files
Fix RandomWeightedCrop for Integer Weightmap Handling (#8097)
Fixes #7949 . ### Description Regardless of the type of `weight map`, random numbers should be kept as floating-point numbers for calculating the sampling location. However, `searchsorted` requires matching data structures. I have modified `convert_to_dst_type` to control converting only the data structure while maintaining the original data type. Additionally, I have included an example with integer weight maps in the test file. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han123su <popsmall212@gmail.com> Signed-off-by: Han123su <107395380+Han123su@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d2d492e commit fa1c1af

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

monai/transforms/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,8 @@ def weighted_patch_samples(
582582
if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling
583583
idx = r_state.randint(0, len(v), size=n_samples)
584584
else:
585-
r, *_ = convert_to_dst_type(r_state.random(n_samples), v)
585+
r_samples = r_state.random(n_samples)
586+
r, *_ = convert_to_dst_type(r_samples, v, dtype=r_samples.dtype)
586587
idx = searchsorted(v, r * v[-1], right=True) # type: ignore
587588
idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore
588589
# compensate 'valid' mode

tests/test_rand_weighted_crop.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def get_data(ndim):
9090
[[63, 37], [31, 43], [66, 20]],
9191
]
9292
)
93+
im = SEG1_2D
94+
weight_map = np.zeros_like(im, dtype=np.int32)
95+
weight_map[0, 30, 20] = 3
96+
weight_map[0, 45, 44] = 1
97+
weight_map[0, 60, 50] = 2
98+
TESTS.append(
99+
[
100+
"int w 2d",
101+
dict(spatial_size=(10, 12), num_samples=3),
102+
p(im),
103+
q(weight_map),
104+
(1, 10, 12),
105+
[[60, 50], [30, 20], [45, 44]],
106+
]
107+
)
93108
im = SEG1_3D
94109
weight = np.zeros_like(im)
95110
weight[0, 5, 30, 17] = 1.1
@@ -149,6 +164,21 @@ def get_data(ndim):
149164
[[32, 24, 40], [32, 24, 40], [32, 24, 40]],
150165
]
151166
)
167+
im = SEG1_3D
168+
weight_map = np.zeros_like(im, dtype=np.int32)
169+
weight_map[0, 6, 22, 19] = 4
170+
weight_map[0, 8, 40, 31] = 2
171+
weight_map[0, 13, 20, 24] = 3
172+
TESTS.append(
173+
[
174+
"int w 3d",
175+
dict(spatial_size=(8, 10, 12), num_samples=3),
176+
p(im),
177+
q(weight_map),
178+
(1, 8, 10, 12),
179+
[[13, 20, 24], [6, 22, 19], [8, 40, 31]],
180+
]
181+
)
152182

153183

154184
class TestRandWeightedCrop(CropTest):

0 commit comments

Comments
 (0)