diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index 25cd9162c..904a78355 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -41,6 +41,12 @@ class DatasetConfiguration: scenario_strategies (Optional[Sequence[ScenarioStrategy]]): The scenario strategies being executed. Subclasses can use this to filter or customize which seed groups are loaded based on the selected strategies. + + Subclassing notes: + Memoization lives in ``get_seed_groups()`` and ``get_all_seeds()`` — + the two methods that call ``random.sample``. Overrides of those, or + new resolution methods that introduce their own randomness, must + memoize explicitly to preserve lifetime-stable sampling. """ def __init__( @@ -75,14 +81,37 @@ def __init__( "or 'dataset_names' to load from memory." ) - if max_dataset_size is not None and max_dataset_size < 1: - raise ValueError("'max_dataset_size' must be a positive integer (>= 1).") - - # Store private attributes + # Caches must exist before the max_dataset_size setter runs. self._seed_groups = list(seed_groups) if seed_groups is not None else None - self.max_dataset_size = max_dataset_size self._dataset_names = list(dataset_names) if dataset_names is not None else None self._scenario_strategies = scenario_strategies + self._resolved_groups_cache: Optional[dict[str, list[SeedGroup]]] = None + self._resolved_seeds_cache: Optional[list[Seed]] = None + self._max_dataset_size: Optional[int] = None + self.max_dataset_size = max_dataset_size # validates via setter + + @property + def max_dataset_size(self) -> Optional[int]: + """ + Maximum number of SeedGroups to sample per dataset. + + When set, the configuration samples a stable random subset on first + resolution and reuses that subset for the lifetime of the + configuration object (or until this attribute is reassigned). + Reassigning invalidates the cached sample so the next resolution + produces a fresh subset. + """ + return self._max_dataset_size + + @max_dataset_size.setter + def max_dataset_size(self, value: Optional[int]) -> None: + if value is not None and value < 1: + raise ValueError("'max_dataset_size' must be a positive integer (>= 1).") + self._max_dataset_size = value + # Invalidate any previously resolved sample so the next call + # re-samples against the new cap. + self._resolved_groups_cache = None + self._resolved_seeds_cache = None def get_seed_groups(self) -> dict[str, list[SeedGroup]]: """ @@ -94,6 +123,11 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]: In all cases, max_dataset_size is applied **per dataset** if set. + The resolved sample is cached for the lifetime of the configuration + (until ``max_dataset_size`` is reassigned). A defensive container + copy is returned on each call so the cache survives caller-side + mutation of the dict or per-dataset lists. + Subclasses can override this to filter or customize which seed groups are loaded based on the stored scenario_composites. @@ -106,6 +140,9 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]: Raises: ValueError: If no seed groups could be resolved from the configuration. """ + if self._resolved_groups_cache is not None: + return {name: list(groups) for name, groups in self._resolved_groups_cache.items()} + result: dict[str, list[SeedGroup]] = {} if self._seed_groups is not None: @@ -129,7 +166,9 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]: if not result: raise ValueError("DatasetConfiguration has no seed_groups. Set seed_groups or dataset_names.") - return result + self._resolved_groups_cache = result + # Defensive copy: caller must not be able to mutate the cache. + return {name: list(groups) for name, groups in result.items()} def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> list[SeedGroup]: """ @@ -256,6 +295,11 @@ def get_all_seeds(self) -> list[Seed]: from memory for all configured datasets. If max_dataset_size is set, randomly samples up to that many prompts per dataset (without replacement). + The resolved sample is cached for the lifetime of the configuration + (until ``max_dataset_size`` is reassigned). A defensive list copy + is returned on each call so the cache survives caller-side + mutation. + Returns: List[SeedPrompt]: List of SeedPrompt objects from all configured datasets. Returns an empty list if no prompts are found. @@ -266,6 +310,9 @@ def get_all_seeds(self) -> list[Seed]: if self._dataset_names is None: raise ValueError("No dataset names configured. Set dataset_names to use get_all_seed_prompts.") + if self._resolved_seeds_cache is not None: + return list(self._resolved_seeds_cache) + memory = CentralMemory.get_memory_instance() all_seeds: list[Seed] = [] @@ -277,4 +324,5 @@ def get_all_seeds(self) -> list[Seed]: seeds = random.sample(seeds, self.max_dataset_size) all_seeds.extend(seeds) - return all_seeds + self._resolved_seeds_cache = all_seeds + return list(all_seeds) diff --git a/tests/unit/scenario/test_dataset_configuration.py b/tests/unit/scenario/test_dataset_configuration.py index e1b5c6872..5a3557e4b 100644 --- a/tests/unit/scenario/test_dataset_configuration.py +++ b/tests/unit/scenario/test_dataset_configuration.py @@ -515,3 +515,197 @@ def test_get_all_seeds_returns_empty_list_when_no_seeds_in_memory(self) -> None: result = config.get_all_seeds() assert result == [] + + +@pytest.mark.usefixtures("patch_central_database") +class TestDatasetConfigurationMemoization: + """Tests for memoization of resolved seed groups and seeds. + + Pins the contract that the random subset selected when ``max_dataset_size`` + is set is stable for the lifetime of the configuration object. ADO 9012 + regression tests live here; flakiness is avoided by patching + ``random.sample`` rather than relying on RNG seeds. + """ + + def _make_seed_groups(self, count: int) -> list[SeedGroup]: + return [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(count)] + + def test_get_seed_groups_is_stable_across_calls_with_max_dataset_size(self) -> None: + seed_groups = self._make_seed_groups(10) + config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + + first_sample = seed_groups[:3] + second_sample = seed_groups[3:6] + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[first_sample, second_sample], + ) as mock_sample: + first = config.get_seed_groups() + second = config.get_seed_groups() + + assert first[EXPLICIT_SEED_GROUPS_KEY] == first_sample + assert second[EXPLICIT_SEED_GROUPS_KEY] == first_sample + assert mock_sample.call_count == 1 + + def test_get_seed_groups_is_stable_across_multi_dataset(self) -> None: + ds1 = self._make_seed_groups(10) + ds2 = self._make_seed_groups(10) + + def mock_load(*, dataset_name: str) -> list[SeedGroup]: + return ds1 if dataset_name == "ds1" else ds2 + + config = DatasetConfiguration(dataset_names=["ds1", "ds2"], max_dataset_size=3) + + ds1_sample = ds1[:3] + ds2_sample = ds2[:3] + with ( + patch.object(config, "_load_seed_groups_for_dataset", side_effect=mock_load), + patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[ds1_sample, ds2_sample, ds1[3:6], ds2[3:6]], + ) as mock_sample, + ): + first = config.get_seed_groups() + second = config.get_seed_groups() + + assert first["ds1"] == ds1_sample + assert first["ds2"] == ds2_sample + assert second["ds1"] == ds1_sample + assert second["ds2"] == ds2_sample + assert mock_sample.call_count == 2 # one per dataset, on the first call only + + def test_get_all_seed_attack_groups_is_stable_across_calls(self) -> None: + seed_groups = self._make_seed_groups(10) + config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[seed_groups[:3], seed_groups[3:6]], + ): + first = config.get_all_seed_attack_groups() + second = config.get_all_seed_attack_groups() + + first_objectives = [g.objective.value for g in first] + second_objectives = [g.objective.value for g in second] + assert first_objectives == second_objectives + + def test_get_all_seeds_is_stable_across_calls(self) -> None: + seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)] + + with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_seeds.return_value = seeds + mock_memory_class.get_memory_instance.return_value = mock_memory + + config = DatasetConfiguration(dataset_names=["d1"], max_dataset_size=3) + + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[seeds[:3], seeds[3:6]], + ) as mock_sample: + first = config.get_all_seeds() + second = config.get_all_seeds() + + assert first == seeds[:3] + assert second == seeds[:3] + assert mock_sample.call_count == 1 + + def test_returned_dict_can_be_mutated_without_poisoning_cache(self) -> None: + seed_groups = self._make_seed_groups(10) + config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + return_value=seed_groups[:3], + ): + first = config.get_seed_groups() + first[EXPLICIT_SEED_GROUPS_KEY].clear() + first.pop(EXPLICIT_SEED_GROUPS_KEY, None) + second = config.get_seed_groups() + + assert second[EXPLICIT_SEED_GROUPS_KEY] == seed_groups[:3] + + def test_returned_seeds_list_can_be_mutated_without_poisoning_cache(self) -> None: + seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)] + + with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_seeds.return_value = seeds + mock_memory_class.get_memory_instance.return_value = mock_memory + + config = DatasetConfiguration(dataset_names=["d1"], max_dataset_size=3) + + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + return_value=seeds[:3], + ): + first = config.get_all_seeds() + first.clear() + second = config.get_all_seeds() + + assert second == seeds[:3] + + +@pytest.mark.usefixtures("patch_central_database") +class TestDatasetConfigurationMaxDatasetSizeSetter: + """Tests for the ``max_dataset_size`` property setter.""" + + def test_setter_invalidates_groups_cache(self) -> None: + seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + + first_sample = seed_groups[:3] + second_sample = seed_groups[5:8] + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[first_sample, second_sample], + ): + first = config.get_seed_groups() + config.max_dataset_size = 3 # reassign (same value triggers invalidation) + second = config.get_seed_groups() + + assert first[EXPLICIT_SEED_GROUPS_KEY] == first_sample + assert second[EXPLICIT_SEED_GROUPS_KEY] == second_sample + + def test_setter_invalidates_seeds_cache(self) -> None: + seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)] + + with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_seeds.return_value = seeds + mock_memory_class.get_memory_instance.return_value = mock_memory + + config = DatasetConfiguration(dataset_names=["d1"], max_dataset_size=3) + + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[seeds[:3], seeds[5:8]], + ): + first = config.get_all_seeds() + config.max_dataset_size = 3 + second = config.get_all_seeds() + + assert first == seeds[:3] + assert second == seeds[5:8] + + def test_setter_rejects_zero(self) -> None: + config = DatasetConfiguration(seed_groups=[SeedGroup(seeds=[SeedObjective(value="obj")])]) + + with pytest.raises(ValueError, match="must be a positive integer"): + config.max_dataset_size = 0 + + def test_setter_rejects_negative(self) -> None: + config = DatasetConfiguration(seed_groups=[SeedGroup(seeds=[SeedObjective(value="obj")])]) + + with pytest.raises(ValueError, match="must be a positive integer"): + config.max_dataset_size = -1 + + def test_setter_accepts_none(self) -> None: + config = DatasetConfiguration( + seed_groups=[SeedGroup(seeds=[SeedObjective(value="obj")])], + max_dataset_size=5, + ) + + config.max_dataset_size = None + + assert config.max_dataset_size is None diff --git a/tests/unit/scenario/test_encoding.py b/tests/unit/scenario/test_encoding.py index 0df8435a8..e09b83884 100644 --- a/tests/unit/scenario/test_encoding.py +++ b/tests/unit/scenario/test_encoding.py @@ -399,3 +399,34 @@ def test_encoding_dataset_config_can_be_initialized_with_dataset_names(self): assert config._dataset_names == ["garak_slur_terms_en", "garak_web_html_js"] assert config.max_dataset_size == 5 + + def test_get_all_seed_attack_groups_is_stable_across_calls_with_max_dataset_size(self): + """Regression test for ADO 9012 (Path 2). + + EncodingDatasetConfiguration.get_all_seed_attack_groups overrides the + base method and routes through get_all_seeds, which has its own + random.sample. Memoizing only get_seed_groups would not catch this + path; this test pins that the override is stable across calls. + """ + from unittest.mock import patch + + seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)] + + with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_seeds.return_value = seeds + mock_memory_class.get_memory_instance.return_value = mock_memory + + config = EncodingDatasetConfiguration(dataset_names=["d1"], max_dataset_size=3) + + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[seeds[:3], seeds[3:6]], + ) as mock_sample: + first = config.get_all_seed_attack_groups() + second = config.get_all_seed_attack_groups() + + first_objectives = [g.objective.value for g in first] + second_objectives = [g.objective.value for g in second] + assert first_objectives == second_objectives + assert mock_sample.call_count == 1 diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index bbead3840..b042fd11a 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -884,6 +884,67 @@ async def test_execute_scenario_raises_when_scenario_result_id_is_none(): await scenario._execute_scenario_async() +@pytest.mark.usefixtures("patch_central_database") +class TestScenarioBaselineUniformObjectives: + """ADO 9012 regression: baseline and strategy atomic attacks share objectives. + + Without memoization in DatasetConfiguration, ``_get_atomic_attacks_async`` + and ``_get_baseline_data`` each call ``get_all_seed_attack_groups()`` + independently and ``random.sample`` produces a different subset for + each. With memoization, both calls converge on the same subset. + """ + + async def test_baseline_objectives_match_atomic_attacks_under_max_dataset_size( + self, + mock_objective_target, + ): + from pyrit.models import SeedGroup, SeedObjective + from pyrit.scenario.core.attack_technique import AttackTechnique + + seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + + class StrategyScenario(ConcreteScenarioWithTrueFalseScorer): + async def _get_atomic_attacks_async(self): + groups = self._dataset_config.get_all_seed_attack_groups() + return [ + AtomicAttack( + atomic_attack_name="strategy", + attack_technique=AttackTechnique(attack=MagicMock()), + seed_groups=groups, + ) + ] + + scenario = StrategyScenario( + name="ADO 9012 regression", + version=1, + include_default_baseline=True, + ) + + config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + + # Two distinct samples: a non-memoized implementation would consume + # both (one for the strategy call, one for the baseline call) and + # the assertion below would fail. Memoization consumes only the first. + first_sample = seed_groups[:3] + second_sample = seed_groups[5:8] + with patch( + "pyrit.scenario.core.dataset_configuration.random.sample", + side_effect=[first_sample, second_sample], + ): + await scenario.initialize_async( + objective_target=mock_objective_target, + scenario_strategies=None, + dataset_config=config, + ) + + baseline = scenario._atomic_attacks[0] + strategy = scenario._atomic_attacks[1] + assert baseline.atomic_attack_name == "baseline" + assert strategy.atomic_attack_name == "strategy" + assert set(baseline.objectives) == set(strategy.objectives) + assert len(baseline.objectives) == 3 + + @pytest.mark.usefixtures("patch_central_database") class TestValidateStoredScenario: """Tests for Scenario._validate_stored_scenario."""