diff --git a/captum/attr/_utils/interpretable_input.py b/captum/attr/_utils/interpretable_input.py index 0d6a55a7f..6b0ee8cb7 100644 --- a/captum/attr/_utils/interpretable_input.py +++ b/captum/attr/_utils/interpretable_input.py @@ -572,7 +572,7 @@ class MMImageMaskInput(InterpretableInput): n_itp_features: int original_model_inputs: Any mask_id_to_idx: Dict[int, int] - values: List[str] = [] # no use for now + values: List[str] def __init__( self, @@ -606,6 +606,10 @@ def __init__( self.original_model_inputs = processor_fn(image) + # temporarily for compatibility with AttributionResult + # which use the values for plot legends + self.values = [f"image_feature_{mid}" for mid in mask_ids] + def to_tensor(self) -> Tensor: return torch.tensor([[1.0] * self.n_itp_features])