We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 697b6e8 commit e56a7a0Copy full SHA for e56a7a0
captum/attr/_utils/interpretable_input.py
@@ -594,6 +594,12 @@ def __init__(
594
self.n_itp_features = 1
595
self.mask_id_to_idx = {}
596
else:
597
+ # Validate that mask size matches image size
598
+ image_shape = (image.size[1], image.size[0]) # (height, width)
599
+ assert (
600
+ mask.shape == image_shape
601
+ ), f"mask shape {mask.shape} must match image shape {image_shape}"
602
+
603
mask_ids = torch.unique(mask)
604
self.n_itp_features = len(mask_ids)
605
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}
0 commit comments