Skip to content

Commit 91118b6

Browse files
aobo-ymeta-codesync[bot]
authored andcommitted
Add tests for MMImageMaskInput (#1672)
Summary: Pull Request resolved: #1672 add tests for `MMImageMaskInput` Reviewed By: craymichael Differential Revision: D87822368 fbshipit-source-id: 61e48f7c778846c77abaecce1fdfac726a8277a8
1 parent 3b43fb4 commit 91118b6

File tree

1 file changed

+287
-2
lines changed

1 file changed

+287
-2
lines changed

tests/attr/test_interpretable_input.py

Lines changed: 287 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22

33
# pyre-unsafe
44

5-
from typing import List, Literal, Optional, overload, Union
5+
from typing import Any, Dict, List, Literal, Optional, overload, Union
66

7+
import numpy as np
8+
import PIL.Image
79
import torch
810
from captum._utils.typing import BatchEncodingType
9-
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
11+
from captum.attr._utils.interpretable_input import (
12+
MMImageMaskInput,
13+
TextTemplateInput,
14+
TextTokenInput,
15+
)
1016
from captum.testing.helpers import BaseTest
1117
from captum.testing.helpers.basic import assertTensorAlmostEqual
1218
from parameterized import parameterized
@@ -228,3 +234,282 @@ def test_input_with_skip_tokens(self) -> None:
228234
assertTensorAlmostEqual(
229235
self, tt_input.to_model_input(perturbed_tensor), expected_perturbed_inp
230236
)
237+
238+
239+
class TestMMImageMaskInput(BaseTest):
240+
def _create_test_image(
241+
self, width: int = 10, height: int = 10, color: tuple = (255, 0, 0)
242+
) -> PIL.Image.Image:
243+
"""Helper method to create a test PIL image."""
244+
img_array = np.full((height, width, 3), color, dtype=np.uint8)
245+
return PIL.Image.fromarray(img_array)
246+
247+
def _simple_processor(self, image: PIL.Image.Image) -> Dict[str, Tensor]:
248+
"""Simple test processor that converts image to tensor."""
249+
img_array = np.array(image)
250+
return {"pixel_values": torch.from_numpy(img_array).float()}
251+
252+
def test_init_without_mask(self) -> None:
253+
# Setup: create test image and processor
254+
image = self._create_test_image()
255+
256+
# Execute: create MMImageMaskInput without mask
257+
mm_input = MMImageMaskInput(
258+
processor_fn=self._simple_processor,
259+
image=image,
260+
)
261+
262+
# Assert: verify n_itp_features is 1 when no mask provided
263+
self.assertEqual(mm_input.n_itp_features, 1)
264+
self.assertEqual(len(mm_input.mask_id_to_idx), 0)
265+
self.assertIsNone(mm_input.mask)
266+
267+
def test_init_with_mask(self) -> None:
268+
# Setup: create test image and mask with 2 segments
269+
image = self._create_test_image()
270+
mask = torch.zeros((10, 10), dtype=torch.int32)
271+
mask[:, 5:] = 1 # Split horizontally into 2 segments
272+
273+
# Execute: create MMImageMaskInput with mask
274+
mm_input = MMImageMaskInput(
275+
processor_fn=self._simple_processor,
276+
image=image,
277+
mask=mask,
278+
)
279+
280+
# Assert: verify n_itp_features matches number of unique mask values
281+
self.assertEqual(mm_input.n_itp_features, 2)
282+
self.assertEqual(len(mm_input.mask_id_to_idx), 2)
283+
self.assertIn(0, mm_input.mask_id_to_idx)
284+
self.assertIn(1, mm_input.mask_id_to_idx)
285+
286+
def test_init_with_non_continuous_mask_ids(self) -> None:
287+
# Setup: create mask with non-continuous IDs (e.g., 5, 10, 15)
288+
image = self._create_test_image(width=15, height=10)
289+
mask = torch.zeros((10, 15), dtype=torch.int32)
290+
mask[:, 5:10] = 5
291+
mask[:, 10:] = 10
292+
293+
# Execute: create MMImageMaskInput
294+
mm_input = MMImageMaskInput(
295+
processor_fn=self._simple_processor,
296+
image=image,
297+
mask=mask,
298+
)
299+
300+
# Assert: verify mask_id_to_idx creates continuous mapping
301+
self.assertEqual(mm_input.n_itp_features, 3)
302+
self.assertEqual(len(mm_input.mask_id_to_idx), 3)
303+
# Verify all mask IDs are mapped to continuous indices 0, 1, 2
304+
mapped_indices = set(mm_input.mask_id_to_idx.values())
305+
self.assertEqual(mapped_indices, {0, 1, 2})
306+
307+
def test_to_tensor_without_mask(self) -> None:
308+
# Setup: create MMImageMaskInput without mask
309+
image = self._create_test_image()
310+
mm_input = MMImageMaskInput(
311+
processor_fn=self._simple_processor,
312+
image=image,
313+
)
314+
315+
# Execute: convert to tensor
316+
result = mm_input.to_tensor()
317+
318+
# Assert: verify tensor shape and values for single feature
319+
expected = torch.tensor([[1.0]])
320+
assertTensorAlmostEqual(self, result, expected)
321+
322+
def test_to_tensor_with_mask(self) -> None:
323+
# Setup: create MMImageMaskInput with 3 segments
324+
image = self._create_test_image(width=15)
325+
mask = torch.zeros((10, 15), dtype=torch.int32)
326+
mask[:, 5:10] = 1
327+
mask[:, 10:] = 2
328+
329+
mm_input = MMImageMaskInput(
330+
processor_fn=self._simple_processor,
331+
image=image,
332+
mask=mask,
333+
)
334+
335+
# Execute: convert to tensor
336+
result = mm_input.to_tensor()
337+
338+
# Assert: verify tensor has correct number of features
339+
expected = torch.tensor([[1.0, 1.0, 1.0]])
340+
assertTensorAlmostEqual(self, result, expected)
341+
342+
def test_to_model_input_without_perturbation(self) -> None:
343+
# Setup: create MMImageMaskInput
344+
image = self._create_test_image()
345+
mm_input = MMImageMaskInput(
346+
processor_fn=self._simple_processor,
347+
image=image,
348+
)
349+
350+
# Execute: get model input without perturbation
351+
result = mm_input.to_model_input()
352+
353+
# Assert: verify returns original model inputs
354+
self.assertIn("pixel_values", result)
355+
assertTensorAlmostEqual(
356+
self, result["pixel_values"], mm_input.original_model_inputs["pixel_values"]
357+
)
358+
359+
def test_to_model_input_with_perturbation_no_mask_present(self) -> None:
360+
# Setup: create red image without mask
361+
image = self._create_test_image(color=(255, 0, 0))
362+
mm_input = MMImageMaskInput(
363+
processor_fn=self._simple_processor,
364+
image=image,
365+
baselines=(255, 255, 255), # white baseline
366+
)
367+
368+
# Execute: perturb with feature present (value 1)
369+
perturbed_tensor = torch.tensor([[1.0]])
370+
result = mm_input.to_model_input(perturbed_tensor)
371+
372+
# Assert: image should remain red (unchanged)
373+
img_array = result["pixel_values"].numpy().astype(np.uint8)
374+
self.assertTrue(np.all(img_array[:, :, 0] == 255))
375+
self.assertTrue(np.all(img_array[:, :, 1] == 0))
376+
self.assertTrue(np.all(img_array[:, :, 2] == 0))
377+
378+
def test_to_model_input_with_perturbation_no_mask_absent(self) -> None:
379+
# Setup: create red image without mask
380+
image = self._create_test_image(color=(255, 0, 0))
381+
mm_input = MMImageMaskInput(
382+
processor_fn=self._simple_processor,
383+
image=image,
384+
baselines=(255, 255, 255), # white baseline
385+
)
386+
387+
# Execute: perturb with feature absent (value 0)
388+
perturbed_tensor = torch.tensor([[0.0]])
389+
result = mm_input.to_model_input(perturbed_tensor)
390+
391+
# Assert: entire image should be white (baseline)
392+
img_array = result["pixel_values"].numpy().astype(np.uint8)
393+
self.assertTrue(np.all(img_array == 255))
394+
395+
def test_to_model_input_with_mask_partial_perturbation(self) -> None:
396+
# Setup: create image with 2 segments (left red, right green)
397+
img_array = np.zeros((10, 10, 3), dtype=np.uint8)
398+
img_array[:, :5] = [255, 0, 0] # Left half red
399+
img_array[:, 5:] = [0, 255, 0] # Right half green
400+
image = PIL.Image.fromarray(img_array)
401+
402+
mask = torch.zeros((10, 10), dtype=torch.int32)
403+
mask[:, 5:] = 1 # Right half is segment 1
404+
405+
mm_input = MMImageMaskInput(
406+
processor_fn=self._simple_processor,
407+
image=image,
408+
mask=mask,
409+
baselines=(255, 255, 255), # white baseline
410+
)
411+
412+
# Execute: perturb to keep left segment (0) but remove right segment (1)
413+
perturbed_tensor = torch.tensor([[1.0, 0.0]])
414+
result = mm_input.to_model_input(perturbed_tensor)
415+
416+
# Assert: left half should be red, right half should be white
417+
img_array = result["pixel_values"].numpy().astype(np.uint8)
418+
# Left half should be red
419+
self.assertTrue(np.all(img_array[:, :5, 0] == 255))
420+
self.assertTrue(np.all(img_array[:, :5, 1] == 0))
421+
# Right half should be white (baseline)
422+
self.assertTrue(np.all(img_array[:, 5:] == 255))
423+
424+
def test_to_model_input_with_custom_baselines(self) -> None:
425+
# Setup: create image with custom baseline color
426+
image = self._create_test_image(color=(255, 0, 0))
427+
mm_input = MMImageMaskInput(
428+
processor_fn=self._simple_processor,
429+
image=image,
430+
baselines=(0, 128, 255), # Custom blue-ish baseline
431+
)
432+
433+
# Execute: perturb to remove feature
434+
perturbed_tensor = torch.tensor([[0.0]])
435+
result = mm_input.to_model_input(perturbed_tensor)
436+
437+
# Assert: image should have custom baseline color
438+
img_array = result["pixel_values"].numpy().astype(np.uint8)
439+
self.assertTrue(np.all(img_array[:, :, 0] == 0))
440+
self.assertTrue(np.all(img_array[:, :, 1] == 128))
441+
self.assertTrue(np.all(img_array[:, :, 2] == 255))
442+
443+
def test_format_attr_without_mask(self) -> None:
444+
# Setup: create MMImageMaskInput without mask
445+
image = self._create_test_image(width=5, height=5)
446+
mm_input = MMImageMaskInput(
447+
processor_fn=self._simple_processor,
448+
image=image,
449+
)
450+
451+
# Execute: format attribution for single feature
452+
attr = torch.tensor([[0.5]])
453+
result = mm_input.format_attr(attr)
454+
455+
# Assert: attribution should be broadcast to all pixels
456+
self.assertEqual(result.shape, (1, 5, 5))
457+
self.assertTrue(torch.all(result == 0.5))
458+
459+
def test_format_attr_with_mask(self) -> None:
460+
# Setup: create MMImageMaskInput with 2 segments
461+
image = self._create_test_image(width=10, height=5)
462+
mask = torch.zeros((5, 10), dtype=torch.int32)
463+
mask[:, 5:] = 1 # Split horizontally
464+
465+
mm_input = MMImageMaskInput(
466+
processor_fn=self._simple_processor,
467+
image=image,
468+
mask=mask,
469+
)
470+
471+
# Execute: format attribution with different values for each segment
472+
attr = torch.tensor([[0.3, 0.7]])
473+
result = mm_input.format_attr(attr)
474+
475+
# Assert: left half should have 0.3, right half should have 0.7
476+
self.assertEqual(result.shape, (1, 5, 10))
477+
assertTensorAlmostEqual(
478+
self, result[0, :, :5], torch.full((5, 5), 0.3)
479+
) # Left half
480+
assertTensorAlmostEqual(
481+
self, result[0, :, 5:], torch.full((5, 5), 0.7)
482+
) # Right half
483+
484+
def test_format_attr_with_non_continuous_mask(self) -> None:
485+
# Setup: create mask with non-continuous IDs
486+
image = self._create_test_image(width=15, height=5)
487+
mask = torch.zeros((5, 15), dtype=torch.int32)
488+
mask[:, 5:10] = 10
489+
mask[:, 10:] = 20
490+
491+
mm_input = MMImageMaskInput(
492+
processor_fn=self._simple_processor,
493+
image=image,
494+
mask=mask,
495+
)
496+
497+
# Execute: format attribution
498+
attr = torch.tensor([[0.1, 0.2, 0.3]])
499+
result = mm_input.format_attr(attr)
500+
501+
# Assert: verify correct attribution values for each segment
502+
self.assertEqual(result.shape, (1, 5, 15))
503+
# Find which continuous index maps to which mask ID
504+
idx_0 = mm_input.mask_id_to_idx[0]
505+
idx_10 = mm_input.mask_id_to_idx[10]
506+
idx_20 = mm_input.mask_id_to_idx[20]
507+
508+
# Verify each segment has its corresponding attribution
509+
segment_0_value = attr[0, idx_0].item()
510+
segment_10_value = attr[0, idx_10].item()
511+
segment_20_value = attr[0, idx_20].item()
512+
513+
self.assertTrue(torch.all(result[0, :, :5] == segment_0_value))
514+
self.assertTrue(torch.all(result[0, :, 5:10] == segment_10_value))
515+
self.assertTrue(torch.all(result[0, :, 10:] == segment_20_value))

0 commit comments

Comments
 (0)