Skip to content

Commit e56a7a0

Browse files
aobo-yfacebook-github-bot
authored andcommitted
ensure the mask and image size match in MMImageMaskInput (#1673)
Summary: as title Differential Revision: D87877392
1 parent 697b6e8 commit e56a7a0

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

captum/attr/_utils/interpretable_input.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,12 @@ def __init__(
594594
self.n_itp_features = 1
595595
self.mask_id_to_idx = {}
596596
else:
597+
# Validate that mask size matches image size
598+
image_shape = (image.size[1], image.size[0]) # (height, width)
599+
assert (
600+
mask.shape == image_shape
601+
), f"mask shape {mask.shape} must match image shape {image_shape}"
602+
597603
mask_ids = torch.unique(mask)
598604
self.n_itp_features = len(mask_ids)
599605
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}

0 commit comments

Comments
 (0)