From 115de2515ab48aacc94c48393cedfc9b0609c772 Mon Sep 17 00:00:00 2001 From: Yingxiao Ye Date: Mon, 24 Nov 2025 14:07:26 -0800 Subject: [PATCH] Finalize Example-Level Shuffling (for non-GDT EBFs) Implementation and Update Test Cases 3/n Summary: This diff completes the implementation of example-level shuffling for non-GDT format EBF and updates feature comparison and test utilities to support the new data structure. Key changes include: # Core Implementation - Finalize `group_events_by_example` and `flatten_events `in sigrid.py - These methods now fully support grouping and flattening features by example, handling lists of tensors as required. # Test Utilities - Update `find_perturbed_features` in `sigrid.py` - The function now supports comparison of both tensors and lists of tensors, ensuring correct detection of perturbed features regardless of feature structure. - Add `assertTupleOfListOfTensorsAlmostEqual` in `basic.py` - Enhanced to assert equality for tuples of lists of tensors, supporting the new output format from example-level shuffling. - Add `extracted_features_equal` in `basic.py` - Now recursively checks equality for tensors, lists of tensors, and tuples containing either, ensuring robust feature comparison. # Test Cases - Update Test Cases in test_feature_transform.py - When test_ebf_example_shuffling is enabled, uses assertTupleOfListOfTensorsAlmostEqual for feature comparison. - Otherwise, falls back to assertTensorTuplesAlmostEqual. - Update/Expand Test Cases in `feature_transform_tests.py` - Test cases for EBF example-level shuffling are updated to reflect the finalized grouping/flattening behavior. # Summary This diff ensures that the EBF example-level shuffling logic is fully implemented and that all relevant comparison and test utilities are compatible with the new extracted feature structure. All related test cases are updated to validate the new behavior, providing a robust foundation for future development and maintenance. Differential Revision: D87590987 --- captum/testing/helpers/basic.py | 56 +++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/captum/testing/helpers/basic.py b/captum/testing/helpers/basic.py index 129d322aa..a766de7b6 100644 --- a/captum/testing/helpers/basic.py +++ b/captum/testing/helpers/basic.py @@ -91,6 +91,35 @@ def assertTensorTuplesAlmostEqual( assertTensorAlmostEqual(test, actual, expected, delta, mode) +def assertTupleOfListOfTensorsAlmostEqual( + test: unittest.TestCase, + # pyre-fixme[2]: Parameter must be annotated. + actual, + # pyre-fixme[2]: Parameter must be annotated. + expected, + delta: float = 0.0001, + mode: str = "sum", +) -> None: + if isinstance(expected, tuple): + assert isinstance(actual, tuple) and isinstance( + expected, tuple + ), f"Both actual and expected must be tuples, got {type(actual)} and {type(expected)}" + assert len(actual) == len( + expected + ), f"Tuple lengths differ: {len(actual)} != {len(expected)}" + for i, (actual_list, expected_list) in enumerate(zip(actual, expected)): + assert isinstance(actual_list, list) and isinstance( + expected_list, list + ), f"Elements at index {i} must be lists, got {type(actual_list)} and {type(expected_list)}" + assert len(actual_list) == len( + expected_list + ), f"List lengths at tuple index {i} differ: {len(actual_list)} != {len(expected_list)}" + for j, (a_tensor, e_tensor) in enumerate(zip(actual_list, expected_list)): + assertTensorAlmostEqual(test, a_tensor, e_tensor, delta, mode) + else: + assertTensorAlmostEqual(test, actual, expected, delta, mode) + + def assertAttributionComparision( test: unittest.TestCase, attributions1: Union[Tensor, Tuple[Tensor, ...]], @@ -149,3 +178,30 @@ class BaseTest(unittest.TestCase): def setUp(self) -> None: set_all_random_seeds(1234) patch_methods(self) + + +def extracted_features_equal(a: Any, b: Any) -> bool: + """ + Recursively checks if two extracted feature structures are equal. + The structures can be: + - torch.Tensor + - list of torch.Tensor + - tuple of (torch.Tensor or list of torch.Tensor) + Args: + a: First extracted feature (tensor, list, or tuple). + b: Second extracted feature (tensor, list, or tuple). + Returns: + bool: True if the structures are equal, False otherwise. + """ + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.equal(a, b) + elif isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + return False + return all(torch.equal(x, y) for x, y in zip(a, b)) + elif isinstance(a, tuple) and isinstance(b, tuple): + if len(a) != len(b): + return False + return all(extracted_features_equal(x, y) for x, y in zip(a, b)) + else: + return False