diff --git a/.pyrit_conf_example b/.pyrit_conf_example index c45bb390ce..9d9e66305d 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -111,6 +111,12 @@ operation: op_trash_panda # - /path/to/.env # - /path/to/.env.local +# Max Concurrent Scenario Runs +# ---------------------------- +# Maximum number of scenario runs that can execute concurrently in the backend. +# Applies only to the pyrit_backend server. +max_concurrent_scenario_runs: 3 + # Silent Mode # ----------- # If true, suppresses print statements during initialization. diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index a47e431805..494972ddfb 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -5,10 +5,12 @@ Scenario API response models. Scenarios are multi-attack security testing campaigns. These models represent -the metadata about available scenarios (listing), not scenario execution results. +the metadata about available scenarios (listing) and scenario execution (runs). """ -from typing import Optional +from datetime import datetime +from enum import StrEnum +from typing import Any, Optional from pydantic import BaseModel, Field @@ -35,3 +37,126 @@ class ScenarioListResponse(BaseModel): items: list[ScenarioSummary] = Field(..., description="List of scenario summaries") pagination: PaginationInfo = Field(..., description="Pagination metadata") + + +# ============================================================================ +# Scenario Run Models +# ============================================================================ + + +class ScenarioRunStatus(StrEnum): + """Status of a scenario run.""" + + PENDING = "pending" + INITIALIZING = "initializing" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class RunScenarioRequest(BaseModel): + """Request body for starting a scenario run.""" + + scenario_name: str = Field(..., description="Registry key of the scenario to run") + target_name: str = Field(..., description="Name of a registered target from the TargetRegistry") + initializers: list[str] | None = Field( + None, description="Initializer names to run before scenario (e.g., ['target', 'load_default_datasets'])" + ) + strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)") + dataset_names: list[str] | None = Field(None, description="Dataset names to use (uses scenario default if omitted)") + max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") + max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") + max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure") + memory_labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") + scenario_params: dict[str, Any] | None = Field( + None, + description="Custom parameters for the scenario (passed to scenario.set_params_from_args). " + "Keys are parameter names declared by the scenario's supported_parameters().", + ) + initializer_args: dict[str, dict[str, Any]] | None = Field( + None, + description="Per-initializer arguments keyed by initializer name. " + "Each value is a dict of args passed to that initializer's set_params_from_args(). " + "Example: {'target': {'endpoint': 'https://...'}}.", + ) + scenario_result_id: str | None = Field( + None, + description="Optional ID of an existing ScenarioResult to resume. " + "If provided, the scenario will resume from prior progress instead of starting fresh.", + ) + + +class ScenarioRunResult(BaseModel): + """Summary of a completed scenario run's results.""" + + scenario_result_id: str = Field(..., description="UUID of the ScenarioResult in memory") + run_state: str = Field(..., description="Final scenario run state (COMPLETED, FAILED)") + strategies_used: list[str] = Field(..., description="Strategy names that were executed") + total_attacks: int = Field(..., ge=0, description="Total number of atomic attacks") + completed_attacks: int = Field(..., ge=0, description="Number of attacks that completed") + number_tries: int = Field(..., ge=0, description="Number of execution attempts") + completion_time: datetime | None = Field(None, description="When the scenario finished") + + +class ScenarioRunResponse(BaseModel): + """Response for a scenario run (status + optional result).""" + + run_id: str = Field(..., description="Unique identifier for this run") + scenario_name: str = Field(..., description="Registry key of the scenario being run") + status: ScenarioRunStatus = Field(..., description="Current run status") + created_at: datetime = Field(..., description="When the run was created") + updated_at: datetime = Field(..., description="When the run status last changed") + error: str | None = Field(None, description="Error message if status is FAILED") + result: ScenarioRunResult | None = Field(None, description="Result details if status is COMPLETED") + + +class ScenarioRunListResponse(BaseModel): + """Response for listing scenario runs.""" + + items: list[ScenarioRunResponse] = Field(..., description="List of scenario runs") + + +# ============================================================================ +# Scenario Results Detail Models +# ============================================================================ + + +class AttackResultDetail(BaseModel): + """Detailed result of a single attack within a scenario.""" + + attack_result_id: str = Field(..., description="Unique ID of this attack result") + conversation_id: str = Field(..., description="Conversation ID that produced this result") + objective: str = Field(..., description="Natural-language description of the attacker's objective") + outcome: str = Field(..., description="Attack outcome: success, failure, or undetermined") + outcome_reason: str | None = Field(None, description="Reason for the outcome") + last_response: str | None = Field(None, description="Model response from the final turn") + score_value: str | None = Field(None, description="Score value from the objective scorer") + executed_turns: int = Field(0, ge=0, description="Number of turns executed") + execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") + timestamp: datetime | None = Field(None, description="When the result was created") + + +class AtomicAttackResults(BaseModel): + """Results grouped by atomic attack name.""" + + atomic_attack_name: str = Field(..., description="Name of the atomic attack (strategy)") + display_group: str | None = Field(None, description="Display group label for UI grouping") + results: list[AttackResultDetail] = Field(..., description="Individual attack results") + success_count: int = Field(0, ge=0, description="Number of successful attacks") + failure_count: int = Field(0, ge=0, description="Number of failed attacks") + total_count: int = Field(0, ge=0, description="Total number of attack results") + + +class ScenarioResultDetailResponse(BaseModel): + """Full detailed results of a scenario run.""" + + scenario_result_id: str = Field(..., description="UUID of the ScenarioResult") + scenario_name: str = Field(..., description="Name of the scenario") + scenario_version: int = Field(..., description="Version of the scenario") + run_state: str = Field(..., description="Final run state (COMPLETED, FAILED, etc.)") + objective_achieved_rate: int = Field(..., ge=0, le=100, description="Success rate as percentage (0-100)") + number_tries: int = Field(..., ge=0, description="Number of execution attempts") + completion_time: datetime | None = Field(None, description="When the scenario finished") + labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") + attacks: list[AtomicAttackResults] = Field(..., description="Results grouped by atomic attack") diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 9cd3e2ef43..77f73a38c5 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -4,7 +4,12 @@ """ Scenario API routes. -Provides endpoints for listing available scenarios and their metadata. +Provides endpoints for listing available scenarios, their metadata, +and managing scenario runs. + +Route structure: + /api/scenarios/catalog — scenario catalog (list + detail) + /api/scenarios/runs — scenario execution lifecycle """ from typing import Optional @@ -12,14 +17,27 @@ from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ( + RunScenarioRequest, + ScenarioListResponse, + ScenarioResultDetailResponse, + ScenarioRunListResponse, + ScenarioRunResponse, + ScenarioSummary, +) +from pyrit.backend.services.scenario_run_service import get_scenario_run_service from pyrit.backend.services.scenario_service import get_scenario_service router = APIRouter(prefix="/scenarios", tags=["scenarios"]) +# ============================================================================ +# Scenario Catalog +# ============================================================================ + + @router.get( - "", + "/catalog", response_model=ScenarioListResponse, ) async def list_scenarios( @@ -30,7 +48,7 @@ async def list_scenarios( List all available scenarios. Returns scenario metadata including strategies, datasets, and defaults. - Use GET /api/scenarios/{scenario_name} for full details on a specific scenario. + Use GET /api/scenarios/catalog/{scenario_name} for full details on a specific scenario. Returns: ScenarioListResponse: Paginated list of scenario summaries. @@ -40,7 +58,7 @@ async def list_scenarios( @router.get( - "/{scenario_name:path}", + "/catalog/{scenario_name:path}", response_model=ScenarioSummary, responses={ 404: {"model": ProblemDetail, "description": "Scenario not found"}, @@ -66,3 +84,147 @@ async def get_scenario(scenario_name: str) -> ScenarioSummary: ) return scenario + + +# ============================================================================ +# Scenario Runs +# ============================================================================ + + +@router.post( + "/runs", + response_model=ScenarioRunResponse, + status_code=status.HTTP_202_ACCEPTED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid request (bad scenario/target/strategy)"}, + }, +) +async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunResponse: + """ + Start a new scenario run as a background task. + + Returns immediately with a run_id that can be polled for status. + + Args: + request: Scenario run configuration. + + Returns: + ScenarioRunResponse: Run metadata with PENDING status. + """ + service = get_scenario_run_service() + try: + return await service.start_run_async(request=request) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from None + + +@router.get( + "/runs", + response_model=ScenarioRunListResponse, +) +async def list_scenario_runs(limit: int = Query(100, ge=1)) -> ScenarioRunListResponse: + """ + List tracked scenario runs (most recent first). + + Args: + limit (int): Maximum number of runs to return. Defaults to 100. + + Returns: + ScenarioRunListResponse: Runs, most recent first. + """ + service = get_scenario_run_service() + return service.list_runs(limit=limit) + + +@router.get( + "/runs/{run_id}", + response_model=ScenarioRunResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + }, +) +async def get_scenario_run(run_id: str) -> ScenarioRunResponse: + """ + Get the current status and result of a scenario run. + + Args: + run_id: The unique run identifier returned by POST /runs. + + Returns: + ScenarioRunResponse: Current run status (and result if completed). + """ + service = get_scenario_run_service() + run = service.get_run(run_id=run_id) + if run is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{run_id}' not found", + ) + return run + + +@router.delete( + "/runs/{run_id}", + response_model=ScenarioRunResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + 409: {"model": ProblemDetail, "description": "Run already in terminal state"}, + }, +) +async def cancel_scenario_run(run_id: str) -> ScenarioRunResponse: + """ + Cancel a running scenario. + + Args: + run_id: The unique run identifier to cancel. + + Returns: + ScenarioRunResponse: Updated run with CANCELLED status. + """ + service = get_scenario_run_service() + try: + result = await service.cancel_run_async(run_id=run_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{run_id}' not found", + ) + return result + + +@router.get( + "/runs/{run_id}/results", + response_model=ScenarioResultDetailResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + 409: {"model": ProblemDetail, "description": "Run not yet completed"}, + }, +) +async def get_scenario_run_results(run_id: str) -> ScenarioResultDetailResponse: + """ + Get detailed results for a completed scenario run. + + Returns per-attack outcomes including objectives, responses, scores, + and success/failure counts. + + Args: + run_id: The unique run identifier. + + Returns: + ScenarioResultDetailResponse: Full attack-level results. + """ + service = get_scenario_run_service() + try: + result = service.get_run_results(run_id=run_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{run_id}' not found", + ) + return result diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index 29807150ae..d36f69a830 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -15,6 +15,10 @@ ConverterService, get_converter_service, ) +from pyrit.backend.services.scenario_run_service import ( + ScenarioRunService, + get_scenario_run_service, +) from pyrit.backend.services.scenario_service import ( ScenarioService, get_scenario_service, @@ -31,6 +35,8 @@ "get_converter_service", "ScenarioService", "get_scenario_service", + "ScenarioRunService", + "get_scenario_run_service", "TargetService", "get_target_service", ] diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py new file mode 100644 index 0000000000..85fe619fa4 --- /dev/null +++ b/pyrit/backend/services/scenario_run_service.py @@ -0,0 +1,484 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario run service for executing scenarios as background tasks. + +Manages the lifecycle of scenario runs: starting, tracking status, +retrieving results, and cancellation. +""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import Any + +from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + AttackResultDetail, + RunScenarioRequest, + ScenarioResultDetailResponse, + ScenarioRunListResponse, + ScenarioRunResponse, + ScenarioRunResult, + ScenarioRunStatus, +) +from pyrit.memory import CentralMemory +from pyrit.models import AttackOutcome, ScenarioResult +from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry +from pyrit.scenario import Scenario +from pyrit.scenario.core import DatasetConfiguration + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_CONCURRENT_RUNS = 3 + +# Maps DB ScenarioRunState values to API ScenarioRunStatus +_STATE_TO_STATUS: dict[str, ScenarioRunStatus] = { + "CREATED": ScenarioRunStatus.INITIALIZING, + "IN_PROGRESS": ScenarioRunStatus.RUNNING, + "COMPLETED": ScenarioRunStatus.COMPLETED, + "FAILED": ScenarioRunStatus.FAILED, + "CANCELLED": ScenarioRunStatus.CANCELLED, +} + + +@dataclass +class _ActiveTask: + """Tracks an in-flight scenario run's asyncio task.""" + + scenario_result_id: str + task: asyncio.Task[None] | None = None + scenario: Scenario | None = None + error: str | None = None + + +class ScenarioRunService: + """ + Service for managing scenario run lifecycle. + + Uses CentralMemory (database) as the source of truth for run state. + Keeps an in-memory dict only for active asyncio tasks (cancellation support). + """ + + def __init__(self, *, max_concurrent_runs: int = _DEFAULT_MAX_CONCURRENT_RUNS) -> None: + """Initialize the scenario run service.""" + self._max_concurrent_runs = max_concurrent_runs + self._active_tasks: dict[str, _ActiveTask] = {} + + async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunResponse: + """ + Start a new scenario run as a background task. + + Performs all validation and initialization eagerly (initializers, target + resolution, strategy validation, scenario.initialize_async) so errors are + returned immediately. On success, spawns a background task that only + executes scenario.run_async. + + Args: + request: The run request with scenario name, target, and options. + + Returns: + ScenarioRunResponse with run_id and RUNNING status. + + Raises: + ValueError: If scenario, target, initializer, or strategy cannot be found, + or concurrent limit exceeded. + """ + if sum(1 for a in self._active_tasks.values() if a.task is not None and not a.task.done()) >= self._max_concurrent_runs: + raise ValueError( + f"Maximum concurrent runs ({self._max_concurrent_runs}) reached. " + "Wait for an existing run to complete or cancel one." + ) + + # Perform all initialization eagerly — errors propagate to caller + scenario = await self._initialize_run_async(request=request) + + # scenario_result_id is set during initialize_async + scenario_result_id = scenario._scenario_result_id + if scenario_result_id is None: + raise ValueError("Scenario did not produce a scenario_result_id during initialization.") + + # Track active task + active = _ActiveTask(scenario_result_id=scenario_result_id, scenario=scenario) + self._active_tasks[scenario_result_id] = active + + # Spawn background task (only runs scenario.run_async) + task = asyncio.create_task(self._execute_run_async(scenario_result_id=scenario_result_id)) + active.task = task + + response = self._build_response(scenario_result_id=scenario_result_id) + assert response is not None # guaranteed: we just inserted into DB via initialize_async + return response + + def get_run(self, *, run_id: str) -> ScenarioRunResponse | None: + """ + Get the current status of a scenario run by querying the database. + + Args: + run_id: The scenario result ID (run identifier). + + Returns: + ScenarioRunResponse if found, None otherwise. + """ + return self._build_response(scenario_result_id=run_id) + + def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: + """ + List scenario runs by querying the database (most recent first). + + Args: + limit (int): Maximum number of runs to return. Defaults to 100. + + Returns: + ScenarioRunListResponse with runs. + """ + memory = CentralMemory.get_memory_instance() + + # This is expensive, and we don't need all the data. At some point + # we may want to add a lightweight "list" query to the DB layer that only + results = memory.get_scenario_results(limit=limit) + items = [self._build_response_from_db(scenario_result=sr) for sr in results] + return ScenarioRunListResponse(items=items) + + async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: + """ + Cancel a running scenario. + + Args: + run_id: The scenario result ID (run identifier). + + Returns: + Updated ScenarioRunResponse if found, None if run_id not found. + + Raises: + ValueError: If the run is already in a terminal state or not active. + """ + # Verify run exists in DB + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[run_id]) + if not results: + return None + + scenario_result = results[0] + db_status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + + if db_status in (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED): + raise ValueError(f"Cannot cancel run in '{db_status}' state.") + + # Cancel the asyncio task if active + active = self._active_tasks.get(run_id) + if active is not None: + if active.task is not None and not active.task.done(): + active.task.cancel() + + # Persist cancelled state to DB + memory.update_scenario_run_state(scenario_result_id=run_id, scenario_run_state="CANCELLED") + + return self._build_response(scenario_result_id=run_id) + + async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Scenario: + """ + Validate inputs and initialize the scenario eagerly. + + Performs all validation (scenario, initializers, target, strategies) and + calls scenario.initialize_async so that any errors are raised immediately + to the caller. Running initialization on creation simplifies error handling and ensures + that the scenario is fully ready to run when we spawn the background task. + + Args: + request: The run request with scenario name, target, and options. + + Returns: + The fully initialized Scenario instance ready for run_async. + + Raises: + ValueError: If any validation fails (bad scenario name, missing target, + invalid strategy, unknown initializer, etc.). + """ + # Validate scenario exists + scenario_registry = ScenarioRegistry.get_registry_singleton() + try: + scenario_class = scenario_registry.get_class(request.scenario_name) + except KeyError as e: + raise ValueError(str(e)) from None + + # Validate and run initializers + if request.initializers: + initializer_registry = InitializerRegistry.get_registry_singleton() + for initializer_name in request.initializers: + try: + initializer_class = initializer_registry.get_class(initializer_name) + except KeyError as e: + raise ValueError(f"Initializer not found: {e}") from None + instance = initializer_class() + if request.initializer_args and initializer_name in request.initializer_args: + instance.set_params_from_args(args=request.initializer_args[initializer_name]) + await instance.initialize_async() + + # Resolve target + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(request.target_name) + if objective_target is None: + available_names = target_registry.get_names() + if not available_names: + raise ValueError( + f"Target '{request.target_name}' not found. The target registry is empty. " + "Make sure to include an initializer that registers targets " + "(e.g., initializers: ['target'])." + ) + raise ValueError( + f"Target '{request.target_name}' not found in registry. Available targets: {', '.join(available_names)}" + ) + + # Build init kwargs + init_kwargs: dict[str, Any] = { + "objective_target": objective_target, + "max_concurrency": request.max_concurrency, + "max_retries": request.max_retries, + } + + if request.memory_labels: + init_kwargs["memory_labels"] = request.memory_labels + + # Validate and resolve strategies + if request.strategies: + strategy_class = scenario_class.get_strategy_class() + strategy_enums = [] + for name in request.strategies: + try: + strategy_enums.append(strategy_class(name)) + except ValueError: + available_strategies = [s.value for s in strategy_class] + raise ValueError( + f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " + f"Available: {', '.join(available_strategies)}" + ) from None + init_kwargs["scenario_strategies"] = strategy_enums + + # Build dataset config + if request.dataset_names: + init_kwargs["dataset_config"] = DatasetConfiguration( + dataset_names=request.dataset_names, + max_dataset_size=request.max_dataset_size, + ) + elif request.max_dataset_size is not None: + default_config = scenario_class.default_dataset_config() + default_config.max_dataset_size = request.max_dataset_size + init_kwargs["dataset_config"] = default_config + + # Instantiate and initialize scenario + constructor_kwargs: dict[str, Any] = {} + if request.scenario_result_id: + constructor_kwargs["scenario_result_id"] = request.scenario_result_id + scenario = scenario_class(**constructor_kwargs) # type: ignore[call-arg] + scenario.set_params_from_args(args=request.scenario_params or {}) + await scenario.initialize_async(**init_kwargs) + return scenario + + async def _execute_run_async(self, *, scenario_result_id: str) -> None: + """ + Execute a scenario run (background task entry point). + + Only calls scenario.run_async on the already-initialized scenario. + Removes the task from _active_tasks when done. + + Args: + scenario_result_id: The scenario result ID for this run. + """ + active = self._active_tasks[scenario_result_id] + assert active.scenario is not None + + try: + await active.scenario.run_async() + + except asyncio.CancelledError: + logger.info(f"Scenario run {scenario_result_id} was cancelled.") + + except Exception as e: + active.error = str(e) + logger.exception(f"Scenario run {scenario_result_id} failed: {e}") + + def _build_response(self, *, scenario_result_id: str) -> ScenarioRunResponse | None: + """ + Build a ScenarioRunResponse by querying the database and merging active task state. + + Args: + scenario_result_id: The scenario result ID. + + Returns: + ScenarioRunResponse if found in the database, None otherwise. + """ + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if not results: + return None + return self._build_response_from_db(scenario_result=results[0]) + + def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> ScenarioRunResponse: + """ + Build a ScenarioRunResponse from a database ScenarioResult, merged with active task info. + + Args: + scenario_result: A ScenarioResult retrieved from CentralMemory. + + Returns: + The API response model. + """ + scenario_result_id = str(scenario_result.id) + active = self._active_tasks.get(scenario_result_id) + + # Clean up finished active tasks after reading the error + error = None + if active is not None: + error = active.error + if active.task is not None and active.task.done(): + del self._active_tasks[scenario_result_id] + + status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + + # Build result summary for completed runs + result = None + if status == ScenarioRunStatus.COMPLETED: + completed_attacks = sum( + 1 + for results in scenario_result.attack_results.values() + for ar in results + if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) + ) + total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) + result = ScenarioRunResult( + scenario_result_id=scenario_result_id, + run_state=scenario_result.scenario_run_state, + strategies_used=scenario_result.get_strategies_used(), + total_attacks=total_attacks, + completed_attacks=completed_attacks, + number_tries=scenario_result.number_tries, + completion_time=scenario_result.completion_time, + ) + + return ScenarioRunResponse( + run_id=scenario_result_id, + scenario_name=scenario_result.scenario_identifier.name, + status=status, + created_at=scenario_result.created_at, + updated_at=scenario_result.completion_time, + error=error, + result=result, + ) + + def get_run_results(self, *, run_id: str) -> ScenarioResultDetailResponse | None: + """ + Get detailed results for a completed scenario run. + + Retrieves the full ScenarioResult from CentralMemory and maps it + to a detailed response model with per-attack outcomes. + + Args: + run_id: The scenario result ID (run identifier). + + Returns: + ScenarioResultDetailResponse if the run is completed and results exist, None if run not found. + + Raises: + ValueError: If the run is not in a completed state. + """ + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[run_id]) + if not results: + return None + + scenario_result = results[0] + + if scenario_result.scenario_run_state != "COMPLETED": + status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + raise ValueError(f"Results are only available for completed runs. Current status: '{status}'.") + + # Build per-attack detail + attacks: list[AtomicAttackResults] = [] + display_group_map = scenario_result.display_group_map + for attack_name, attack_results in scenario_result.attack_results.items(): + details: list[AttackResultDetail] = [] + success_count = 0 + failure_count = 0 + + for ar in attack_results: + score_value = None + if ar.last_score is not None: + score_value = str(ar.last_score.get_value()) + + last_response_text = None + if ar.last_response is not None: + last_response_text = str(ar.last_response) + + details.append( + AttackResultDetail( + attack_result_id=ar.attack_result_id, + conversation_id=ar.conversation_id, + objective=ar.objective, + outcome=ar.outcome.value, + outcome_reason=ar.outcome_reason, + last_response=last_response_text, + score_value=score_value, + executed_turns=ar.executed_turns, + execution_time_ms=ar.execution_time_ms, + timestamp=ar.timestamp, + ) + ) + + if ar.outcome == AttackOutcome.SUCCESS: + success_count += 1 + elif ar.outcome == AttackOutcome.FAILURE: + failure_count += 1 + + attacks.append( + AtomicAttackResults( + atomic_attack_name=attack_name, + display_group=display_group_map.get(attack_name), + results=details, + success_count=success_count, + failure_count=failure_count, + total_count=len(details), + ) + ) + + return ScenarioResultDetailResponse( + scenario_result_id=str(scenario_result.id), + scenario_name=scenario_result.scenario_identifier.name, + scenario_version=scenario_result.scenario_identifier.version, + run_state=scenario_result.scenario_run_state, + objective_achieved_rate=scenario_result.objective_achieved_rate(), + number_tries=scenario_result.number_tries, + completion_time=scenario_result.completion_time, + labels=scenario_result.labels, + attacks=attacks, + ) + + +_service_instance: ScenarioRunService | None = None + + +def get_scenario_run_service() -> ScenarioRunService: + """ + Get the global scenario run service instance. + + On first call, reads ``max_concurrent_scenario_runs`` from ``app.state`` + (set by ``pyrit_backend`` CLI) if available, otherwise uses the default. + + Returns: + The singleton ScenarioRunService instance. + """ + global _service_instance + if _service_instance is not None: + return _service_instance + + max_runs = _DEFAULT_MAX_CONCURRENT_RUNS + try: + from pyrit.backend.main import app + + max_runs = getattr(app.state, "max_concurrent_scenario_runs", _DEFAULT_MAX_CONCURRENT_RUNS) + except Exception: + pass + + _service_instance = ScenarioRunService(max_concurrent_runs=max_runs) + return _service_instance diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index b75634891a..c17eb83b54 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -146,6 +146,7 @@ def __init__( self._env_files = config._resolve_env_files() self._operator = config.operator self._operation = config.operation + self._max_concurrent_scenario_runs = config.max_concurrent_scenario_runs # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -221,6 +222,7 @@ def with_overrides( derived._env_files = self._env_files derived._operator = self._operator derived._operation = self._operation + derived._max_concurrent_scenario_runs = self._max_concurrent_scenario_runs derived._scenario_config = self._scenario_config # Apply overrides or inherit diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index f45cc0c448..8eed2cc929 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -198,6 +198,7 @@ async def initialize_and_run_async(*, parsed_args: Namespace) -> int: if context._operation: default_labels["operation"] = context._operation app.state.default_labels = default_labels + app.state.max_concurrent_scenario_runs = context._max_concurrent_scenario_runs display_host = parsed_args.host print(f"🚀 Starting PyRIT backend on http://{display_host}:{parsed_args.port}") diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index e23af1eabf..eebc771295 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -1354,7 +1354,7 @@ class TreeOfAttacksWithPruningAttack(AttackStrategy[TAPAttackContext, TAPAttackR def __init__( self, *, - objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-assignment, ty:invalid-parameter-default] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index dbb228b435..753236968d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -786,6 +786,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -795,6 +797,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (defaults to False). join_scores: Flag to join the scores table with entries (defaults to False). + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -814,8 +818,12 @@ def _query_entries( ) if conditions is not None: query = query.filter(conditions) + if order_by is not None: + query = query.order_by(order_by) if distinct: - return query.distinct().all() + query = query.distinct() + if limit is not None: + query = query.limit(limit) return query.all() except SQLAlchemyError as e: logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore[ty:unresolved-attribute] @@ -846,7 +854,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict # attributes from the (potentially stale) detached object # and silently overwrite concurrent updates to columns # that are NOT in update_fields. - entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute] + entry_in_session = session.get(type(entry), entry.id) if entry_in_session is None: entry_in_session = session.merge(entry) for field, value in update_fields.items(): diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0c3310c0ee..b28d05976e 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -354,6 +354,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -363,6 +365,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Whether to return distinct rows only. Defaults to False. join_scores: Whether to join the scores table. Defaults to False. + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -378,6 +382,8 @@ def _execute_batched_query( distinct: bool = False, join_scores: bool = False, batch_size: int | None = None, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Execute queries in batches to avoid exceeding database bind variable limits. @@ -394,6 +400,8 @@ def _execute_batched_query( join_scores: Whether to join the scores table. batch_size: Override for the number of values per batch. Defaults to ``_MAX_BIND_VARS`` when not specified. + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: MutableSequence[Model]: Merged and deduplicated results from all batched queries. @@ -411,6 +419,8 @@ def _execute_batched_query( conditions=and_(*conditions) if conditions else None, distinct=distinct, join_scores=join_scores, + order_by=order_by, + limit=limit, ) # Execute multiple separate queries and merge results @@ -426,6 +436,7 @@ def _execute_batched_query( conditions=and_(*conditions) if conditions else None, distinct=distinct, join_scores=join_scores, + order_by=order_by, ) # Deduplicate by primary key (id) @@ -2062,10 +2073,13 @@ def get_scenario_results( objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + limit: int | None = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. + Results are always ordered by completion_time descending (most recent first). + Args: scenario_result_ids (Optional[Sequence[str]], optional): A list of scenario result IDs. Defaults to None. @@ -2088,9 +2102,11 @@ def get_scenario_results( identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by identifier JSON properties. Defaults to None. + limit (int | None): Maximum number of results to return. Defaults to None (no limit). Returns: - Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. + Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters, + ordered by completion_time descending. """ if scenario_result_ids is not None and len(scenario_result_ids) == 0: return [] @@ -2149,6 +2165,8 @@ def get_scenario_results( ) try: + order_by_clause = ScenarioResultEntry.completion_time.desc() + # Handle scenario_result_ids with batched queries if needed if scenario_result_ids: entries = self._execute_batched_query( @@ -2156,9 +2174,16 @@ def get_scenario_results( batch_column=ScenarioResultEntry.id, batch_values=list(scenario_result_ids), other_conditions=conditions, + order_by=order_by_clause, + limit=limit, ) else: - entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) + entries = self._query_entries( + ScenarioResultEntry, + conditions=and_(*conditions) if conditions else None, + order_by=order_by_clause, + limit=limit, + ) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 6b89313ba3..e9b6b6659e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -949,7 +949,7 @@ class ScenarioResultEntry(Base): scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) - scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"]] = mapped_column( + scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"]] = mapped_column( String, nullable=False, default="CREATED" ) attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False) @@ -1053,6 +1053,7 @@ def get_scenario_result(self) -> ScenarioResult: objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type] scenario_run_state=self.scenario_run_state, labels=self.labels, + created_at=self.timestamp, number_tries=self.number_tries, completion_time=self.completion_time, display_group_map=display_group_map, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 53f6ce9134..0874428878 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -326,6 +326,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -335,6 +337,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (default is False). join_scores: Flag to join the scores table (default is False). + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -354,8 +358,12 @@ def _query_entries( ) if conditions is not None: query = query.filter(conditions) + if order_by is not None: + query = query.order_by(order_by) if distinct: - return query.distinct().all() + query = query.distinct() + if limit is not None: + query = query.limit(limit) return query.all() except SQLAlchemyError as e: logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore[ty:unresolved-attribute] diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 3b159846d2..7f237e8c9f 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -47,7 +47,7 @@ def __init__( self.init_data = init_data -ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"] +ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"] class ScenarioResult: @@ -64,6 +64,7 @@ def __init__( objective_scorer_identifier: "ComponentIdentifier", scenario_run_state: ScenarioRunState = "CREATED", labels: Optional[dict[str, str]] = None, + created_at: Optional[datetime] = None, completion_time: Optional[datetime] = None, number_tries: int = 0, id: Optional[uuid.UUID] = None, # noqa: A002 @@ -79,6 +80,7 @@ def __init__( objective_scorer_identifier (ComponentIdentifier): Objective scorer identifier. scenario_run_state (ScenarioRunState): Current scenario run state. labels (Optional[dict[str, str]]): Optional labels. + created_at (Optional[datetime]): When the scenario result was created. completion_time (Optional[datetime]): Optional completion timestamp. number_tries (int): Number of run attempts. id (Optional[uuid.UUID]): Optional scenario result ID. @@ -97,10 +99,16 @@ def __init__( self.scenario_run_state = scenario_run_state self.attack_results = attack_results self.labels = labels if labels is not None else {} + self.created_at = created_at if created_at is not None else datetime.now(timezone.utc) self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries self._display_group_map = display_group_map or {} + @property + def display_group_map(self) -> dict[str, str]: + """Mapping of atomic_attack_name → display group label.""" + return self._display_group_map + def get_strategies_used(self) -> list[str]: """ Get the list of strategies used in this scenario. diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 4b81d2041b..769cb51611 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -35,7 +35,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): def __init__( self, *, - converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-assignment, ty:invalid-parameter-default] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] system_prompt_template: Optional[SeedPrompt] = None, languages: Optional[list[str]] = None, word_selection_strategy: Optional[WordSelectionStrategy] = None, diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 754a0269ce..34168b41ae 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -149,7 +149,7 @@ def __init__( if callable(self._api_key): # Token provider - create an AsyncTokenCredential wrapper credential = AsyncTokenProviderCredential(self._api_key) # type: ignore[ty:invalid-argument-type] - self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) # type: ignore[ty:invalid-argument-type] + self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key if not isinstance(self._api_key, str): diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 184a235f65..e2e18e2350 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -132,6 +132,7 @@ class ConfigurationLoader(YamlLoadable): operator: Optional[str] = None operation: Optional[str] = None scenario: Optional[Union[str, dict[str, Any]]] = None + max_concurrent_scenario_runs: int = 3 def __post_init__(self) -> None: """Validate and normalize the configuration after loading.""" diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py new file mode 100644 index 0000000000..9099214e58 --- /dev/null +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -0,0 +1,313 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for scenario run API routes. +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from pyrit.backend.main import app +from pyrit.backend.models.scenarios import ( + ScenarioRunListResponse, + ScenarioRunResponse, + ScenarioRunStatus, +) +import pyrit.backend.services.scenario_run_service as _svc_mod + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the service singleton between tests.""" + _svc_mod._service_instance = None + yield + _svc_mod._service_instance = None + + +def _mock_run_response( + *, + run_id: str = "test-run-id", + scenario_name: str = "foundry.red_team_agent", + run_status: ScenarioRunStatus = ScenarioRunStatus.PENDING, +) -> ScenarioRunResponse: + """Create a mock ScenarioRunResponse.""" + return ScenarioRunResponse( + run_id=run_id, + scenario_name=scenario_name, + status=run_status, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + error=None, + result=None, + ) + + +class TestStartScenarioRunRoute: + """Tests for POST /api/scenarios/runs.""" + + def test_start_run_returns_202(self, client: TestClient) -> None: + """Test that a valid request returns 202 Accepted.""" + mock_response = _mock_run_response() + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={"scenario_name": "foundry.red_team_agent", "target_name": "my_target"}, + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + data = response.json() + assert data["run_id"] == "test-run-id" + assert data["status"] == "pending" + + def test_start_run_invalid_scenario_returns_400(self, client: TestClient) -> None: + """Test that an invalid scenario returns 400.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(side_effect=ValueError("'bad.scenario' not found in registry.")) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={"scenario_name": "bad.scenario", "target_name": "my_target"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "not found" in response.json()["detail"] + + def test_start_run_missing_required_fields_returns_422(self, client: TestClient) -> None: + """Test that missing required fields returns 422.""" + response = client.post("/api/scenarios/runs", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_start_run_with_all_options(self, client: TestClient) -> None: + """Test that all optional fields are accepted.""" + mock_response = _mock_run_response() + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={ + "scenario_name": "foundry.red_team_agent", + "target_name": "my_target", + "initializers": ["target", "load_default_datasets"], + "strategies": ["base64", "rot13"], + "dataset_names": ["harmful_content"], + "max_dataset_size": 50, + "max_concurrency": 5, + "max_retries": 2, + "memory_labels": {"team": "red"}, + "scenario_params": {"max_turns": 10, "threshold": 0.8}, + "initializer_args": {"target": {"endpoint": "https://example.com"}}, + }, + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + + +class TestListScenarioRunsRoute: + """Tests for GET /api/scenarios/runs.""" + + def test_list_runs_returns_200(self, client: TestClient) -> None: + """Test that list runs returns 200 with empty list.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.list_runs.return_value = ScenarioRunListResponse(items=[]) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["items"] == [] + + def test_list_runs_returns_multiple_runs(self, client: TestClient) -> None: + """Test that list runs returns all tracked runs.""" + runs = [ + _mock_run_response(run_id="run-1"), + _mock_run_response(run_id="run-2", run_status=ScenarioRunStatus.RUNNING), + ] + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.list_runs.return_value = ScenarioRunListResponse(items=runs) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs") + + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["items"]) == 2 + + +class TestGetScenarioRunRoute: + """Tests for GET /api/scenarios/runs/{run_id}.""" + + def test_get_run_returns_200(self, client: TestClient) -> None: + """Test that getting an existing run returns 200.""" + mock_response = _mock_run_response(run_status=ScenarioRunStatus.RUNNING) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run.return_value = mock_response + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["status"] == "running" + + def test_get_run_not_found_returns_404(self, client: TestClient) -> None: + """Test that getting a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run.return_value = None + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestCancelScenarioRunRoute: + """Tests for DELETE /api/scenarios/runs/{run_id}.""" + + def test_cancel_run_returns_200(self, client: TestClient) -> None: + """Test that cancelling a running scenario returns 200.""" + mock_response = _mock_run_response(run_status=ScenarioRunStatus.CANCELLED) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.delete("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["status"] == "cancelled" + + def test_cancel_run_not_found_returns_404(self, client: TestClient) -> None: + """Test that cancelling a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(return_value=None) + mock_get.return_value = mock_service + + response = client.delete("/api/scenarios/runs/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_cancel_completed_run_returns_409(self, client: TestClient) -> None: + """Test that cancelling a completed run returns 409 Conflict.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(side_effect=ValueError("Cannot cancel run in 'completed' state.")) + mock_get.return_value = mock_service + + response = client.delete("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_409_CONFLICT + assert "Cannot cancel" in response.json()["detail"] + + +class TestGetScenarioRunResultsRoute: + """Tests for GET /api/scenarios/runs/{run_id}/results.""" + + def test_get_results_returns_200(self, client: TestClient) -> None: + """Test that getting results of a completed run returns 200.""" + from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + AttackResultDetail, + ScenarioResultDetailResponse, + ) + + mock_result = ScenarioResultDetailResponse( + scenario_result_id="result-uuid", + scenario_name="foundry.red_team_agent", + scenario_version=1, + run_state="COMPLETED", + objective_achieved_rate=50, + number_tries=1, + completion_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + labels={"team": "red"}, + attacks=[ + AtomicAttackResults( + atomic_attack_name="base64_attack", + display_group="encoding", + results=[ + AttackResultDetail( + attack_result_id="ar-1", + conversation_id="conv-1", + objective="Extract sensitive info", + outcome="success", + outcome_reason="Model revealed data", + last_response="Here is the data...", + score_value="1.0", + executed_turns=3, + execution_time_ms=1500, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ), + ], + success_count=1, + failure_count=0, + total_count=1, + ), + ], + ) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.return_value = mock_result + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id/results") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["scenario_result_id"] == "result-uuid" + assert data["objective_achieved_rate"] == 50 + assert len(data["attacks"]) == 1 + assert data["attacks"][0]["atomic_attack_name"] == "base64_attack" + assert data["attacks"][0]["results"][0]["outcome"] == "success" + + def test_get_results_not_found_returns_404(self, client: TestClient) -> None: + """Test that getting results of a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.return_value = None + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/nonexistent/results") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_results_not_completed_returns_409(self, client: TestClient) -> None: + """Test that getting results of a non-completed run returns 409.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.side_effect = ValueError( + "Results are only available for completed runs. Current status: 'running'." + ) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id/results") + + assert response.status_code == status.HTTP_409_CONFLICT + assert "only available for completed runs" in response.json()["detail"] diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py new file mode 100644 index 0000000000..d1e653332d --- /dev/null +++ b/tests/unit/backend/test_scenario_run_service.py @@ -0,0 +1,492 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for ScenarioRunService. +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.backend.models.scenarios import ( + RunScenarioRequest, + ScenarioRunStatus, +) +from pyrit.backend.services.scenario_run_service import ( + _DEFAULT_MAX_CONCURRENT_RUNS, + ScenarioRunService, +) +import pyrit.backend.services.scenario_run_service as _svc_mod + +_REGISTRY_PATCH_BASE = "pyrit.registry" +_MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the singleton instance between tests.""" + _svc_mod._service_instance = None + yield + _svc_mod._service_instance = None + + +def _make_request( + *, + scenario_name: str = "foundry.red_team_agent", + target_name: str = "my_target", + initializers: list[str] | None = None, + strategies: list[str] | None = None, + scenario_result_id: str | None = None, +) -> RunScenarioRequest: + """Create a RunScenarioRequest for testing.""" + return RunScenarioRequest( + scenario_name=scenario_name, + target_name=target_name, + initializers=initializers, + strategies=strategies, + scenario_result_id=scenario_result_id, + ) + + +def _make_db_scenario_result( + *, + result_id: str = "sr-uuid-1", + scenario_name: str = "foundry.red_team_agent", + run_state: str = "IN_PROGRESS", + attack_results: dict | None = None, +) -> MagicMock: + """Create a mock ScenarioResult as returned by CentralMemory.""" + sr = MagicMock() + sr.id = result_id + sr.scenario_identifier.name = scenario_name + sr.scenario_identifier.version = 1 + sr.scenario_run_state = run_state + sr.get_strategies_used.return_value = [] + sr.attack_results = attack_results or {} + sr.number_tries = 1 + sr.created_at = datetime(2025, 1, 1, tzinfo=timezone.utc) + sr.completion_time = datetime(2025, 1, 1, 0, 5, tzinfo=timezone.utc) + sr.labels = {} + sr.objective_achieved_rate.return_value = 0 + sr.get_display_groups.return_value = {} + sr.display_group_map = {} + return sr + + +@pytest.fixture +def mock_memory(): + """Patch CentralMemory.get_memory_instance to return a mock.""" + mock = MagicMock() + mock.get_scenario_results.return_value = [] + with patch(_MEMORY_PATCH, return_value=mock): + yield mock + + +@pytest.fixture +def mock_all_registries(mock_memory): + """Patch all registries and CentralMemory with valid defaults.""" + mock_scenario_instance = MagicMock() + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock() + mock_scenario_instance._scenario_result_id = "sr-uuid-1" + + mock_scenario_class = MagicMock(return_value=mock_scenario_instance) + mock_scenario_class.get_strategy_class.return_value = MagicMock() + mock_scenario_class.default_dataset_config.return_value = MagicMock() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + mock_tr.get_names.return_value = ["my_target"] + + mock_ir = MagicMock() + mock_ir.get_class.return_value = MagicMock(return_value=MagicMock(initialize_async=AsyncMock())) + + # By default, return a matching DB result for get_run / list_runs queries + db_result = _make_db_scenario_result() + mock_memory.get_scenario_results.return_value = [db_result] + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + yield { + "scenario_registry": mock_sr, + "target_registry": mock_tr, + "initializer_registry": mock_ir, + "scenario_class": mock_scenario_class, + "scenario_instance": mock_scenario_instance, + "memory": mock_memory, + "db_result": db_result, + } + + +class TestScenarioRunServiceStartRun: + """Tests for ScenarioRunService.start_run_async.""" + + async def test_start_run_returns_running_status(self, mock_all_registries) -> None: + """Test that starting a run returns RUNNING status with run_id = scenario_result_id.""" + service = ScenarioRunService() + response = await service.start_run_async(request=_make_request()) + + assert response.run_id == "sr-uuid-1" + assert response.status == ScenarioRunStatus.RUNNING + assert response.scenario_name == "foundry.red_team_agent" + assert response.error is None + + async def test_start_run_invalid_scenario_raises_value_error(self, mock_memory) -> None: + """Test that an invalid scenario name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.side_effect = KeyError("'bad.scenario' not found in registry. Available: foo") + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton"), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="not found in registry"): + await service.start_run_async(request=_make_request(scenario_name="bad.scenario")) + + async def test_start_run_invalid_target_raises_value_error(self, mock_memory) -> None: + """Test that an invalid target name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = MagicMock() + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = None + mock_tr.get_names.return_value = ["other_target"] + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="my_target.*not found in registry"): + await service.start_run_async(request=_make_request()) + + async def test_start_run_invalid_initializer_raises_value_error(self, mock_memory) -> None: + """Test that an invalid initializer name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = MagicMock() + + mock_ir = MagicMock() + mock_ir.get_class.side_effect = KeyError("'bad_init' not found") + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton"), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + with pytest.raises(ValueError, match="Initializer not found"): + await service.start_run_async(request=_make_request(initializers=["bad_init"])) + + async def test_start_run_invalid_strategy_raises_value_error(self, mock_memory) -> None: + """Test that an invalid strategy name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_strategy_class = MagicMock(side_effect=ValueError("not a valid strategy")) + mock_strategy_class.__iter__ = MagicMock(return_value=iter([MagicMock(value="valid_strat")])) + + mock_scenario_class = MagicMock() + mock_scenario_class.get_strategy_class.return_value = mock_strategy_class + + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="Strategy.*not found for scenario"): + await service.start_run_async(request=_make_request(strategies=["bad_strategy"])) + + async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None: + """Test that exceeding concurrent run limit raises ValueError.""" + service = ScenarioRunService() + scenario_instance = mock_all_registries["scenario_instance"] + + # Each call needs a unique scenario_result_id + call_count = 0 + original_init = scenario_instance.initialize_async + + async def _set_unique_id(**kwargs: object) -> None: + nonlocal call_count + call_count += 1 + scenario_instance._scenario_result_id = f"sr-uuid-{call_count}" + + scenario_instance.initialize_async = AsyncMock(side_effect=_set_unique_id) + + # Fill up to the limit + for _ in range(_DEFAULT_MAX_CONCURRENT_RUNS): + await service.start_run_async(request=_make_request()) + + # Next one should fail + with pytest.raises(ValueError, match="Maximum concurrent runs"): + await service.start_run_async(request=_make_request()) + + async def test_start_run_runs_initializers(self, mock_all_registries) -> None: + """Test that initializers are run during start_run_async.""" + service = ScenarioRunService() + mock_ir = mock_all_registries["initializer_registry"] + mock_init_instance = mock_ir.get_class.return_value.return_value + + response = await service.start_run_async( + request=_make_request(initializers=["target", "load_default_datasets"]) + ) + + assert response.status == ScenarioRunStatus.RUNNING + assert mock_init_instance.initialize_async.await_count == 2 + + async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: + """Test that scenario_result_id is passed to the scenario constructor for resumption.""" + service = ScenarioRunService() + mock_scenario_class = mock_all_registries["scenario_class"] + + response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) + + assert response.status == ScenarioRunStatus.RUNNING + mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") + + async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: + """Test that scenario_result_id is not passed to constructor when not provided.""" + service = ScenarioRunService() + mock_scenario_class = mock_all_registries["scenario_class"] + + await service.start_run_async(request=_make_request()) + + mock_scenario_class.assert_called_once_with() + + +class TestScenarioRunServiceGetRun: + """Tests for ScenarioRunService.get_run.""" + + def test_get_run_returns_none_for_unknown_id(self, mock_memory) -> None: + """Test that get_run returns None for non-existent run_id.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = service.get_run(run_id="nonexistent-id") + assert result is None + + def test_get_run_returns_existing_run(self, mock_memory) -> None: + """Test that get_run returns a run from the database.""" + db_result = _make_db_scenario_result(result_id="sr-123", run_state="IN_PROGRESS") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + fetched = service.get_run(run_id="sr-123") + + assert fetched is not None + assert fetched.run_id == "sr-123" + assert fetched.scenario_name == "foundry.red_team_agent" + assert fetched.status == ScenarioRunStatus.RUNNING + + +class TestScenarioRunServiceListRuns: + """Tests for ScenarioRunService.list_runs.""" + + def test_list_runs_empty(self, mock_memory) -> None: + """Test that list_runs returns empty list when DB has no results.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = service.list_runs() + assert result.items == [] + mock_memory.get_scenario_results.assert_called_once_with(limit=100) + + def test_list_runs_returns_all_runs(self, mock_memory) -> None: + """Test that list_runs returns all runs from the database.""" + db_results = [ + _make_db_scenario_result(result_id="sr-1", run_state="COMPLETED"), + _make_db_scenario_result(result_id="sr-2", run_state="IN_PROGRESS"), + ] + mock_memory.get_scenario_results.return_value = db_results + + service = ScenarioRunService() + result = service.list_runs() + assert len(result.items) == 2 + mock_memory.get_scenario_results.assert_called_once_with(limit=100) + + def test_list_runs_passes_custom_limit(self, mock_memory) -> None: + """Test that list_runs passes a custom limit to the memory query.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + service.list_runs(limit=10) + mock_memory.get_scenario_results.assert_called_once_with(limit=10) + + +class TestScenarioRunServiceCancelRun: + """Tests for ScenarioRunService.cancel_run_async.""" + + async def test_cancel_run_returns_none_for_unknown_id(self, mock_memory) -> None: + """Test that cancel returns None for non-existent run_id.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = await service.cancel_run_async(run_id="nonexistent-id") + assert result is None + + async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> None: + """Test that cancelling a running scenario persists CANCELLED to DB.""" + service = ScenarioRunService() + mock_memory = mock_all_registries["memory"] + response = await service.start_run_async(request=_make_request()) + + # After update_scenario_run_state, the next DB query should return CANCELLED + running_result = mock_all_registries["db_result"] + cancelled_result = _make_db_scenario_result(result_id=response.run_id, run_state="CANCELLED") + mock_memory.get_scenario_results.side_effect = [[running_result], [cancelled_result]] + + result = await service.cancel_run_async(run_id=response.run_id) + + mock_memory.update_scenario_run_state.assert_called_once_with( + scenario_result_id=response.run_id, scenario_run_state="CANCELLED" + ) + assert result is not None + assert result.status == ScenarioRunStatus.CANCELLED + + async def test_cancel_completed_run_raises_value_error(self, mock_memory) -> None: + """Test that cancelling a completed run raises ValueError.""" + db_result = _make_db_scenario_result(result_id="sr-done", run_state="COMPLETED") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + with pytest.raises(ValueError, match="Cannot cancel run"): + await service.cancel_run_async(run_id="sr-done") + + async def test_cancel_already_cancelled_run_raises_value_error(self, mock_memory) -> None: + """Test that cancelling an already-cancelled run raises ValueError.""" + db_result = _make_db_scenario_result(result_id="sr-cancelled", run_state="CANCELLED") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + with pytest.raises(ValueError, match="Cannot cancel run"): + await service.cancel_run_async(run_id="sr-cancelled") + + +class TestScenarioRunServiceExecution: + """Tests for the background execution logic.""" + + async def test_execute_run_completes_successfully(self, mock_all_registries) -> None: + """Test that a successful execution removes active task and DB reflects COMPLETED.""" + service = ScenarioRunService() + mock_instance = mock_all_registries["scenario_instance"] + mock_memory = mock_all_registries["memory"] + + mock_scenario_result = MagicMock() + mock_scenario_result.id = "sr-uuid-1" + mock_scenario_result.scenario_run_state = "COMPLETED" + mock_scenario_result.get_strategies_used.return_value = ["base64"] + mock_scenario_result.attack_results = {"attack1": []} + mock_scenario_result.number_tries = 1 + mock_scenario_result.created_at = datetime(2025, 1, 1, tzinfo=timezone.utc) + mock_scenario_result.completion_time = datetime(2025, 1, 1, 0, 5, tzinfo=timezone.utc) + + mock_instance.run_async = AsyncMock(return_value=mock_scenario_result) + + response = await service.start_run_async(request=_make_request()) + + # Wait for the background task to complete + active = service._active_tasks.get(response.run_id) + assert active is not None + assert active.task is not None + await active.task + + # Active task is cleaned up on next get_run (deferred cleanup) + assert response.run_id in service._active_tasks + fetched = service.get_run(run_id=response.run_id) + assert fetched is not None + assert response.run_id not in service._active_tasks + + async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: + """Test that a run_async failure stores error and surfaces it via get_run.""" + service = ScenarioRunService() + mock_instance = mock_all_registries["scenario_instance"] + + mock_instance.run_async = AsyncMock(side_effect=RuntimeError("scenario exploded")) + + response = await service.start_run_async(request=_make_request()) + + # Wait for the background task + active = service._active_tasks.get(response.run_id) + assert active is not None + assert active.task is not None + await active.task + + # Error is stored on the active task until get_run reads it + assert active.error == "scenario exploded" + assert response.run_id in service._active_tasks + + # get_run should surface the error and clean up + fetched = service.get_run(run_id=response.run_id) + assert fetched is not None + assert fetched.error == "scenario exploded" + assert response.run_id not in service._active_tasks + + +class TestScenarioRunServiceGetResults: + """Tests for ScenarioRunService.get_run_results.""" + + def test_get_results_returns_none_for_unknown_id(self, mock_memory) -> None: + """Test that get_run_results returns None for non-existent run_id.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + result = service.get_run_results(run_id="nonexistent-id") + assert result is None + + def test_get_results_raises_if_not_completed(self, mock_memory) -> None: + """Test that get_run_results raises ValueError if run is not completed.""" + db_result = _make_db_scenario_result(result_id="sr-running", run_state="IN_PROGRESS") + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + with pytest.raises(ValueError, match="only available for completed runs"): + service.get_run_results(run_id="sr-running") + + def test_get_results_returns_details_for_completed_run(self, mock_memory) -> None: + """Test that get_run_results returns full details for a completed run.""" + from pyrit.models import AttackOutcome + + mock_attack_result = MagicMock() + mock_attack_result.attack_result_id = "ar-1" + mock_attack_result.conversation_id = "conv-1" + mock_attack_result.objective = "Extract info" + mock_attack_result.outcome = AttackOutcome.SUCCESS + mock_attack_result.outcome_reason = "Model complied" + mock_attack_result.last_response = MagicMock(value="Here is the data") + mock_attack_result.last_score = MagicMock() + mock_attack_result.last_score.get_value.return_value = "1.0" + mock_attack_result.executed_turns = 3 + mock_attack_result.execution_time_ms = 1500 + mock_attack_result.timestamp = None + + db_result = _make_db_scenario_result( + result_id="sr-123", + run_state="COMPLETED", + attack_results={"base64_attack": [mock_attack_result]}, + ) + db_result.objective_achieved_rate.return_value = 100 + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + detail = service.get_run_results(run_id="sr-123") + + assert detail is not None + assert detail.scenario_result_id == "sr-123" + assert detail.objective_achieved_rate == 100 + assert len(detail.attacks) == 1 + assert detail.attacks[0].atomic_attack_name == "base64_attack" + assert detail.attacks[0].success_count == 1 + assert detail.attacks[0].results[0].objective == "Extract info" + assert detail.attacks[0].results[0].outcome == "success" diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 7f435d76a5..1a56086c25 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -210,7 +210,7 @@ class TestScenarioRoutes: """Tests for scenario API routes.""" def test_list_scenarios_returns_200(self, client: TestClient) -> None: - """Test that GET /api/scenarios returns 200.""" + """Test that GET /api/scenarios/catalog returns 200.""" with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( @@ -221,7 +221,7 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios") + response = client.get("/api/scenarios/catalog") assert response.status_code == status.HTTP_200_OK data = response.json() @@ -229,7 +229,7 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: assert data["pagination"]["has_more"] is False def test_list_scenarios_with_items(self, client: TestClient) -> None: - """Test that GET /api/scenarios returns scenario data.""" + """Test that GET /api/scenarios/catalog returns scenario data.""" summary = ScenarioSummary( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", @@ -251,7 +251,7 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios") + response = client.get("/api/scenarios/catalog") assert response.status_code == status.HTTP_200_OK data = response.json() @@ -277,13 +277,13 @@ def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> No ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios?limit=10&cursor=test.scenario_1") + response = client.get("/api/scenarios/catalog?limit=10&cursor=test.scenario_1") assert response.status_code == status.HTTP_200_OK mock_service.list_scenarios_async.assert_called_once_with(limit=10, cursor="test.scenario_1") def test_get_scenario_returns_200(self, client: TestClient) -> None: - """Test that GET /api/scenarios/{name} returns 200 when found.""" + """Test that GET /api/scenarios/catalog/{name} returns 200 when found.""" summary = ScenarioSummary( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", @@ -300,20 +300,20 @@ def test_get_scenario_returns_200(self, client: TestClient) -> None: mock_service.get_scenario_async = AsyncMock(return_value=summary) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/foundry.red_team_agent") + response = client.get("/api/scenarios/catalog/foundry.red_team_agent") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["scenario_name"] == "foundry.red_team_agent" def test_get_scenario_returns_404_when_not_found(self, client: TestClient) -> None: - """Test that GET /api/scenarios/{name} returns 404 when not found.""" + """Test that GET /api/scenarios/catalog/{name} returns 404 when not found.""" with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.get_scenario_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/nonexistent") + response = client.get("/api/scenarios/catalog/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND @@ -335,7 +335,7 @@ def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: mock_service.get_scenario_async = AsyncMock(return_value=summary) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/garak.encoding") + response = client.get("/api/scenarios/catalog/garak.encoding") assert response.status_code == status.HTTP_200_OK mock_service.get_scenario_async.assert_called_once_with(scenario_name="garak.encoding")