Skip to content

Commit ff0eaf3

Browse files
aobo-ymeta-codesync[bot]
authored andcommitted
beautify the handling of mask of None in MMImageMaskInput (#1674)
Summary: Pull Request resolved: #1674 if `mask` is `None`, the whole `image` is seen as one interpretable feature. Initiate a dummy `mask` whose pixels all belongs to feature id `0`, so the following code can be simplified by assuming `mask` is given Reviewed By: craymichael Differential Revision: D87890030 fbshipit-source-id: 8743f76464433c2bc4c9a8d7b4a193dcded48920
1 parent d8d84e5 commit ff0eaf3

File tree

2 files changed

+30
-34
lines changed

2 files changed

+30
-34
lines changed

captum/attr/_utils/interpretable_input.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -537,20 +537,20 @@ class MMImageMaskInput(InterpretableInput):
537537
>>>
538538
>>> prompt = processor.apply_chat_template(
539539
>>> messages, add_generation_prompt=True
540-
>>>)
540+
>>> )
541541
>>>
542542
>>> return processor(
543543
>>> text=prompt,
544544
>>> images=image,
545545
>>> return_tensors="pt",
546546
>>> ).to(model.device)
547-
547+
>>>
548548
>>> image = Image.open("test.jpg")
549-
549+
>>>
550550
>>> # Split horizontally: left half = 0, right half = 1
551551
>>> mask = torch.zeros(image.size[::-1], dtype=torch.int32)
552552
>>> mask[:, image.size[0] // 2:] = 1
553-
553+
>>>
554554
>>> image_mask_inp = MMImageMaskInput(
555555
>>> processor_fn=processor_fn,
556556
>>> image=image,
@@ -567,7 +567,7 @@ class MMImageMaskInput(InterpretableInput):
567567

568568
processor_fn: Callable[[PIL.Image.Image], Any]
569569
image: PIL.Image.Image
570-
mask: Optional[Tensor]
570+
mask: Tensor
571571
baselines: Tuple[int, int, int]
572572
n_itp_features: int
573573
original_model_inputs: Any
@@ -585,22 +585,24 @@ def __init__(
585585

586586
self.processor_fn = processor_fn
587587
self.image = image
588-
self.mask = mask
589588
self.baselines = baselines
590589

590+
# Create a dummy mask if None is provided
591591
if mask is None:
592-
self.n_itp_features = 1
593-
self.mask_id_to_idx = {}
592+
# Create a mask with all zeros (entire image as one segment)
593+
image_shape = (image.size[1], image.size[0]) # (height, width)
594+
mask = torch.zeros(image_shape, dtype=torch.int32)
594595
else:
595596
# Validate that mask size matches image size
596597
image_shape = (image.size[1], image.size[0]) # (height, width)
597598
assert (
598599
mask.shape == image_shape
599600
), f"mask shape {mask.shape} must match image shape {image_shape}"
600601

601-
mask_ids = torch.unique(mask)
602-
self.n_itp_features = len(mask_ids)
603-
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}
602+
self.mask = mask
603+
mask_ids = torch.unique(mask)
604+
self.n_itp_features = len(mask_ids)
605+
self.mask_id_to_idx = {int(mid): i for i, mid in enumerate(mask_ids)}
604606

605607
self.original_model_inputs = processor_fn(image)
606608

@@ -613,14 +615,10 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
613615

614616
img_array = np.array(self.image)
615617

616-
if self.mask is None:
617-
if perturbed_tensor[0][0] == 0:
618-
img_array[:, :] = self.baselines
619-
else:
620-
for mask_id, itp_idx in self.mask_id_to_idx.items():
621-
if perturbed_tensor[0][itp_idx] == 0:
622-
mask_positions = self.mask == mask_id
623-
img_array[mask_positions] = self.baselines
618+
for mask_id, itp_idx in self.mask_id_to_idx.items():
619+
if perturbed_tensor[0][itp_idx] == 0:
620+
mask_positions = self.mask == mask_id
621+
img_array[mask_positions] = self.baselines
624622

625623
perturbed_image = PIL.Image.fromarray(img_array.astype("uint8"))
626624

@@ -629,18 +627,11 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
629627
def format_attr(self, itp_attr: Tensor) -> Tensor:
630628
device = itp_attr.device
631629

632-
if self.mask is None:
633-
# When mask is None, treat entire image as one segment
634-
# Create a uniform mask of all zeros to broadcast the single attribution
635-
img_array = np.array(self.image)
636-
image_shape = img_array.shape[:2] # (height, width)
637-
formatted_mask = torch.zeros(image_shape, dtype=torch.long, device=device)
638-
else:
639-
# Map mask IDs to continuous indices
640-
image_shape = self.mask.shape
641-
formatted_mask = torch.zeros_like(self.mask, device=device)
642-
for mask_id, itp_idx in self.mask_id_to_idx.items():
643-
formatted_mask[self.mask == mask_id] = itp_idx
630+
# Map mask IDs to continuous indices
631+
image_shape = self.mask.shape
632+
formatted_mask = torch.zeros_like(self.mask, device=device)
633+
for mask_id, itp_idx in self.mask_id_to_idx.items():
634+
formatted_mask[self.mask == mask_id] = itp_idx
644635

645636
formatted_attr = _scatter_itp_attr_by_mask(
646637
itp_attr,

tests/attr/test_interpretable_input.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-unsafe
44

5-
from typing import Any, Dict, List, Literal, Optional, overload, Union
5+
from typing import Dict, List, Literal, Optional, overload, Union
66

77
import numpy as np
88
import PIL.Image
@@ -260,9 +260,14 @@ def test_init_without_mask(self) -> None:
260260
)
261261

262262
# Assert: verify n_itp_features is 1 when no mask provided
263+
# When mask is None, a dummy mask with all zeros is created
263264
self.assertEqual(mm_input.n_itp_features, 1)
264-
self.assertEqual(len(mm_input.mask_id_to_idx), 0)
265-
self.assertIsNone(mm_input.mask)
265+
self.assertEqual(mm_input.mask_id_to_idx, {0: 0})
266+
self.assertIsNotNone(mm_input.mask)
267+
# Verify dummy mask has all zeros
268+
self.assertTrue(torch.all(mm_input.mask == 0))
269+
# Verify dummy mask shape matches image size (height, width)
270+
self.assertEqual(mm_input.mask.shape, (image.size[1], image.size[0]))
266271

267272
def test_init_with_mask(self) -> None:
268273
# Setup: create test image and mask with 2 segments

0 commit comments

Comments
 (0)