@@ -537,20 +537,20 @@ class MMImageMaskInput(InterpretableInput):
537537 >>>
538538 >>> prompt = processor.apply_chat_template(
539539 >>> messages, add_generation_prompt=True
540- >>>)
540+ >>> )
541541 >>>
542542 >>> return processor(
543543 >>> text=prompt,
544544 >>> images=image,
545545 >>> return_tensors="pt",
546546 >>> ).to(model.device)
547-
547+ >>>
548548 >>> image = Image.open("test.jpg")
549-
549+ >>>
550550 >>> # Split horizontally: left half = 0, right half = 1
551551 >>> mask = torch.zeros(image.size[::-1], dtype=torch.int32)
552552 >>> mask[:, image.size[0] // 2:] = 1
553-
553+ >>>
554554 >>> image_mask_inp = MMImageMaskInput(
555555 >>> processor_fn=processor_fn,
556556 >>> image=image,
@@ -567,7 +567,7 @@ class MMImageMaskInput(InterpretableInput):
567567
568568 processor_fn : Callable [[PIL .Image .Image ], Any ]
569569 image : PIL .Image .Image
570- mask : Optional [ Tensor ]
570+ mask : Tensor
571571 baselines : Tuple [int , int , int ]
572572 n_itp_features : int
573573 original_model_inputs : Any
@@ -585,22 +585,24 @@ def __init__(
585585
586586 self .processor_fn = processor_fn
587587 self .image = image
588- self .mask = mask
589588 self .baselines = baselines
590589
590+ # Create a dummy mask if None is provided
591591 if mask is None :
592- self .n_itp_features = 1
593- self .mask_id_to_idx = {}
592+ # Create a mask with all zeros (entire image as one segment)
593+ image_shape = (image .size [1 ], image .size [0 ]) # (height, width)
594+ mask = torch .zeros (image_shape , dtype = torch .int32 )
594595 else :
595596 # Validate that mask size matches image size
596597 image_shape = (image .size [1 ], image .size [0 ]) # (height, width)
597598 assert (
598599 mask .shape == image_shape
599600 ), f"mask shape { mask .shape } must match image shape { image_shape } "
600601
601- mask_ids = torch .unique (mask )
602- self .n_itp_features = len (mask_ids )
603- self .mask_id_to_idx = {int (mid ): i for i , mid in enumerate (mask_ids )}
602+ self .mask = mask
603+ mask_ids = torch .unique (mask )
604+ self .n_itp_features = len (mask_ids )
605+ self .mask_id_to_idx = {int (mid ): i for i , mid in enumerate (mask_ids )}
604606
605607 self .original_model_inputs = processor_fn (image )
606608
@@ -613,14 +615,10 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
613615
614616 img_array = np .array (self .image )
615617
616- if self .mask is None :
617- if perturbed_tensor [0 ][0 ] == 0 :
618- img_array [:, :] = self .baselines
619- else :
620- for mask_id , itp_idx in self .mask_id_to_idx .items ():
621- if perturbed_tensor [0 ][itp_idx ] == 0 :
622- mask_positions = self .mask == mask_id
623- img_array [mask_positions ] = self .baselines
618+ for mask_id , itp_idx in self .mask_id_to_idx .items ():
619+ if perturbed_tensor [0 ][itp_idx ] == 0 :
620+ mask_positions = self .mask == mask_id
621+ img_array [mask_positions ] = self .baselines
624622
625623 perturbed_image = PIL .Image .fromarray (img_array .astype ("uint8" ))
626624
@@ -629,18 +627,11 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Any:
629627 def format_attr (self , itp_attr : Tensor ) -> Tensor :
630628 device = itp_attr .device
631629
632- if self .mask is None :
633- # When mask is None, treat entire image as one segment
634- # Create a uniform mask of all zeros to broadcast the single attribution
635- img_array = np .array (self .image )
636- image_shape = img_array .shape [:2 ] # (height, width)
637- formatted_mask = torch .zeros (image_shape , dtype = torch .long , device = device )
638- else :
639- # Map mask IDs to continuous indices
640- image_shape = self .mask .shape
641- formatted_mask = torch .zeros_like (self .mask , device = device )
642- for mask_id , itp_idx in self .mask_id_to_idx .items ():
643- formatted_mask [self .mask == mask_id ] = itp_idx
630+ # Map mask IDs to continuous indices
631+ image_shape = self .mask .shape
632+ formatted_mask = torch .zeros_like (self .mask , device = device )
633+ for mask_id , itp_idx in self .mask_id_to_idx .items ():
634+ formatted_mask [self .mask == mask_id ] = itp_idx
644635
645636 formatted_attr = _scatter_itp_attr_by_mask (
646637 itp_attr ,
0 commit comments