We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 91118b6 commit d8d84e5Copy full SHA for d8d84e5
captum/attr/_utils/interpretable_input.py
@@ -592,6 +592,12 @@ def __init__(
592
self.n_itp_features = 1
593
self.mask_id_to_idx = {}
594
else:
595
+ # Validate that mask size matches image size
596
+ image_shape = (image.size[1], image.size[0]) # (height, width)
597
+ assert (
598
+ mask.shape == image_shape
599
+ ), f"mask shape {mask.shape} must match image shape {image_shape}"
600
+
601
mask_ids = torch.unique(mask)
602
self.n_itp_features = len(mask_ids)
603
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}
0 commit comments