Skip to content

Commit 0d56ee5

Browse files
aobo-yfacebook-github-bot
authored andcommitted
set values in MMImageMaskInput
Summary: previously `InterpretableInput` is only used for text, the `values` are the actual text segments, which are used in `AttributionResult` for plot legends. This design no longer makes sense for image. Will redesign the `values`. But for now, just set `values` for plot legends Differential Revision: D87891825
1 parent f55cd3d commit 0d56ee5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

captum/attr/_utils/interpretable_input.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ class MMImageMaskInput(InterpretableInput):
574574
n_itp_features: int
575575
original_model_inputs: Any
576576
mask_id_to_idx: Dict[int, int]
577-
values: List[str] = [] # no use for now
577+
values: List[str]
578578

579579
def __init__(
580580
self,
@@ -608,6 +608,10 @@ def __init__(
608608

609609
self.original_model_inputs = processor_fn(image)
610610

611+
# temporarily for compatibility with AttributionResult
612+
# which use the values for plot legends
613+
self.values = [f"image_feature_{mid}" for mid in mask_ids]
614+
611615
def to_tensor(self) -> Tensor:
612616
return torch.tensor([[1.0] * self.n_itp_features])
613617

0 commit comments

Comments
 (0)