|
2 | 2 |
|
3 | 3 | # pyre-unsafe |
4 | 4 |
|
5 | | -from typing import List, Literal, Optional, overload, Union |
| 5 | +from typing import Any, Dict, List, Literal, Optional, overload, Union |
6 | 6 |
|
| 7 | +import numpy as np |
| 8 | +import PIL.Image |
7 | 9 | import torch |
8 | 10 | from captum._utils.typing import BatchEncodingType |
9 | | -from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput |
| 11 | +from captum.attr._utils.interpretable_input import ( |
| 12 | + MMImageMaskInput, |
| 13 | + TextTemplateInput, |
| 14 | + TextTokenInput, |
| 15 | +) |
10 | 16 | from captum.testing.helpers import BaseTest |
11 | 17 | from captum.testing.helpers.basic import assertTensorAlmostEqual |
12 | 18 | from parameterized import parameterized |
@@ -228,3 +234,282 @@ def test_input_with_skip_tokens(self) -> None: |
228 | 234 | assertTensorAlmostEqual( |
229 | 235 | self, tt_input.to_model_input(perturbed_tensor), expected_perturbed_inp |
230 | 236 | ) |
| 237 | + |
| 238 | + |
| 239 | +class TestMMImageMaskInput(BaseTest): |
| 240 | + def _create_test_image( |
| 241 | + self, width: int = 10, height: int = 10, color: tuple = (255, 0, 0) |
| 242 | + ) -> PIL.Image.Image: |
| 243 | + """Helper method to create a test PIL image.""" |
| 244 | + img_array = np.full((height, width, 3), color, dtype=np.uint8) |
| 245 | + return PIL.Image.fromarray(img_array) |
| 246 | + |
| 247 | + def _simple_processor(self, image: PIL.Image.Image) -> Dict[str, Tensor]: |
| 248 | + """Simple test processor that converts image to tensor.""" |
| 249 | + img_array = np.array(image) |
| 250 | + return {"pixel_values": torch.from_numpy(img_array).float()} |
| 251 | + |
| 252 | + def test_init_without_mask(self) -> None: |
| 253 | + # Setup: create test image and processor |
| 254 | + image = self._create_test_image() |
| 255 | + |
| 256 | + # Execute: create MMImageMaskInput without mask |
| 257 | + mm_input = MMImageMaskInput( |
| 258 | + processor_fn=self._simple_processor, |
| 259 | + image=image, |
| 260 | + ) |
| 261 | + |
| 262 | + # Assert: verify n_itp_features is 1 when no mask provided |
| 263 | + self.assertEqual(mm_input.n_itp_features, 1) |
| 264 | + self.assertEqual(len(mm_input.mask_id_to_idx), 0) |
| 265 | + self.assertIsNone(mm_input.mask) |
| 266 | + |
| 267 | + def test_init_with_mask(self) -> None: |
| 268 | + # Setup: create test image and mask with 2 segments |
| 269 | + image = self._create_test_image() |
| 270 | + mask = torch.zeros((10, 10), dtype=torch.int32) |
| 271 | + mask[:, 5:] = 1 # Split horizontally into 2 segments |
| 272 | + |
| 273 | + # Execute: create MMImageMaskInput with mask |
| 274 | + mm_input = MMImageMaskInput( |
| 275 | + processor_fn=self._simple_processor, |
| 276 | + image=image, |
| 277 | + mask=mask, |
| 278 | + ) |
| 279 | + |
| 280 | + # Assert: verify n_itp_features matches number of unique mask values |
| 281 | + self.assertEqual(mm_input.n_itp_features, 2) |
| 282 | + self.assertEqual(len(mm_input.mask_id_to_idx), 2) |
| 283 | + self.assertIn(0, mm_input.mask_id_to_idx) |
| 284 | + self.assertIn(1, mm_input.mask_id_to_idx) |
| 285 | + |
| 286 | + def test_init_with_non_continuous_mask_ids(self) -> None: |
| 287 | + # Setup: create mask with non-continuous IDs (e.g., 5, 10, 15) |
| 288 | + image = self._create_test_image(width=15, height=10) |
| 289 | + mask = torch.zeros((10, 15), dtype=torch.int32) |
| 290 | + mask[:, 5:10] = 5 |
| 291 | + mask[:, 10:] = 10 |
| 292 | + |
| 293 | + # Execute: create MMImageMaskInput |
| 294 | + mm_input = MMImageMaskInput( |
| 295 | + processor_fn=self._simple_processor, |
| 296 | + image=image, |
| 297 | + mask=mask, |
| 298 | + ) |
| 299 | + |
| 300 | + # Assert: verify mask_id_to_idx creates continuous mapping |
| 301 | + self.assertEqual(mm_input.n_itp_features, 3) |
| 302 | + self.assertEqual(len(mm_input.mask_id_to_idx), 3) |
| 303 | + # Verify all mask IDs are mapped to continuous indices 0, 1, 2 |
| 304 | + mapped_indices = set(mm_input.mask_id_to_idx.values()) |
| 305 | + self.assertEqual(mapped_indices, {0, 1, 2}) |
| 306 | + |
| 307 | + def test_to_tensor_without_mask(self) -> None: |
| 308 | + # Setup: create MMImageMaskInput without mask |
| 309 | + image = self._create_test_image() |
| 310 | + mm_input = MMImageMaskInput( |
| 311 | + processor_fn=self._simple_processor, |
| 312 | + image=image, |
| 313 | + ) |
| 314 | + |
| 315 | + # Execute: convert to tensor |
| 316 | + result = mm_input.to_tensor() |
| 317 | + |
| 318 | + # Assert: verify tensor shape and values for single feature |
| 319 | + expected = torch.tensor([[1.0]]) |
| 320 | + assertTensorAlmostEqual(self, result, expected) |
| 321 | + |
| 322 | + def test_to_tensor_with_mask(self) -> None: |
| 323 | + # Setup: create MMImageMaskInput with 3 segments |
| 324 | + image = self._create_test_image(width=15) |
| 325 | + mask = torch.zeros((10, 15), dtype=torch.int32) |
| 326 | + mask[:, 5:10] = 1 |
| 327 | + mask[:, 10:] = 2 |
| 328 | + |
| 329 | + mm_input = MMImageMaskInput( |
| 330 | + processor_fn=self._simple_processor, |
| 331 | + image=image, |
| 332 | + mask=mask, |
| 333 | + ) |
| 334 | + |
| 335 | + # Execute: convert to tensor |
| 336 | + result = mm_input.to_tensor() |
| 337 | + |
| 338 | + # Assert: verify tensor has correct number of features |
| 339 | + expected = torch.tensor([[1.0, 1.0, 1.0]]) |
| 340 | + assertTensorAlmostEqual(self, result, expected) |
| 341 | + |
| 342 | + def test_to_model_input_without_perturbation(self) -> None: |
| 343 | + # Setup: create MMImageMaskInput |
| 344 | + image = self._create_test_image() |
| 345 | + mm_input = MMImageMaskInput( |
| 346 | + processor_fn=self._simple_processor, |
| 347 | + image=image, |
| 348 | + ) |
| 349 | + |
| 350 | + # Execute: get model input without perturbation |
| 351 | + result = mm_input.to_model_input() |
| 352 | + |
| 353 | + # Assert: verify returns original model inputs |
| 354 | + self.assertIn("pixel_values", result) |
| 355 | + assertTensorAlmostEqual( |
| 356 | + self, result["pixel_values"], mm_input.original_model_inputs["pixel_values"] |
| 357 | + ) |
| 358 | + |
| 359 | + def test_to_model_input_with_perturbation_no_mask_present(self) -> None: |
| 360 | + # Setup: create red image without mask |
| 361 | + image = self._create_test_image(color=(255, 0, 0)) |
| 362 | + mm_input = MMImageMaskInput( |
| 363 | + processor_fn=self._simple_processor, |
| 364 | + image=image, |
| 365 | + baselines=(255, 255, 255), # white baseline |
| 366 | + ) |
| 367 | + |
| 368 | + # Execute: perturb with feature present (value 1) |
| 369 | + perturbed_tensor = torch.tensor([[1.0]]) |
| 370 | + result = mm_input.to_model_input(perturbed_tensor) |
| 371 | + |
| 372 | + # Assert: image should remain red (unchanged) |
| 373 | + img_array = result["pixel_values"].numpy().astype(np.uint8) |
| 374 | + self.assertTrue(np.all(img_array[:, :, 0] == 255)) |
| 375 | + self.assertTrue(np.all(img_array[:, :, 1] == 0)) |
| 376 | + self.assertTrue(np.all(img_array[:, :, 2] == 0)) |
| 377 | + |
| 378 | + def test_to_model_input_with_perturbation_no_mask_absent(self) -> None: |
| 379 | + # Setup: create red image without mask |
| 380 | + image = self._create_test_image(color=(255, 0, 0)) |
| 381 | + mm_input = MMImageMaskInput( |
| 382 | + processor_fn=self._simple_processor, |
| 383 | + image=image, |
| 384 | + baselines=(255, 255, 255), # white baseline |
| 385 | + ) |
| 386 | + |
| 387 | + # Execute: perturb with feature absent (value 0) |
| 388 | + perturbed_tensor = torch.tensor([[0.0]]) |
| 389 | + result = mm_input.to_model_input(perturbed_tensor) |
| 390 | + |
| 391 | + # Assert: entire image should be white (baseline) |
| 392 | + img_array = result["pixel_values"].numpy().astype(np.uint8) |
| 393 | + self.assertTrue(np.all(img_array == 255)) |
| 394 | + |
| 395 | + def test_to_model_input_with_mask_partial_perturbation(self) -> None: |
| 396 | + # Setup: create image with 2 segments (left red, right green) |
| 397 | + img_array = np.zeros((10, 10, 3), dtype=np.uint8) |
| 398 | + img_array[:, :5] = [255, 0, 0] # Left half red |
| 399 | + img_array[:, 5:] = [0, 255, 0] # Right half green |
| 400 | + image = PIL.Image.fromarray(img_array) |
| 401 | + |
| 402 | + mask = torch.zeros((10, 10), dtype=torch.int32) |
| 403 | + mask[:, 5:] = 1 # Right half is segment 1 |
| 404 | + |
| 405 | + mm_input = MMImageMaskInput( |
| 406 | + processor_fn=self._simple_processor, |
| 407 | + image=image, |
| 408 | + mask=mask, |
| 409 | + baselines=(255, 255, 255), # white baseline |
| 410 | + ) |
| 411 | + |
| 412 | + # Execute: perturb to keep left segment (0) but remove right segment (1) |
| 413 | + perturbed_tensor = torch.tensor([[1.0, 0.0]]) |
| 414 | + result = mm_input.to_model_input(perturbed_tensor) |
| 415 | + |
| 416 | + # Assert: left half should be red, right half should be white |
| 417 | + img_array = result["pixel_values"].numpy().astype(np.uint8) |
| 418 | + # Left half should be red |
| 419 | + self.assertTrue(np.all(img_array[:, :5, 0] == 255)) |
| 420 | + self.assertTrue(np.all(img_array[:, :5, 1] == 0)) |
| 421 | + # Right half should be white (baseline) |
| 422 | + self.assertTrue(np.all(img_array[:, 5:] == 255)) |
| 423 | + |
| 424 | + def test_to_model_input_with_custom_baselines(self) -> None: |
| 425 | + # Setup: create image with custom baseline color |
| 426 | + image = self._create_test_image(color=(255, 0, 0)) |
| 427 | + mm_input = MMImageMaskInput( |
| 428 | + processor_fn=self._simple_processor, |
| 429 | + image=image, |
| 430 | + baselines=(0, 128, 255), # Custom blue-ish baseline |
| 431 | + ) |
| 432 | + |
| 433 | + # Execute: perturb to remove feature |
| 434 | + perturbed_tensor = torch.tensor([[0.0]]) |
| 435 | + result = mm_input.to_model_input(perturbed_tensor) |
| 436 | + |
| 437 | + # Assert: image should have custom baseline color |
| 438 | + img_array = result["pixel_values"].numpy().astype(np.uint8) |
| 439 | + self.assertTrue(np.all(img_array[:, :, 0] == 0)) |
| 440 | + self.assertTrue(np.all(img_array[:, :, 1] == 128)) |
| 441 | + self.assertTrue(np.all(img_array[:, :, 2] == 255)) |
| 442 | + |
| 443 | + def test_format_attr_without_mask(self) -> None: |
| 444 | + # Setup: create MMImageMaskInput without mask |
| 445 | + image = self._create_test_image(width=5, height=5) |
| 446 | + mm_input = MMImageMaskInput( |
| 447 | + processor_fn=self._simple_processor, |
| 448 | + image=image, |
| 449 | + ) |
| 450 | + |
| 451 | + # Execute: format attribution for single feature |
| 452 | + attr = torch.tensor([[0.5]]) |
| 453 | + result = mm_input.format_attr(attr) |
| 454 | + |
| 455 | + # Assert: attribution should be broadcast to all pixels |
| 456 | + self.assertEqual(result.shape, (1, 5, 5)) |
| 457 | + self.assertTrue(torch.all(result == 0.5)) |
| 458 | + |
| 459 | + def test_format_attr_with_mask(self) -> None: |
| 460 | + # Setup: create MMImageMaskInput with 2 segments |
| 461 | + image = self._create_test_image(width=10, height=5) |
| 462 | + mask = torch.zeros((5, 10), dtype=torch.int32) |
| 463 | + mask[:, 5:] = 1 # Split horizontally |
| 464 | + |
| 465 | + mm_input = MMImageMaskInput( |
| 466 | + processor_fn=self._simple_processor, |
| 467 | + image=image, |
| 468 | + mask=mask, |
| 469 | + ) |
| 470 | + |
| 471 | + # Execute: format attribution with different values for each segment |
| 472 | + attr = torch.tensor([[0.3, 0.7]]) |
| 473 | + result = mm_input.format_attr(attr) |
| 474 | + |
| 475 | + # Assert: left half should have 0.3, right half should have 0.7 |
| 476 | + self.assertEqual(result.shape, (1, 5, 10)) |
| 477 | + assertTensorAlmostEqual( |
| 478 | + self, result[0, :, :5], torch.full((5, 5), 0.3) |
| 479 | + ) # Left half |
| 480 | + assertTensorAlmostEqual( |
| 481 | + self, result[0, :, 5:], torch.full((5, 5), 0.7) |
| 482 | + ) # Right half |
| 483 | + |
| 484 | + def test_format_attr_with_non_continuous_mask(self) -> None: |
| 485 | + # Setup: create mask with non-continuous IDs |
| 486 | + image = self._create_test_image(width=15, height=5) |
| 487 | + mask = torch.zeros((5, 15), dtype=torch.int32) |
| 488 | + mask[:, 5:10] = 10 |
| 489 | + mask[:, 10:] = 20 |
| 490 | + |
| 491 | + mm_input = MMImageMaskInput( |
| 492 | + processor_fn=self._simple_processor, |
| 493 | + image=image, |
| 494 | + mask=mask, |
| 495 | + ) |
| 496 | + |
| 497 | + # Execute: format attribution |
| 498 | + attr = torch.tensor([[0.1, 0.2, 0.3]]) |
| 499 | + result = mm_input.format_attr(attr) |
| 500 | + |
| 501 | + # Assert: verify correct attribution values for each segment |
| 502 | + self.assertEqual(result.shape, (1, 5, 15)) |
| 503 | + # Find which continuous index maps to which mask ID |
| 504 | + idx_0 = mm_input.mask_id_to_idx[0] |
| 505 | + idx_10 = mm_input.mask_id_to_idx[10] |
| 506 | + idx_20 = mm_input.mask_id_to_idx[20] |
| 507 | + |
| 508 | + # Verify each segment has its corresponding attribution |
| 509 | + segment_0_value = attr[0, idx_0].item() |
| 510 | + segment_10_value = attr[0, idx_10].item() |
| 511 | + segment_20_value = attr[0, idx_20].item() |
| 512 | + |
| 513 | + self.assertTrue(torch.all(result[0, :, :5] == segment_0_value)) |
| 514 | + self.assertTrue(torch.all(result[0, :, 5:10] == segment_10_value)) |
| 515 | + self.assertTrue(torch.all(result[0, :, 10:] == segment_20_value)) |
0 commit comments