Skip to content

Commit d8d84e5

Browse files
aobo-ymeta-codesync[bot]
authored andcommitted
ensure the mask and image size match in MMImageMaskInput (#1673)
Summary: Pull Request resolved: #1673 as title Reviewed By: craymichael Differential Revision: D87877392 fbshipit-source-id: 4a843d0fc7a0775d65d53c7735a747b3c90e00e1
1 parent 91118b6 commit d8d84e5

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

captum/attr/_utils/interpretable_input.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,12 @@ def __init__(
592592
self.n_itp_features = 1
593593
self.mask_id_to_idx = {}
594594
else:
595+
# Validate that mask size matches image size
596+
image_shape = (image.size[1], image.size[0]) # (height, width)
597+
assert (
598+
mask.shape == image_shape
599+
), f"mask shape {mask.shape} must match image shape {image_shape}"
600+
595601
mask_ids = torch.unique(mask)
596602
self.n_itp_features = len(mask_ids)
597603
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}

0 commit comments

Comments
 (0)