From 9947fd15065ebb2459a4c66a11dce2fa1ae738aa Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sun, 25 Jan 2026 06:15:50 -0800 Subject: [PATCH 01/35] backend with APIs that will be needed in the short-term. --- pyproject.toml | 2 + pyrit/backend/main.py | 25 +- pyrit/backend/middleware/__init__.py | 8 + pyrit/backend/middleware/error_handlers.py | 182 +++++++ pyrit/backend/models/__init__.py | 103 ++++ pyrit/backend/models/common.py | 143 +++++ pyrit/backend/models/conversations.py | 201 +++++++ pyrit/backend/models/converters.py | 60 +++ pyrit/backend/models/memory.py | 125 +++++ pyrit/backend/models/registry.py | 140 +++++ pyrit/backend/routes/__init__.py | 11 + pyrit/backend/routes/conversations.py | 338 ++++++++++++ pyrit/backend/routes/converters.py | 104 ++++ pyrit/backend/routes/memory.py | 206 +++++++ pyrit/backend/routes/registry.py | 125 +++++ pyrit/backend/services/__init__.py | 28 +- .../backend/services/conversation_service.py | 508 ++++++++++++++++++ pyrit/backend/services/memory_service.py | 444 +++++++++++++++ pyrit/backend/services/registry_service.py | 307 +++++++++++ tests/unit/backend/__init__.py | 2 + tests/unit/backend/test_common_models.py | 405 ++++++++++++++ .../unit/backend/test_conversation_service.py | 243 +++++++++ tests/unit/backend/test_error_handlers.py | 200 +++++++ tests/unit/backend/test_memory_service.py | 267 +++++++++ tests/unit/backend/test_registry_service.py | 156 ++++++ tests/unit/backend/test_routes.py | 351 ++++++++++++ 26 files changed, 4680 insertions(+), 4 deletions(-) create mode 100644 pyrit/backend/middleware/__init__.py create mode 100644 pyrit/backend/middleware/error_handlers.py create mode 100644 pyrit/backend/models/common.py create mode 100644 pyrit/backend/models/conversations.py create mode 100644 pyrit/backend/models/converters.py create mode 100644 pyrit/backend/models/memory.py create mode 100644 pyrit/backend/models/registry.py create mode 100644 pyrit/backend/routes/conversations.py create mode 100644 pyrit/backend/routes/converters.py create mode 100644 pyrit/backend/routes/memory.py create mode 100644 pyrit/backend/routes/registry.py create mode 100644 pyrit/backend/services/conversation_service.py create mode 100644 pyrit/backend/services/memory_service.py create mode 100644 pyrit/backend/services/registry_service.py create mode 100644 tests/unit/backend/__init__.py create mode 100644 tests/unit/backend/test_common_models.py create mode 100644 tests/unit/backend/test_conversation_service.py create mode 100644 tests/unit/backend/test_error_handlers.py create mode 100644 tests/unit/backend/test_memory_service.py create mode 100644 tests/unit/backend/test_registry_service.py create mode 100644 tests/unit/backend/test_routes.py diff --git a/pyproject.toml b/pyproject.toml index 6e51277d4b..abf251d8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,6 +280,8 @@ notice-rgx = "Copyright \\(c\\) Microsoft Corporation\\.\\s*\\n.*Licensed under # https://github.com/Azure/PyRIT/issues/1176 is fully resolved # TODO: Remove these ignores once the issues are fixed "pyrit/{auxiliary_attacks,exceptions,models,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"] +# Backend API routes raise HTTPException handled by FastAPI, not true exceptions +"pyrit/backend/**/*.py" = ["DOC501"] "pyrit/__init__.py" = ["D104"] [tool.ruff.lint.pydocstyle] diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index c703f592cc..e4dc422c8f 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -9,13 +9,14 @@ import sys from pathlib import Path -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles import pyrit -from pyrit.backend.routes import health, version +from pyrit.backend.middleware import register_error_handlers +from pyrit.backend.routes import conversations, converters, health, memory, registry, version from pyrit.setup.initialization import initialize_pyrit_async # Check for development mode from environment variable @@ -27,6 +28,9 @@ version=pyrit.__version__, ) +# Register RFC 7807 error handlers +register_error_handlers(app) + # Initialize PyRIT on startup to load .env and .env.local files @app.on_event("startup") @@ -46,7 +50,19 @@ async def startup_event_async() -> None: ) -# Include routers +# Create versioned API router +api_v1 = APIRouter(prefix="/api/v1") + +# Include v1 routes +api_v1.include_router(conversations.router) +api_v1.include_router(converters.router) +api_v1.include_router(memory.router) +api_v1.include_router(registry.router) + +# Mount versioned API +app.include_router(api_v1) + +# Include legacy/non-versioned routes app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(version.router, tags=["version"]) @@ -76,6 +92,9 @@ async def global_exception_handler_async(request: object, exc: Exception) -> JSO """ Handle all unhandled exceptions globally. + Note: This is a fallback handler. Most exceptions are handled by + the RFC 7807 error handlers in middleware/error_handlers.py. + Returns: JSONResponse: Error response with 500 status code. """ diff --git a/pyrit/backend/middleware/__init__.py b/pyrit/backend/middleware/__init__.py new file mode 100644 index 0000000000..8b97c3937a --- /dev/null +++ b/pyrit/backend/middleware/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Middleware module for backend.""" + +from pyrit.backend.middleware.error_handlers import register_error_handlers + +__all__ = ["register_error_handlers"] diff --git a/pyrit/backend/middleware/error_handlers.py b/pyrit/backend/middleware/error_handlers.py new file mode 100644 index 0000000000..446db678e0 --- /dev/null +++ b/pyrit/backend/middleware/error_handlers.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Error handling middleware for RFC 7807 compliant responses. +""" + +import logging + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +from pyrit.backend.models.common import FieldError, ProblemDetail + +logger = logging.getLogger(__name__) + + +def register_error_handlers(app: FastAPI) -> None: + """Register all error handlers with the FastAPI app.""" + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler( + request: Request, + exc: RequestValidationError, + ) -> JSONResponse: + """ + Handle Pydantic validation errors with RFC 7807 format. + + Returns: + JSONResponse: RFC 7807 problem detail response with validation errors. + """ + errors = [] + for error in exc.errors(): + field_path = ".".join(str(loc) for loc in error["loc"]) + errors.append( + FieldError( + field=field_path, + message=error["msg"], + code=error["type"], + ) # type: ignore[call-arg] + ) + + problem = ProblemDetail( + type="/errors/validation-error", + title="Validation Error", + status=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Request validation failed", + instance=str(request.url.path), + errors=errors, + ) + + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=problem.model_dump(exclude_none=True), + ) + + @app.exception_handler(ValueError) + async def value_error_handler( + request: Request, + exc: ValueError, + ) -> JSONResponse: + """ + Handle ValueError as 400 Bad Request. + + Returns: + JSONResponse: RFC 7807 problem detail response with 400 status. + """ + problem = ProblemDetail( + type="/errors/bad-request", + title="Bad Request", + status=status.HTTP_400_BAD_REQUEST, + detail=str(exc), + instance=str(request.url.path), + ) # type: ignore[call-arg] + + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=problem.model_dump(exclude_none=True), + ) + + @app.exception_handler(FileNotFoundError) + async def not_found_handler( + request: Request, + exc: FileNotFoundError, + ) -> JSONResponse: + """ + Handle FileNotFoundError as 404 Not Found. + + Returns: + JSONResponse: RFC 7807 problem detail response with 404 status. + """ + problem = ProblemDetail( + type="/errors/not-found", + title="Not Found", + status=status.HTTP_404_NOT_FOUND, + detail=str(exc), + instance=str(request.url.path), + ) # type: ignore[call-arg] + + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=problem.model_dump(exclude_none=True), + ) + + @app.exception_handler(PermissionError) + async def permission_error_handler( + request: Request, + exc: PermissionError, + ) -> JSONResponse: + """ + Handle PermissionError as 403 Forbidden. + + Returns: + JSONResponse: RFC 7807 problem detail response with 403 status. + """ + problem = ProblemDetail( + type="/errors/forbidden", + title="Forbidden", + status=status.HTTP_403_FORBIDDEN, + detail=str(exc), + instance=str(request.url.path), + ) # type: ignore[call-arg] + + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content=problem.model_dump(exclude_none=True), + ) + + @app.exception_handler(NotImplementedError) + async def not_implemented_handler( + request: Request, + exc: NotImplementedError, + ) -> JSONResponse: + """ + Handle NotImplementedError as 501 Not Implemented. + + Returns: + JSONResponse: RFC 7807 problem detail response with 501 status. + """ + problem = ProblemDetail( + type="/errors/not-implemented", + title="Not Implemented", + status=status.HTTP_501_NOT_IMPLEMENTED, + detail=str(exc) or "This feature is not yet implemented", + instance=str(request.url.path), + ) # type: ignore[call-arg] + + return JSONResponse( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + content=problem.model_dump(exclude_none=True), + ) + + @app.exception_handler(Exception) + async def generic_exception_handler( + request: Request, + exc: Exception, + ) -> JSONResponse: + """ + Handle unexpected exceptions with RFC 7807 format. + + Returns: + JSONResponse: RFC 7807 problem detail response with 500 status. + """ + # Log the full exception for debugging + logger.error( + f"Unhandled exception on {request.method} {request.url.path}: {exc}", + exc_info=True, + ) + + problem = ProblemDetail( + type="/errors/internal-error", + title="Internal Server Error", + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred", + instance=str(request.url.path), + ) # type: ignore[call-arg] + + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=problem.model_dump(exclude_none=True), + ) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 6e8bc8cb78..17554d2781 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -2,5 +2,108 @@ # Licensed under the MIT license. """ +Backend models package. + Pydantic models for API requests and responses. """ + +from pyrit.backend.models.common import ( + ALLOWED_IDENTIFIER_FIELDS, + SENSITIVE_FIELD_PATTERNS, + FieldError, + IdentifierDict, + PaginatedResponse, + PaginationInfo, + ProblemDetail, + filter_sensitive_fields, +) +from pyrit.backend.models.conversations import ( + BranchConversationRequest, + BranchConversationResponse, + ConversationResponse, + ConverterConfig, + ConvertersResponse, + CreateConversationRequest, + CreateConversationResponse, + MessagePieceInput, + MessagePieceResponse, + MessageResponse, + SendMessageRequest, + SendMessageResponse, + SetConvertersRequest, + SetSystemPromptRequest, + SystemPromptResponse, +) +from pyrit.backend.models.converters import ( + ConversionStep, + ConverterListResponse, + ConverterMetadataResponse, + PreviewConverterRequest, + PreviewConverterResponse, +) +from pyrit.backend.models.memory import ( + AttackResultQueryResponse, + MessageQueryResponse, + ScenarioResultQueryResponse, + ScoreQueryResponse, + SeedQueryResponse, +) +from pyrit.backend.models.registry import ( + InitializerListResponse, + InitializerMetadataResponse, + ScenarioListResponse, + ScenarioMetadataResponse, + ScorerListResponse, + ScorerMetadataResponse, + TargetListResponse, + TargetMetadataResponse, +) + +__all__ = [ + # Common + "ALLOWED_IDENTIFIER_FIELDS", + "SENSITIVE_FIELD_PATTERNS", + "FieldError", + "filter_sensitive_fields", + "IdentifierDict", + "PaginatedResponse", + "PaginationInfo", + "ProblemDetail", + # Conversations + "BranchConversationRequest", + "BranchConversationResponse", + "ConversationResponse", + "ConverterConfig", + "ConvertersResponse", + "CreateConversationRequest", + "CreateConversationResponse", + "MessagePieceInput", + "MessagePieceResponse", + "MessageResponse", + "SendMessageRequest", + "SendMessageResponse", + "SetConvertersRequest", + "SetSystemPromptRequest", + "SystemPromptResponse", + # Converters + "ConversionStep", + "ConverterListResponse", + "ConverterMetadataResponse", + "PreviewConverterRequest", + "PreviewConverterResponse", + # Memory + "AttackResultQueryResponse", + "MessageQueryResponse", + "ScenarioResultQueryResponse", + "ScoreQueryResponse", + "SeedQueryResponse", + # Registry + "InitializerListResponse", + "InitializerMetadataResponse", + "ScenarioListResponse", + "ScenarioMetadataResponse", + "ScorerListResponse", + "ScorerMetadataResponse", + "TargetListResponse", + "TargetMetadataResponse", +] diff --git a/pyrit/backend/models/common.py b/pyrit/backend/models/common.py new file mode 100644 index 0000000000..3aadc8157a --- /dev/null +++ b/pyrit/backend/models/common.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Common response models for the PyRIT API. + +Includes pagination, error handling (RFC 7807), and shared base models. +""" + +from datetime import datetime +from typing import Any, Generic, List, Optional, TypeVar + +from pydantic import BaseModel, Field + +T = TypeVar("T") + + +class PaginationInfo(BaseModel): + """Pagination metadata for list responses.""" + + limit: int = Field(..., description="Maximum items per page") + has_more: bool = Field(..., description="Whether more items exist") + next_cursor: Optional[str] = Field(None, description="Cursor for next page") + prev_cursor: Optional[str] = Field(None, description="Cursor for previous page") + + +class PaginatedResponse(BaseModel, Generic[T]): + """Generic paginated response wrapper.""" + + items: List[T] = Field(..., description="List of items") + pagination: PaginationInfo = Field(..., description="Pagination metadata") + + +class FieldError(BaseModel): + """Individual field validation error.""" + + field: str = Field(..., description="Field name with path (e.g., 'pieces[0].data_type')") + message: str = Field(..., description="Error message") + code: Optional[str] = Field(None, description="Error code") + value: Optional[Any] = Field(None, description="The invalid value") + + +class ProblemDetail(BaseModel): + """ + RFC 7807 Problem Details response. + + Used for all error responses to provide consistent error formatting. + """ + + type: str = Field(..., description="Error type URI (e.g., '/errors/validation-error')") + title: str = Field(..., description="Short human-readable summary") + status: int = Field(..., description="HTTP status code") + detail: str = Field(..., description="Human-readable explanation") + instance: Optional[str] = Field(None, description="URI of the specific occurrence") + errors: Optional[List[FieldError]] = Field(None, description="Field-level errors for validation") + + +class IdentifierDict(BaseModel): + """ + Represents a filtered identifier dictionary. + + Only contains safe fields (no API keys, tokens, etc.). + Uses 'type_' and 'module_' as field names but serializes to '__type__' and '__module__'. + """ + + type_: str = Field(..., alias="__type__", description="Class name") + module_: Optional[str] = Field(None, alias="__module__", description="Module path") + + model_config = { + "extra": "allow", # Allow additional fields like endpoint, model_name, etc. + "populate_by_name": True, + } + + +class TimestampMixin(BaseModel): + """Mixin for models with timestamps.""" + + timestamp: datetime = Field(..., description="Creation/event timestamp") + created_at: Optional[datetime] = Field(None, description="Resource creation time") + + +# Sensitive field patterns to filter from identifiers +SENSITIVE_FIELD_PATTERNS = frozenset( + [ + "api_key", + "_api_key", + "token", + "secret", + "password", + "credential", + "auth", + "key", + ] +) + +# Fields allowed in identifier responses +ALLOWED_IDENTIFIER_FIELDS = frozenset( + [ + "__type__", + "__module__", + "endpoint", + "model_name", + "deployment_name", + "underlying_model", + "temperature", + "top_p", + "language", + "tone", + ] +) + + +def filter_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]: + """ + Recursively filter sensitive fields from a dictionary. + + Args: + data: Dictionary potentially containing sensitive fields. + + Returns: + dict[str, Any]: Dictionary with sensitive fields removed. + """ + if not isinstance(data, dict): + return data + + filtered: dict[str, Any] = {} + for key, value in data.items(): + # Check if key matches sensitive patterns + key_lower = key.lower() + is_sensitive = any(pattern in key_lower for pattern in SENSITIVE_FIELD_PATTERNS) + + if is_sensitive: + continue + + # Recursively filter nested dicts + if isinstance(value, dict): + filtered[key] = filter_sensitive_fields(value) + elif isinstance(value, list): + filtered[key] = [filter_sensitive_fields(item) if isinstance(item, dict) else item for item in value] + else: + filtered[key] = value + + return filtered diff --git a/pyrit/backend/models/conversations.py b/pyrit/backend/models/conversations.py new file mode 100644 index 0000000000..f151148c36 --- /dev/null +++ b/pyrit/backend/models/conversations.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Conversation-related request and response models. + +These models align with PyRIT's MessagePiece and Message structures. +""" + +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + +from pyrit.models import PromptDataType, PromptResponseError + + +class ConverterConfig(BaseModel): + """Configuration for a single converter.""" + + class_name: str = Field(..., description="Converter class name (e.g., 'TranslationConverter')") + module: str = Field( + default="pyrit.prompt_converter", + description="Module containing the converter class", + ) + params: Optional[Dict[str, Any]] = Field(default=None, description="Constructor parameters") + + +# ============================================================================ +# Conversation Creation +# ============================================================================ + + +class CreateConversationRequest(BaseModel): + """Request to create a new conversation.""" + + target_class: str = Field( + ..., + description="Target class name (e.g., 'TextTarget', 'AzureOpenAIGPT4OChatTarget')", + ) + target_params: Optional[Dict[str, Any]] = Field( + None, + description="Constructor parameters for the target", + ) + labels: Optional[Dict[str, str]] = Field(None, description="Key-value labels for filtering") + + +class CreateConversationResponse(BaseModel): + """Response after creating a conversation.""" + + conversation_id: str = Field(..., description="Unique conversation identifier") + target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") + labels: Optional[Dict[str, str]] = Field(None, description="Applied labels") + created_at: datetime = Field(..., description="Creation timestamp") + + +# ============================================================================ +# System Prompt +# ============================================================================ + + +class SetSystemPromptRequest(BaseModel): + """Request to set the system prompt for a conversation.""" + + system_prompt: str = Field(..., description="The system prompt text") + + +class SystemPromptResponse(BaseModel): + """Response containing the system prompt.""" + + system_prompt: Optional[str] = Field(None, description="Current system prompt") + piece_id: Optional[str] = Field(None, description="ID of the system prompt message piece") + + +# ============================================================================ +# Converter Configuration +# ============================================================================ + + +class SetConvertersRequest(BaseModel): + """Request to set the converter chain for a conversation.""" + + converters: List[ConverterConfig] = Field(..., description="Ordered list of converters") + + +class ConvertersResponse(BaseModel): + """Response containing the converter chain.""" + + converters: List[ConverterConfig] = Field(default_factory=list, description="Current converter chain") + + +# ============================================================================ +# Message Pieces (aligned with MessagePiece) +# ============================================================================ + + +class MessagePieceInput(BaseModel): + """ + Input for a single message piece. + + Aligned with pyrit.models.MessagePiece fields. + """ + + original_value: Optional[str] = Field(None, description="Text content (for text type)") + original_value_data_type: PromptDataType = Field(..., description="Data type of the content") + file_name: Optional[str] = Field(None, description="Filename in multipart request (for file types)") + converted_value: Optional[str] = Field(None, description="Pre-converted content (if pre_converted=true)") + converted_value_data_type: Optional[PromptDataType] = Field(None, description="Data type after conversion") + converter_identifiers: Optional[List[Dict[str, Any]]] = Field( + None, description="Converters already applied (if pre_converted=true)" + ) + + +class MessagePieceResponse(BaseModel): + """ + Response model for a single message piece. + + Aligned with pyrit.models.MessagePiece fields. + """ + + id: str = Field(..., description="Unique piece identifier (UUID)") + original_value: str = Field(..., description="Original content or file path") + original_value_data_type: PromptDataType = Field(..., description="Original data type") + converted_value: str = Field(..., description="Converted content or file path") + converted_value_data_type: PromptDataType = Field(..., description="Converted data type") + converter_identifiers: List[Dict[str, Any]] = Field( + default_factory=list, description="Applied converters with params" + ) + response_error: Optional[PromptResponseError] = Field(None, description="Error type if any") + timestamp: Optional[datetime] = Field(None, description="Piece timestamp") + + +# ============================================================================ +# Messages +# ============================================================================ + +ChatMessageRole = Literal["system", "user", "assistant", "simulated_assistant", "tool", "developer"] + + +class MessageResponse(BaseModel): + """Response model for a message (group of pieces with same sequence).""" + + sequence: int = Field(..., description="Sequence number in conversation") + role: ChatMessageRole = Field(..., description="Message role") + pieces: List[MessagePieceResponse] = Field(..., description="Message content pieces") + timestamp: datetime = Field(..., description="Message timestamp") + + +class SendMessageRequest(BaseModel): + """ + Request to send a message. + + Note: For file uploads, use multipart/form-data with 'pieces' as JSON + and files attached with their filenames. + """ + + pieces: List[MessagePieceInput] = Field(..., description="Message content pieces") + pre_converted: bool = Field(False, description="If true, skip converter chain") + + +class SendMessageResponse(BaseModel): + """Response after sending a message.""" + + user_message: MessageResponse = Field(..., description="The sent user message") + assistant_message: Optional[MessageResponse] = Field(None, description="The assistant's response") + + +# ============================================================================ +# Branch +# ============================================================================ + + +class BranchConversationRequest(BaseModel): + """Request to branch a conversation.""" + + last_included_sequence: int = Field(..., description="Copy messages with sequence <= this value") + + +class BranchConversationResponse(BaseModel): + """Response after branching a conversation.""" + + conversation_id: str = Field(..., description="New conversation ID") + branched_from: Dict[str, Any] = Field(..., description="Source conversation info") + message_count: int = Field(..., description="Number of messages copied") + created_at: datetime = Field(..., description="Branch creation timestamp") + + +# ============================================================================ +# Full Conversation +# ============================================================================ + + +class ConversationResponse(BaseModel): + """Full conversation with all messages.""" + + conversation_id: str = Field(..., description="Unique conversation identifier") + target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") + labels: Optional[Dict[str, str]] = Field(None, description="Applied labels") + converters: List[ConverterConfig] = Field(default_factory=list, description="Configured converters") + created_at: datetime = Field(..., description="Creation timestamp") + messages: List[MessageResponse] = Field(default_factory=list, description="All messages in order") diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py new file mode 100644 index 0000000000..2bd56a3c52 --- /dev/null +++ b/pyrit/backend/models/converters.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Converter-related request and response models. +""" + +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + +from pyrit.backend.models.conversations import ConverterConfig +from pyrit.backend.models.registry import ConverterMetadataResponse +from pyrit.models import PromptDataType + +# Re-export for convenience +__all__ = [ + "ConverterMetadataResponse", + "ConverterListResponse", + "ConverterConfig", + "ConversionStep", + "PreviewConverterRequest", + "PreviewConverterResponse", +] + + +class ConverterListResponse(BaseModel): + """Response containing list of available converters.""" + + converters: List[ConverterMetadataResponse] = Field(..., description="Available converter types") + + +class ConversionStep(BaseModel): + """Single step in a conversion chain.""" + + converter_class: str = Field(..., description="Converter class that was applied") + input: str = Field(..., description="Input to this converter") + input_data_type: PromptDataType = Field(..., description="Input data type") + output: str = Field(..., description="Output from this converter") + output_data_type: PromptDataType = Field(..., description="Output data type") + + +class PreviewConverterRequest(BaseModel): + """Request to preview converter output.""" + + content: str = Field(..., description="Original content to convert") + data_type: PromptDataType = Field("text", description="Content data type") + converters: List[ConverterConfig] = Field(..., description="Ordered list of converters to apply") + + +class PreviewConverterResponse(BaseModel): + """Response with converter preview results.""" + + original_content: str = Field(..., description="Original input content") + converted_content: str = Field(..., description="Final converted content") + converted_data_type: PromptDataType = Field(..., description="Final data type") + conversion_chain: List[ConversionStep] = Field(..., description="Step-by-step conversion results") + converter_identifiers: List[Dict[str, Any]] = Field( + ..., description="Converter identifiers for use in pre_converted requests" + ) diff --git a/pyrit/backend/models/memory.py b/pyrit/backend/models/memory.py new file mode 100644 index 0000000000..98741f6eeb --- /dev/null +++ b/pyrit/backend/models/memory.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Memory query response models. + +Models for messages, scores, attack results, scenario results, and seeds. +""" + +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + +from pyrit.models import PromptDataType, PromptResponseError + +# ============================================================================ +# Message Queries +# ============================================================================ + + +class MessageQueryResponse(BaseModel): + """Response model for message piece queries.""" + + id: str = Field(..., description="Message piece ID") + conversation_id: str = Field(..., description="Parent conversation ID") + sequence: int = Field(..., description="Sequence in conversation") + role: str = Field(..., description="Message role") + original_value: str = Field(..., description="Original content") + original_value_data_type: PromptDataType = Field(..., description="Original data type") + converted_value: str = Field(..., description="Converted content") + converted_value_data_type: PromptDataType = Field(..., description="Converted data type") + converter_identifiers: List[Dict[str, Any]] = Field(default_factory=list, description="Applied converters") + target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") + labels: Optional[Dict[str, str]] = Field(None, description="Message labels") + response_error: Optional[PromptResponseError] = Field(None, description="Error type if any") + timestamp: datetime = Field(..., description="Message timestamp") + + +# ============================================================================ +# Score Queries +# ============================================================================ + +ScoreType = Literal["true_false", "float_scale", "unknown"] + + +class ScoreQueryResponse(BaseModel): + """Response model for score queries.""" + + id: str = Field(..., description="Score ID") + message_piece_id: str = Field(..., description="Associated message piece ID") + score_value: str = Field(..., description="Score value ('true'/'false' or numeric)") + score_value_description: str = Field(..., description="Human-readable score description") + score_type: ScoreType = Field(..., description="Type of score") + score_category: Optional[List[str]] = Field(None, description="Score categories") + score_rationale: str = Field(..., description="Explanation for the score") + scorer_identifier: Dict[str, Any] = Field(..., description="Scorer identifier (filtered)") + objective: Optional[str] = Field(None, description="Scoring objective") + timestamp: datetime = Field(..., description="Score timestamp") + + +# ============================================================================ +# Attack Results +# ============================================================================ + +AttackOutcome = Literal["success", "failure", "undetermined"] + + +class AttackResultQueryResponse(BaseModel): + """Response model for attack result queries.""" + + id: str = Field(..., description="Attack result ID") + conversation_id: str = Field(..., description="Associated conversation ID") + objective: str = Field(..., description="Attack objective") + attack_identifier: Dict[str, Any] = Field(..., description="Attack identifier (filtered)") + outcome: Optional[str] = Field(None, description="Attack outcome (success, failure, undetermined)") + outcome_reason: Optional[str] = Field(None, description="Explanation for outcome") + executed_turns: int = Field(..., description="Number of turns executed") + execution_time_ms: int = Field(..., description="Execution time in milliseconds") + timestamp: Optional[datetime] = Field(None, description="Result timestamp") + + +# ============================================================================ +# Scenario Results +# ============================================================================ + +ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"] + + +class ScenarioResultQueryResponse(BaseModel): + """Response model for scenario result queries.""" + + id: str = Field(..., description="Scenario result ID") + scenario_name: str = Field(..., description="Scenario name") + scenario_description: Optional[str] = Field(None, description="Scenario description") + scenario_version: int = Field(..., description="Scenario version") + pyrit_version: str = Field(..., description="PyRIT version used") + run_state: ScenarioRunState = Field(..., description="Current run state") + objective_target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") + labels: Optional[Dict[str, str]] = Field(None, description="Scenario labels") + number_tries: int = Field(..., description="Number of objectives attempted") + completion_time: Optional[datetime] = Field(None, description="Completion timestamp") + timestamp: datetime = Field(..., description="Creation timestamp") + + +# ============================================================================ +# Seeds +# ============================================================================ + +SeedType = Literal["prompt", "objective", "simulated_conversation"] + + +class SeedQueryResponse(BaseModel): + """Response model for seed queries.""" + + id: str = Field(..., description="Seed ID") + value: str = Field(..., description="Seed content") + data_type: PromptDataType = Field(..., description="Content data type") + name: Optional[str] = Field(None, description="Seed name") + dataset_name: Optional[str] = Field(None, description="Dataset name") + seed_type: SeedType = Field(..., description="Type of seed") + harm_categories: Optional[List[str]] = Field(None, description="Harm categories") + description: Optional[str] = Field(None, description="Seed description") + source: Optional[str] = Field(None, description="Seed source") + date_added: Optional[datetime] = Field(None, description="Date added") diff --git a/pyrit/backend/models/registry.py b/pyrit/backend/models/registry.py new file mode 100644 index 0000000000..4c2cfffa8b --- /dev/null +++ b/pyrit/backend/models/registry.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Registry response models. + +Models for targets, scenarios, scorers, converters, and initializers. +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from pyrit.models import PromptDataType + +# ============================================================================ +# Common +# ============================================================================ + + +class ParameterInfo(BaseModel): + """Information about a constructor parameter.""" + + name: str = Field(..., description="Parameter name") + type_hint: Optional[str] = Field(None, description="Type hint as string") + required: bool = Field(..., description="Whether parameter is required") + default: Optional[Any] = Field(None, description="Default value if not required") + + +# ============================================================================ +# Targets +# ============================================================================ + + +class TargetMetadataResponse(BaseModel): + """Metadata for a target type.""" + + name: str = Field(..., description="Registry name") + class_name: str = Field(..., description="Python class name") + description: str = Field(..., description="Target description") + is_chat_target: bool = Field(..., description="Whether target supports chat/system prompts") + supports_json_response: bool = Field(..., description="Whether target supports JSON response format") + supported_data_types: List[PromptDataType] = Field(..., description="Supported input data types") + params_schema: Optional[Dict[str, Any]] = Field(None, description="Parameter schema") + + +class TargetListResponse(BaseModel): + """Response containing list of available targets.""" + + targets: List[TargetMetadataResponse] = Field(..., description="Available target types") + + +# ============================================================================ +# Scenarios +# ============================================================================ + + +class ScenarioMetadataResponse(BaseModel): + """Metadata for a scenario type.""" + + name: str = Field(..., description="Registry name") + class_name: str = Field(..., description="Python class name") + description: str = Field(..., description="Scenario description") + default_strategy: str = Field(..., description="Default strategy name") + all_strategies: List[str] = Field(..., description="All available strategies") + aggregate_strategies: List[str] = Field(..., description="Composite/aggregate strategies") + default_datasets: List[str] = Field(..., description="Default dataset names") + max_dataset_size: Optional[int] = Field(None, description="Maximum dataset size limit") + + +class ScenarioListResponse(BaseModel): + """Response containing list of available scenarios.""" + + scenarios: List[ScenarioMetadataResponse] = Field(..., description="Available scenarios") + + +# ============================================================================ +# Scorers +# ============================================================================ + + +class ScorerMetadataResponse(BaseModel): + """Metadata for a registered scorer instance.""" + + name: str = Field(..., description="Registry name") + class_name: str = Field(..., description="Python class name") + description: str = Field(..., description="Scorer description") + scorer_type: str = Field(..., description="Score type (true_false or float_scale)") + scorer_identifier: Dict[str, Any] = Field(..., description="Scorer identifier (filtered)") + + +class ScorerListResponse(BaseModel): + """Response containing list of registered scorers.""" + + scorers: List[ScorerMetadataResponse] = Field(..., description="Registered scorer instances") + + +# ============================================================================ +# Initializers +# ============================================================================ + + +class InitializerMetadataResponse(BaseModel): + """Metadata for an initializer.""" + + name: str = Field(..., description="Registry name") + class_name: str = Field(..., description="Python class name") + description: str = Field(..., description="Initializer description") + required_env_vars: List[str] = Field(..., description="Required environment variables") + execution_order: int = Field(..., description="Execution order priority") + + +class InitializerListResponse(BaseModel): + """Response containing list of available initializers.""" + + initializers: List[InitializerMetadataResponse] = Field(..., description="Available initializers") + + +# ============================================================================ +# Converters +# ============================================================================ + + +class ConverterMetadataResponse(BaseModel): + """Metadata for a converter type.""" + + name: str = Field(..., description="Registry name (snake_case)") + class_name: str = Field(..., description="Python class name") + description: str = Field(..., description="Converter description") + supported_input_types: List[PromptDataType] = Field(..., description="Supported input data types") + supported_output_types: List[PromptDataType] = Field(..., description="Supported output data types") + is_llm_based: bool = Field(..., description="Whether converter requires LLM calls") + is_deterministic: bool = Field(..., description="Whether same input produces same output") + params_schema: Optional[Dict[str, Any]] = Field(None, description="Parameter schema") + + +class ConverterListResponse(BaseModel): + """Response containing list of available converters.""" + + converters: List[ConverterMetadataResponse] = Field(..., description="Available converter types") diff --git a/pyrit/backend/routes/__init__.py b/pyrit/backend/routes/__init__.py index 49d365abea..78bfbae3f3 100644 --- a/pyrit/backend/routes/__init__.py +++ b/pyrit/backend/routes/__init__.py @@ -4,3 +4,14 @@ """ API route handlers. """ + +from pyrit.backend.routes import conversations, converters, health, memory, registry, version + +__all__ = [ + "conversations", + "converters", + "health", + "memory", + "registry", + "version", +] diff --git a/pyrit/backend/routes/conversations.py b/pyrit/backend/routes/conversations.py new file mode 100644 index 0000000000..83de1cfd57 --- /dev/null +++ b/pyrit/backend/routes/conversations.py @@ -0,0 +1,338 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Conversation API routes. + +Provides endpoints for managing interactive conversation sessions. +""" + +from typing import List + +from fastapi import APIRouter, HTTPException, status + +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.models.conversations import ( + BranchConversationRequest, + BranchConversationResponse, + ConverterConfig, + ConvertersResponse, + CreateConversationRequest, + CreateConversationResponse, + MessageResponse, + SendMessageRequest, + SendMessageResponse, + SetSystemPromptRequest, + SystemPromptResponse, +) +from pyrit.backend.services import get_conversation_service + +router = APIRouter(prefix="/conversations", tags=["conversations"]) + + +@router.post( + "", + response_model=CreateConversationResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid request"}, + 422: {"model": ProblemDetail, "description": "Validation error"}, + }, +) +async def create_conversation(request: CreateConversationRequest) -> CreateConversationResponse: + """ + Create a new conversation session. + + Establishes a new conversation with the specified target and optional + system prompt and converters. + + Returns: + CreateConversationResponse: The created conversation details. + """ + service = get_conversation_service() + + try: + return await service.create_conversation(request) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create conversation: {str(e)}", + ) + + +@router.get( + "/{conversation_id}", + response_model=List[MessageResponse], + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + }, +) +async def get_conversation(conversation_id: str) -> List[MessageResponse]: + """ + Get all messages in a conversation. + + Returns messages in sequence order. + + Returns: + List[MessageResponse]: List of messages in the conversation. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + return await service.get_conversation_messages(conversation_id) + + +@router.post( + "/{conversation_id}/messages", + response_model=SendMessageResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + 400: {"model": ProblemDetail, "description": "Message send failed"}, + }, +) +async def send_message( + conversation_id: str, + request: SendMessageRequest, +) -> SendMessageResponse: + """ + Send a message in a conversation. + + Sends the user message to the target, applies converters, and returns + both the sent message and assistant response(s). + + Returns: + SendMessageResponse: The sent message and assistant response. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + try: + return await service.send_message(conversation_id, request) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to send message: {str(e)}", + ) + + +@router.get( + "/{conversation_id}/system-prompt", + response_model=SystemPromptResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + }, +) +async def get_system_prompt(conversation_id: str) -> SystemPromptResponse: + """ + Get the current system prompt for a conversation. + + Returns: + SystemPromptResponse: The current system prompt. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + return SystemPromptResponse( + system_prompt=state.system_prompt, + piece_id=None, # System prompts stored in state, not as MessagePiece + ) + + +@router.put( + "/{conversation_id}/system-prompt", + response_model=SystemPromptResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + }, +) +async def update_system_prompt( + conversation_id: str, + request: SetSystemPromptRequest, +) -> SystemPromptResponse: + """ + Update the system prompt for a conversation. + + Takes effect for subsequent messages. + + Returns: + SystemPromptResponse: The updated system prompt. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + await service.update_system_prompt(conversation_id, request.system_prompt) + + return SystemPromptResponse( + system_prompt=request.system_prompt, + piece_id=None, + ) + + +@router.get( + "/{conversation_id}/converters", + response_model=ConvertersResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + }, +) +async def get_converters(conversation_id: str) -> ConvertersResponse: + """ + Get the current converters for a conversation. + + Returns: + ConvertersResponse: The current converter configurations. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + return ConvertersResponse( + converters=state.converters, + ) + + +@router.put( + "/{conversation_id}/converters", + response_model=ConvertersResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + 400: {"model": ProblemDetail, "description": "Invalid converter configuration"}, + }, +) +async def update_converters( + conversation_id: str, + converters: List[ConverterConfig], +) -> ConvertersResponse: + """ + Update the converters for a conversation. + + Replaces all current converters with the provided list. + Takes effect for subsequent messages. + + Returns: + ConvertersResponse: The updated converter configurations. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + try: + await service.update_converters(conversation_id, converters) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + return ConvertersResponse( + converters=converters, + ) + + +@router.post( + "/{conversation_id}/branch", + response_model=BranchConversationResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + 400: {"model": ProblemDetail, "description": "Invalid branch request"}, + }, +) +async def branch_conversation( + conversation_id: str, + request: BranchConversationRequest, +) -> BranchConversationResponse: + """ + Branch a conversation from a specific point. + + Creates a new conversation with messages copied up to and including + the specified sequence number. The original conversation is unchanged. + + Returns: + BranchConversationResponse: The new branched conversation details. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + try: + return await service.branch_conversation(conversation_id, request) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + +@router.delete( + "/{conversation_id}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 404: {"model": ProblemDetail, "description": "Conversation not found"}, + }, +) +async def delete_conversation(conversation_id: str) -> None: + """ + Delete a conversation session. + + Cleans up in-memory resources. Messages remain in memory database. + """ + service = get_conversation_service() + + state = await service.get_conversation(conversation_id) + if not state: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found", + ) + + service.cleanup_conversation(conversation_id) diff --git a/pyrit/backend/routes/converters.py b/pyrit/backend/routes/converters.py new file mode 100644 index 0000000000..2790598ab0 --- /dev/null +++ b/pyrit/backend/routes/converters.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Converters API routes. + +Provides endpoints for listing and previewing prompt converters. +""" + +from typing import List, Optional + +from fastapi import APIRouter, HTTPException, Query, status + +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.models.converters import ( + ConversionStep, + ConverterMetadataResponse, + PreviewConverterRequest, + PreviewConverterResponse, +) +from pyrit.backend.services import get_conversation_service, get_registry_service + +router = APIRouter(prefix="/converters", tags=["converters"]) + + +@router.get( + "", + response_model=List[ConverterMetadataResponse], +) +async def list_converters( + is_llm_based: Optional[bool] = Query(None, description="Filter by LLM-based converters"), + is_deterministic: Optional[bool] = Query(None, description="Filter by deterministic converters"), +) -> List[ConverterMetadataResponse]: + """ + List available converters. + + Returns metadata about all available prompt converters, optionally + filtered by LLM-based status or determinism. + + Returns: + List[ConverterMetadataResponse]: List of converter metadata. + """ + service = get_registry_service() + + return service.get_converters( + is_llm_based=is_llm_based, + is_deterministic=is_deterministic, + ) + + +@router.post( + "/preview", + response_model=PreviewConverterResponse, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid converter configuration"}, + }, +) +async def preview_converters(request: PreviewConverterRequest) -> PreviewConverterResponse: + """ + Preview text through a converter pipeline. + + Applies the specified converters in sequence and returns + intermediate results at each step. Useful for testing converter + configurations before applying to conversations. + + Returns: + PreviewConverterResponse: Original content, converted content, and conversion steps. + """ + service = get_conversation_service() + + try: + steps_data = await service.preview_converters(request.content, request.converters) + + steps = [ + ConversionStep( + converter_class=s["converter_type"], + input=s["input"], + input_data_type=s.get("input_data_type", "text"), + output=s["output"], + output_data_type=s.get("output_type", "text"), + ) + for s in steps_data + ] + + final_output = steps[-1].output if steps else request.content + final_data_type = steps[-1].output_data_type if steps else request.data_type + + return PreviewConverterResponse( + original_content=request.content, + converted_content=final_output, + converted_data_type=final_data_type, + conversion_chain=steps, + converter_identifiers=[{"class_name": s.converter_class} for s in steps], + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Converter preview failed: {str(e)}", + ) diff --git a/pyrit/backend/routes/memory.py b/pyrit/backend/routes/memory.py new file mode 100644 index 0000000000..775e34ae22 --- /dev/null +++ b/pyrit/backend/routes/memory.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Memory API routes. + +Provides endpoints for querying stored data with pagination. +""" + +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, Query + +from pyrit.backend.models.common import PaginatedResponse +from pyrit.backend.models.memory import ( + AttackResultQueryResponse, + MessageQueryResponse, + ScenarioResultQueryResponse, + ScoreQueryResponse, + SeedQueryResponse, +) +from pyrit.backend.services import get_memory_service + +router = APIRouter(prefix="/memory", tags=["memory"]) + + +@router.get( + "/messages", + response_model=PaginatedResponse[MessageQueryResponse], +) +async def query_messages( + conversation_id: Optional[str] = Query(None, description="Filter by conversation ID"), + role: Optional[str] = Query(None, description="Filter by role (user/assistant/system)"), + data_type: Optional[str] = Query(None, description="Filter by data type (text/image_path/audio_path)"), + harm_category: Optional[List[str]] = Query(None, description="Filter by harm categories"), + response_error: Optional[str] = Query(None, description="Filter by response error type"), + start_time: Optional[datetime] = Query(None, description="Messages after this time"), + end_time: Optional[datetime] = Query(None, description="Messages before this time"), + limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor"), +) -> PaginatedResponse[MessageQueryResponse]: + """ + Query message pieces with pagination. + + Returns messages matching the specified filters, ordered by timestamp descending. + Use cursor for pagination through large result sets. + + Returns: + PaginatedResponse[MessageQueryResponse]: Paginated list of messages. + """ + service = get_memory_service() + + return await service.get_messages( + conversation_id=conversation_id, + role=role, + harm_categories=harm_category, + data_type=data_type, + response_error=response_error, + start_time=start_time, + end_time=end_time, + limit=limit, + cursor=cursor, + ) + + +@router.get( + "/scores", + response_model=PaginatedResponse[ScoreQueryResponse], +) +async def query_scores( + message_id: Optional[str] = Query(None, description="Filter by message piece ID"), + conversation_id: Optional[str] = Query(None, description="Filter by conversation ID"), + score_type: Optional[str] = Query(None, description="Filter by score type"), + scorer_type: Optional[str] = Query(None, description="Filter by scorer class name"), + start_time: Optional[datetime] = Query(None, description="Scores after this time"), + end_time: Optional[datetime] = Query(None, description="Scores before this time"), + limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor"), +) -> PaginatedResponse[ScoreQueryResponse]: + """ + Query scores with pagination. + + Returns scores matching the specified filters, ordered by timestamp descending. + + Returns: + PaginatedResponse[ScoreQueryResponse]: Paginated list of scores. + """ + service = get_memory_service() + + return await service.get_scores( + message_id=message_id, + conversation_id=conversation_id, + score_type=score_type, + scorer_type=scorer_type, + start_time=start_time, + end_time=end_time, + limit=limit, + cursor=cursor, + ) + + +@router.get( + "/attack-results", + response_model=PaginatedResponse[AttackResultQueryResponse], +) +async def query_attack_results( + conversation_id: Optional[str] = Query(None, description="Filter by conversation ID"), + outcome: Optional[str] = Query(None, description="Filter by outcome"), + attack_type: Optional[str] = Query(None, description="Filter by attack class name"), + objective: Optional[str] = Query(None, description="Search by objective text"), + min_turns: Optional[int] = Query(None, ge=1, description="Minimum executed turns"), + max_turns: Optional[int] = Query(None, ge=1, description="Maximum executed turns"), + start_time: Optional[datetime] = Query(None, description="Results after this time"), + end_time: Optional[datetime] = Query(None, description="Results before this time"), + limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor"), +) -> PaginatedResponse[AttackResultQueryResponse]: + """ + Query attack results with pagination. + + Returns attack results matching the specified filters, ordered by timestamp descending. + + Returns: + PaginatedResponse[AttackResultQueryResponse]: Paginated list of attack results. + """ + service = get_memory_service() + + return await service.get_attack_results( + conversation_id=conversation_id, + outcome=outcome, + attack_type=attack_type, + objective=objective, + min_turns=min_turns, + max_turns=max_turns, + start_time=start_time, + end_time=end_time, + limit=limit, + cursor=cursor, + ) + + +@router.get( + "/scenario-results", + response_model=PaginatedResponse[ScenarioResultQueryResponse], +) +async def query_scenario_results( + scenario_name: Optional[str] = Query(None, description="Filter by scenario name"), + run_state: Optional[str] = Query(None, description="Filter by run state"), + start_time: Optional[datetime] = Query(None, description="Results after this time"), + end_time: Optional[datetime] = Query(None, description="Results before this time"), + limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor"), +) -> PaginatedResponse[ScenarioResultQueryResponse]: + """ + Query scenario results with pagination. + + Returns scenario results matching the specified filters, ordered by timestamp descending. + + Returns: + PaginatedResponse[ScenarioResultQueryResponse]: Paginated list of scenario results. + """ + service = get_memory_service() + + return await service.get_scenario_results( + scenario_name=scenario_name, + run_state=run_state, + start_time=start_time, + end_time=end_time, + limit=limit, + cursor=cursor, + ) + + +@router.get( + "/seeds", + response_model=PaginatedResponse[SeedQueryResponse], +) +async def query_seeds( + dataset_name: Optional[str] = Query(None, description="Filter by dataset name"), + seed_type: Optional[str] = Query(None, description="Filter by seed type"), + harm_category: Optional[List[str]] = Query(None, description="Filter by harm categories"), + data_type: Optional[str] = Query(None, description="Filter by data type"), + search: Optional[str] = Query(None, description="Search in seed value text"), + limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor"), +) -> PaginatedResponse[SeedQueryResponse]: + """ + Query seeds with pagination. + + Returns seeds matching the specified filters, ordered by date_added descending. + + Returns: + PaginatedResponse[SeedQueryResponse]: Paginated list of seeds. + """ + service = get_memory_service() + + return await service.get_seeds( + dataset_name=dataset_name, + seed_type=seed_type, + harm_categories=harm_category, + data_type=data_type, + search=search, + limit=limit, + cursor=cursor, + ) diff --git a/pyrit/backend/routes/registry.py b/pyrit/backend/routes/registry.py new file mode 100644 index 0000000000..dd0659bd33 --- /dev/null +++ b/pyrit/backend/routes/registry.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Registry API routes. + +Provides endpoints for querying available components. +""" + +from typing import List, Optional + +from fastapi import APIRouter, Query + +from pyrit.backend.models.registry import ( + ConverterMetadataResponse, + InitializerMetadataResponse, + ScenarioMetadataResponse, + ScorerMetadataResponse, + TargetMetadataResponse, +) +from pyrit.backend.services import get_registry_service + +router = APIRouter(prefix="/registry", tags=["registry"]) + + +@router.get( + "/targets", + response_model=List[TargetMetadataResponse], +) +async def list_targets( + is_chat_target: Optional[bool] = Query(None, description="Filter by chat target support"), +) -> List[TargetMetadataResponse]: + """ + List available targets. + + Returns metadata about all available prompt targets, optionally + filtered by chat target support. + + Returns: + List[TargetMetadataResponse]: List of target metadata. + """ + service = get_registry_service() + + return service.get_targets(is_chat_target=is_chat_target) + + +@router.get( + "/scenarios", + response_model=List[ScenarioMetadataResponse], +) +async def list_scenarios() -> List[ScenarioMetadataResponse]: + """ + List available scenarios. + + Returns metadata about all registered scenarios. + + Returns: + List[ScenarioMetadataResponse]: List of scenario metadata. + """ + service = get_registry_service() + + return service.get_scenarios() + + +@router.get( + "/scorers", + response_model=List[ScorerMetadataResponse], +) +async def list_scorers( + scorer_type: Optional[str] = Query(None, description="Filter by scorer type (true_false or float_scale)"), +) -> List[ScorerMetadataResponse]: + """ + List registered scorers. + + Returns metadata about all registered scorer instances. + + Returns: + List[ScorerMetadataResponse]: List of scorer metadata. + """ + service = get_registry_service() + + return service.get_scorers(scorer_type=scorer_type) + + +@router.get( + "/converters", + response_model=List[ConverterMetadataResponse], +) +async def list_converters( + is_llm_based: Optional[bool] = Query(None, description="Filter by LLM-based converters"), + is_deterministic: Optional[bool] = Query(None, description="Filter by deterministic converters"), +) -> List[ConverterMetadataResponse]: + """ + List available converters. + + Returns metadata about all available prompt converters. + Note: Also available at /converters endpoint. + + Returns: + List[ConverterMetadataResponse]: List of converter metadata. + """ + service = get_registry_service() + + return service.get_converters( + is_llm_based=is_llm_based, + is_deterministic=is_deterministic, + ) + + +@router.get( + "/initializers", + response_model=List[InitializerMetadataResponse], +) +async def list_initializers() -> List[InitializerMetadataResponse]: + """ + List available initializers. + + Returns metadata about all registered initializers. + + Returns: + List[InitializerMetadataResponse]: List of initializer metadata. + """ + service = get_registry_service() + + return service.get_initializers() diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index ee964f027b..8536741133 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -2,5 +2,31 @@ # Licensed under the MIT license. """ -Business logic services for backend operations. +Backend services module. + +Provides business logic layer for API routes. """ + +from pyrit.backend.services.conversation_service import ( + ConversationService, + ConversationState, + get_conversation_service, +) +from pyrit.backend.services.memory_service import ( + MemoryService, + get_memory_service, +) +from pyrit.backend.services.registry_service import ( + RegistryService, + get_registry_service, +) + +__all__ = [ + "ConversationService", + "ConversationState", + "get_conversation_service", + "MemoryService", + "get_memory_service", + "RegistryService", + "get_registry_service", +] diff --git a/pyrit/backend/services/conversation_service.py b/pyrit/backend/services/conversation_service.py new file mode 100644 index 0000000000..24438de473 --- /dev/null +++ b/pyrit/backend/services/conversation_service.py @@ -0,0 +1,508 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Conversation service for managing interactive sessions. + +Handles conversation lifecycle, message sending, branching, and converter management. +""" + +import importlib +import uuid +from collections import defaultdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from pyrit.backend.models.common import filter_sensitive_fields +from pyrit.backend.models.conversations import ( + BranchConversationRequest, + BranchConversationResponse, + ConverterConfig, + CreateConversationRequest, + CreateConversationResponse, + MessagePieceResponse, + MessageResponse, + SendMessageRequest, + SendMessageResponse, +) +from pyrit.memory import CentralMemory +from pyrit.models import Message, MessagePiece + + +class ConversationState(BaseModel): + """In-memory state for an active conversation.""" + + conversation_id: str + target_class: str + target_identifier: Dict[str, Any] + system_prompt: Optional[str] = None + labels: Optional[Dict[str, str]] = None + converters: List[ConverterConfig] = [] + created_at: datetime + message_count: int = 0 + + +class ConversationService: + """Service for managing conversation sessions.""" + + def __init__(self) -> None: + """Initialize the conversation service.""" + self._memory = CentralMemory.get_memory_instance() + # In-memory conversation state (for active sessions) + self._active_conversations: Dict[str, ConversationState] = {} + # Instantiated converters by conversation + self._converter_instances: Dict[str, List[Any]] = {} + # Instantiated targets by conversation + self._target_instances: Dict[str, Any] = {} + + def _instantiate_target_from_class(self, target_class: str, target_params: Optional[Dict[str, Any]]) -> Any: + """ + Instantiate a target from its class name. + + Args: + target_class: Target class name (e.g., 'TextTarget'). + target_params: Constructor parameters. + + Returns: + Instantiated target object. + """ + # Import the target class dynamically + module = importlib.import_module("pyrit.prompt_target") + cls = getattr(module, target_class, None) + + if cls is None: + raise ValueError(f"Target class '{target_class}' not found in pyrit.prompt_target") + + params = target_params or {} + return cls(**params) + + def _instantiate_converters(self, converter_configs: List[ConverterConfig]) -> List[Any]: + """ + Instantiate converters from their configurations. + + Args: + converter_configs: List of converter configurations. + + Returns: + List of instantiated converter objects. + """ + converters = [] + for config in converter_configs: + module = importlib.import_module(config.module) + converter_class = getattr(module, config.class_name) + params = config.params or {} + converter = converter_class(**params) + converters.append(converter) + + return converters + + async def create_conversation(self, request: CreateConversationRequest) -> CreateConversationResponse: + """ + Create a new conversation session. + + Args: + request: Conversation creation request. + + Returns: + Created conversation response with ID. + """ + conversation_id = str(uuid.uuid4()) + now = datetime.utcnow() + + # Instantiate the target + target = self._instantiate_target_from_class(request.target_class, request.target_params) + self._target_instances[conversation_id] = target + + # Get the target's identifier + target_identifier = target.get_identifier() if hasattr(target, "get_identifier") else {} + + # Store conversation state + state = ConversationState( + conversation_id=conversation_id, + target_class=request.target_class, + target_identifier=filter_sensitive_fields(target_identifier), + labels=request.labels, + converters=[], + created_at=now, + ) + self._active_conversations[conversation_id] = state + + return CreateConversationResponse( + conversation_id=conversation_id, + target_identifier=state.target_identifier, + labels=state.labels, + created_at=now, + ) + + async def get_conversation(self, conversation_id: str) -> Optional[ConversationState]: + """ + Get conversation state by ID. + + Returns: + Optional[ConversationState]: The conversation state or None if not found. + """ + return self._active_conversations.get(conversation_id) + + async def get_conversation_messages(self, conversation_id: str) -> List[MessageResponse]: + """ + Get all messages in a conversation. + + Args: + conversation_id: The conversation ID. + + Returns: + List of messages (grouped pieces) in order. + """ + pieces = self._memory.get_message_pieces(conversation_id=conversation_id) + + # Sort by sequence + pieces = sorted(pieces, key=lambda p: p.sequence) + + # Group pieces by sequence + by_sequence: Dict[int, List[Any]] = defaultdict(list) + for p in pieces: + by_sequence[p.sequence].append(p) + + messages = [] + for seq in sorted(by_sequence.keys()): + seq_pieces = by_sequence[seq] + if not seq_pieces: + continue + + first_piece = seq_pieces[0] + message_pieces = [ + MessagePieceResponse( + id=str(p.id) if hasattr(p, "id") and p.id else str(uuid.uuid4()), + original_value=p.original_value or "", + original_value_data_type=p.original_value_data_type, + converted_value=p.converted_value or "", + converted_value_data_type=p.converted_value_data_type, + converter_identifiers=p.converter_identifiers or [], + response_error=p.response_error if hasattr(p, "response_error") else None, + timestamp=p.timestamp, + ) + for p in seq_pieces + ] + + messages.append( + MessageResponse( + sequence=seq, + role=first_piece.role, + pieces=message_pieces, + timestamp=first_piece.timestamp, + ) + ) + + return messages + + async def send_message( + self, + conversation_id: str, + request: SendMessageRequest, + ) -> SendMessageResponse: + """ + Send a message to the target and get a response. + + This is a simplified stub - real implementation would involve + creating MessagePiece objects, applying converters, and calling target. + + Args: + conversation_id: The conversation ID. + request: Message send request. + + Returns: + Response containing sent and received messages. + """ + state = self._active_conversations.get(conversation_id) + if not state: + raise ValueError(f"Conversation {conversation_id} not found") + + target = self._target_instances.get(conversation_id) + if not target: + raise ValueError(f"Target for conversation {conversation_id} not found") + + now = datetime.utcnow() + state.message_count += 1 + user_seq = state.message_count + + # Get converters if any + converters = self._converter_instances.get(conversation_id, []) + + # Build user message pieces + user_pieces_response = [] + user_piece_objs = [] + + for piece_input in request.pieces: + original_value = piece_input.original_value or "" + original_type = piece_input.original_value_data_type + converted_value = piece_input.converted_value or original_value + converted_type = piece_input.converted_value_data_type or original_type + converter_ids = piece_input.converter_identifiers or [] + + # Apply converters if not pre-converted + if not request.pre_converted and converters: + for converter in converters: + result = await converter.convert_async(prompt=converted_value) + converted_value = result.output_text + converted_type = result.output_type + converter_ids.append(converter.get_identifier()) + + piece_id = str(uuid.uuid4()) + user_pieces_response.append( + MessagePieceResponse( + id=piece_id, + original_value=original_value, + original_value_data_type=original_type, + converted_value=converted_value, + converted_value_data_type=converted_type, + converter_identifiers=converter_ids, + response_error=None, + timestamp=now, + ) + ) + + # Create actual MessagePiece for target + user_piece_objs.append( + MessagePiece( + role="user", + original_value=original_value, + original_value_data_type=original_type, + converted_value=converted_value, + converted_value_data_type=converted_type, + converter_identifiers=converter_ids if converter_ids else None, + prompt_target_identifier=target.get_identifier(), + conversation_id=conversation_id, + sequence=user_seq, + ) + ) + + user_message_response = MessageResponse( + sequence=user_seq, + role="user", + pieces=user_pieces_response, + timestamp=now, + ) + + # Send to target + user_message_obj = Message(user_piece_objs) + response_messages = await target.send_prompt_async(message=user_message_obj) + + # Build assistant response + assistant_message_response = None + if response_messages: + state.message_count += 1 + assistant_seq = state.message_count + + assistant_pieces = [] + for resp_message in response_messages: + for resp_piece in resp_message.message_pieces: + assistant_pieces.append( + MessagePieceResponse( + id=str(resp_piece.id) if hasattr(resp_piece, "id") else str(uuid.uuid4()), + original_value=resp_piece.original_value or "", + original_value_data_type=resp_piece.original_value_data_type, + converted_value=resp_piece.converted_value or "", + converted_value_data_type=resp_piece.converted_value_data_type, + converter_identifiers=resp_piece.converter_identifiers or [], + response_error=getattr(resp_piece, "response_error", None), + timestamp=resp_piece.timestamp, + ) + ) + + if assistant_pieces: + assistant_message_response = MessageResponse( + sequence=assistant_seq, + role="assistant", + pieces=assistant_pieces, + timestamp=now, + ) + + return SendMessageResponse( + user_message=user_message_response, + assistant_message=assistant_message_response, + ) + + async def update_system_prompt(self, conversation_id: str, system_prompt: str) -> None: + """ + Update the system prompt for a conversation. + + Args: + conversation_id: The conversation ID. + system_prompt: New system prompt. + """ + state = self._active_conversations.get(conversation_id) + if not state: + raise ValueError(f"Conversation {conversation_id} not found") + + target = self._target_instances.get(conversation_id) + if not target: + raise ValueError(f"Target for conversation {conversation_id} not found") + + # Update target system prompt + target.set_system_prompt( + system_prompt=system_prompt, + conversation_id=conversation_id, + ) + + # Update state + state.system_prompt = system_prompt + + async def update_converters(self, conversation_id: str, converters: List[ConverterConfig]) -> None: + """ + Update the converters for a conversation. + + Args: + conversation_id: The conversation ID. + converters: New converter configurations. + """ + state = self._active_conversations.get(conversation_id) + if not state: + raise ValueError(f"Conversation {conversation_id} not found") + + # Instantiate new converters + converter_instances = self._instantiate_converters(converters) + self._converter_instances[conversation_id] = converter_instances + + # Update state + state.converters = converters + + async def branch_conversation( + self, + conversation_id: str, + request: BranchConversationRequest, + ) -> BranchConversationResponse: + """ + Branch a conversation from a specific point. + + Args: + conversation_id: The source conversation ID. + request: Branch request with last_included_sequence. + + Returns: + New conversation with copied messages. + """ + state = self._active_conversations.get(conversation_id) + if not state: + raise ValueError(f"Conversation {conversation_id} not found") + + # Get messages up to branch point + all_messages = await self.get_conversation_messages(conversation_id) + messages_to_copy = [m for m in all_messages if m.sequence <= request.last_included_sequence] + + if not messages_to_copy: + raise ValueError(f"No messages found at or before sequence {request.last_included_sequence}") + + # Create new conversation with same target and converters + new_conversation_id = str(uuid.uuid4()) + now = datetime.utcnow() + + # Copy target instance + original_target = self._target_instances.get(conversation_id) + if original_target: + # Create new target instance with same config + new_target = self._instantiate_target_from_class(state.target_class, None) + self._target_instances[new_conversation_id] = new_target + + # Copy converters + if state.converters: + self._converter_instances[new_conversation_id] = self._instantiate_converters(state.converters) + + # Create new state + new_state = ConversationState( + conversation_id=new_conversation_id, + target_class=state.target_class, + target_identifier=state.target_identifier, + labels=state.labels, + converters=state.converters, + created_at=now, + message_count=len(messages_to_copy), + ) + self._active_conversations[new_conversation_id] = new_state + + # Copy messages to memory with new conversation ID + for msg in messages_to_copy: + for piece in msg.pieces: + new_piece = MessagePiece( + role=msg.role, + original_value=piece.original_value, + original_value_data_type=piece.original_value_data_type, + converted_value=piece.converted_value, + converted_value_data_type=piece.converted_value_data_type, + converter_identifiers=piece.converter_identifiers if piece.converter_identifiers else None, + conversation_id=new_conversation_id, + sequence=msg.sequence, + ) + self._memory.add_message_pieces_to_memory(message_pieces=[new_piece]) + + return BranchConversationResponse( + conversation_id=new_conversation_id, + branched_from={ + "conversation_id": conversation_id, + "last_included_sequence": request.last_included_sequence, + }, + message_count=len(messages_to_copy), + created_at=now, + ) + + async def preview_converters( + self, + text: str, + converters: List[ConverterConfig], + ) -> List[Dict[str, Any]]: + """ + Preview text through a converter pipeline. + + Args: + text: Input text to convert. + converters: Converter configurations to apply. + + Returns: + List of conversion steps showing intermediate results. + """ + converter_instances = self._instantiate_converters(converters) + + steps = [] + current_text = text + + for i, converter in enumerate(converter_instances): + config = converters[i] + result = await converter.convert_async(prompt=current_text) + + steps.append( + { + "step": i + 1, + "converter_class": config.class_name, + "input": current_text, + "output": result.output_text, + "output_type": result.output_type, + } + ) + + current_text = result.output_text + + return steps + + def cleanup_conversation(self, conversation_id: str) -> None: + """Clean up resources for a conversation.""" + self._active_conversations.pop(conversation_id, None) + self._converter_instances.pop(conversation_id, None) + self._target_instances.pop(conversation_id, None) + + +# Singleton instance +_conversation_service: Optional[ConversationService] = None + + +def get_conversation_service() -> ConversationService: + """ + Get the conversation service singleton. + + Returns: + ConversationService: The conversation service instance. + """ + global _conversation_service + if _conversation_service is None: + _conversation_service = ConversationService() + return _conversation_service diff --git a/pyrit/backend/services/memory_service.py b/pyrit/backend/services/memory_service.py new file mode 100644 index 0000000000..8d9cbeab02 --- /dev/null +++ b/pyrit/backend/services/memory_service.py @@ -0,0 +1,444 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Memory service for API access to stored data. + +Wraps CentralMemory with pagination and filtering for API endpoints. +""" + +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from pyrit.backend.models.common import PaginatedResponse, PaginationInfo, filter_sensitive_fields +from pyrit.backend.models.memory import ( + AttackResultQueryResponse, + MessageQueryResponse, + ScenarioResultQueryResponse, + ScoreQueryResponse, + SeedQueryResponse, +) +from pyrit.memory import CentralMemory +from pyrit.models.seeds import SeedObjective, SeedSimulatedConversation + + +def _parse_cursor(cursor: Optional[str]) -> Tuple[Optional[datetime], Optional[str]]: + """ + Parse a cursor string into timestamp and ID components. + + Cursor format: {ISO8601_timestamp}_{record_id} + + Returns: + Tuple[Optional[datetime], Optional[str]]: Parsed timestamp and record ID. + """ + if not cursor: + return None, None + + try: + parts = cursor.rsplit("_", 1) + if len(parts) != 2: + return None, None + timestamp_str, record_id = parts + timestamp = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) + return timestamp, record_id + except (ValueError, AttributeError): + return None, None + + +def _build_cursor(timestamp: datetime, record_id: str) -> str: + """ + Build a cursor string from timestamp and ID. + + Returns: + str: Cursor string for pagination. + """ + return f"{timestamp.isoformat()}_{record_id}" + + +class MemoryService: + """Service for querying memory with pagination support.""" + + def __init__(self) -> None: + """Initialize the memory service.""" + self._memory = CentralMemory.get_memory_instance() + + async def get_messages( + self, + *, + conversation_id: Optional[str] = None, + role: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + harm_categories: Optional[List[str]] = None, + data_type: Optional[str] = None, + response_error: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> PaginatedResponse[MessageQueryResponse]: + """ + Query message pieces with pagination. + + Args: + conversation_id: Filter by conversation. + role: Filter by message role. + labels: Filter by labels. + harm_categories: Filter by harm categories (not supported in current API). + data_type: Filter by data type. + response_error: Filter by response error type (not supported in current API). + start_time: Messages after this time. + end_time: Messages before this time. + limit: Maximum results per page. + cursor: Pagination cursor. + + Returns: + Paginated list of messages. + """ + # Parse cursor for pagination + cursor_time, cursor_id = _parse_cursor(cursor) + + # Query memory - use supported parameters only + pieces = self._memory.get_message_pieces( + conversation_id=conversation_id, + role=role, + labels=labels, + data_type=data_type, + sent_after=cursor_time or start_time, + sent_before=end_time, + ) + + # Apply start_time filter if provided and no cursor + if start_time and not cursor_time: + pieces = [p for p in pieces if p.timestamp and p.timestamp >= start_time] + + # Sort by timestamp descending + pieces = sorted(pieces, key=lambda p: p.timestamp or datetime.min, reverse=True) + + # Apply limit + 1 to check for more + has_more = len(pieces) > limit + pieces = pieces[:limit] + + # Build response items + items = [] + for piece in pieces: + items.append( + MessageQueryResponse( + id=str(piece.id), + conversation_id=piece.conversation_id, + sequence=piece.sequence, + role=piece.role, + original_value=piece.original_value, + original_value_data_type=piece.original_value_data_type, + converted_value=piece.converted_value, + converted_value_data_type=piece.converted_value_data_type, + converter_identifiers=piece.converter_identifiers or [], + target_identifier=filter_sensitive_fields(piece.prompt_target_identifier or {}), + labels=piece.labels, + response_error=piece.response_error, + timestamp=piece.timestamp, + ) + ) + + # Build pagination info + next_cursor = None + if has_more and pieces: + last_piece = pieces[-1] + next_cursor = _build_cursor(last_piece.timestamp, str(last_piece.id)) + + prev_cursor = None + if cursor and pieces: + first_piece = pieces[0] + prev_cursor = _build_cursor(first_piece.timestamp, str(first_piece.id)) + + return PaginatedResponse( + items=items, + pagination=PaginationInfo( + limit=limit, + has_more=has_more, + next_cursor=next_cursor, + prev_cursor=prev_cursor, + ), + ) + + async def get_scores( + self, + *, + message_id: Optional[str] = None, + conversation_id: Optional[str] = None, + score_type: Optional[str] = None, + scorer_type: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> PaginatedResponse[ScoreQueryResponse]: + """ + Query scores with pagination. + + Returns: + PaginatedResponse[ScoreQueryResponse]: Paginated list of scores. + """ + scores = self._memory.get_scores( + score_type=score_type, + ) + + # Apply additional filters + if message_id: + scores = [s for s in scores if str(s.message_piece_id) == message_id] + + if scorer_type: + scores = [ + s + for s in scores + if s.scorer_class_identifier and s.scorer_class_identifier.get("__type__") == scorer_type + ] + + if start_time: + scores = [s for s in scores if s.timestamp and s.timestamp >= start_time] + if end_time: + scores = [s for s in scores if s.timestamp and s.timestamp <= end_time] + + # Sort and paginate + scores = sorted(scores, key=lambda s: s.timestamp or datetime.min, reverse=True) + has_more = len(scores) > limit + scores = scores[:limit] + + items = [] + for score in scores: + items.append( + ScoreQueryResponse( + id=str(score.id), + message_piece_id=str(score.message_piece_id), + score_value=score.score_value, + score_value_description=score.score_value_description or "", + score_type=score.score_type, + score_category=score.score_category, + score_rationale=score.score_rationale or "", + scorer_identifier=filter_sensitive_fields(score.scorer_class_identifier or {}), + objective=score.objective, + timestamp=score.timestamp, + ) + ) + + next_cursor = None + if has_more and scores: + last = scores[-1] + next_cursor = _build_cursor(last.timestamp, str(last.id)) + + return PaginatedResponse( + items=items, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=None), + ) + + async def get_attack_results( + self, + *, + conversation_id: Optional[str] = None, + outcome: Optional[str] = None, + attack_type: Optional[str] = None, + objective: Optional[str] = None, + min_turns: Optional[int] = None, + max_turns: Optional[int] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> PaginatedResponse[AttackResultQueryResponse]: + """ + Query attack results with pagination. + + Returns: + PaginatedResponse[AttackResultQueryResponse]: Paginated list of attack results. + """ + results = self._memory.get_attack_results( + conversation_id=conversation_id, + outcome=outcome, + objective=objective, + ) + + # Apply additional filters + if attack_type: + results = [r for r in results if r.attack_identifier and r.attack_identifier.get("__type__") == attack_type] + + if min_turns: + results = [r for r in results if r.executed_turns >= min_turns] + if max_turns: + results = [r for r in results if r.executed_turns <= max_turns] + + # Note: AttackResult doesn't have timestamp field - skip time filtering + # Sort by executed_turns as a proxy for recency + results_list = list(results) + has_more = len(results_list) > limit + results_list = results_list[:limit] + + items = [] + for result in results_list: + items.append( + AttackResultQueryResponse( + id=result.conversation_id, # Use conversation_id as identifier + conversation_id=result.conversation_id, + objective=result.objective, + attack_identifier=filter_sensitive_fields(result.attack_identifier or {}), + outcome=str(result.outcome.value) if result.outcome else None, + outcome_reason=result.outcome_reason, + executed_turns=result.executed_turns, + execution_time_ms=result.execution_time_ms, + timestamp=None, # AttackResult doesn't have timestamp + ) + ) + + # No cursor-based pagination available without timestamps + return PaginatedResponse( + items=items, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=None, prev_cursor=None), + ) + + async def get_scenario_results( + self, + *, + scenario_name: Optional[str] = None, + run_state: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> PaginatedResponse[ScenarioResultQueryResponse]: + """ + Query scenario results with pagination. + + Returns: + PaginatedResponse[ScenarioResultQueryResponse]: Paginated list of scenario results. + """ + results = self._memory.get_scenario_results( + scenario_name=scenario_name, + labels=labels, + added_after=start_time, + added_before=end_time, + ) + + # Apply run_state filter if provided (not directly supported in API) + if run_state: + results = [r for r in results if r.scenario_run_state == run_state] + + # Sort by completion_time descending + results_list = list(results) + results_list = sorted(results_list, key=lambda r: r.completion_time or datetime.min, reverse=True) + has_more = len(results_list) > limit + results_list = results_list[:limit] + + items = [] + for result in results_list: + items.append( + ScenarioResultQueryResponse( + id=str(result.id), + scenario_name=result.scenario_identifier.name if result.scenario_identifier else "", + scenario_description=result.scenario_identifier.description if result.scenario_identifier else "", + scenario_version=result.scenario_identifier.version if result.scenario_identifier else 0, + pyrit_version=result.scenario_identifier.pyrit_version if result.scenario_identifier else "", + run_state=result.scenario_run_state, + objective_target_identifier=filter_sensitive_fields(result.objective_target_identifier or {}), + labels=result.labels, + number_tries=result.number_tries, + completion_time=result.completion_time, + timestamp=result.completion_time, # Use completion_time as timestamp + ) + ) + + next_cursor = None + if has_more and results_list: + last = results_list[-1] + next_cursor = _build_cursor(last.completion_time or datetime.min, str(last.id)) + + return PaginatedResponse( + items=items, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=None), + ) + + async def get_seeds( + self, + *, + dataset_name: Optional[str] = None, + seed_type: Optional[str] = None, + harm_categories: Optional[List[str]] = None, + data_type: Optional[str] = None, + search: Optional[str] = None, + limit: int = 50, + cursor: Optional[str] = None, + ) -> PaginatedResponse[SeedQueryResponse]: + """ + Query seeds with pagination. + + Returns: + PaginatedResponse[SeedQueryResponse]: Paginated list of seeds. + """ + # Build query params - seed_type needs conversion to SeedType + query_params: Dict[str, Any] = { + "dataset_name": dataset_name, + "harm_categories": harm_categories, + } + if seed_type: + query_params["seed_type"] = seed_type + if data_type: + query_params["data_types"] = [data_type] + if search: + query_params["value"] = search + + seeds = self._memory.get_seeds(**query_params) + + # Sort by date_added descending + seeds_list = sorted(list(seeds), key=lambda s: s.date_added or datetime.min, reverse=True) + has_more = len(seeds_list) > limit + seeds_list = seeds_list[:limit] + + items = [] + for seed in seeds_list: + # Determine seed_type based on class + if isinstance(seed, SeedObjective): + determined_seed_type = "objective" + elif isinstance(seed, SeedSimulatedConversation): + determined_seed_type = "simulated_conversation" + else: + determined_seed_type = "prompt" + + items.append( + SeedQueryResponse( + id=str(seed.id), + value=seed.value, + data_type=seed.data_type, + name=seed.name, + dataset_name=seed.dataset_name, + seed_type=determined_seed_type, # type: ignore + harm_categories=list(seed.harm_categories) if seed.harm_categories else None, + description=seed.description, + source=seed.source, + date_added=seed.date_added, + ) + ) + + next_cursor = None + if has_more and seeds: + last = seeds[-1] + next_cursor = _build_cursor(last.date_added or datetime.min, str(last.id)) + + return PaginatedResponse( + items=items, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=None), + ) + + +# Singleton instance +_memory_service: Optional[MemoryService] = None + + +def get_memory_service() -> MemoryService: + """ + Get the memory service singleton. + + Returns: + MemoryService: The memory service instance. + """ + global _memory_service + if _memory_service is None: + _memory_service = MemoryService() + return _memory_service diff --git a/pyrit/backend/services/registry_service.py b/pyrit/backend/services/registry_service.py new file mode 100644 index 0000000000..2a4109bda3 --- /dev/null +++ b/pyrit/backend/services/registry_service.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Registry service for API access to registered components. + +Wraps component registries with filtering and metadata extraction. +Uses class introspection when registries are not available. +""" + +import inspect +import logging +from typing import Any, Dict, List, Optional, Type + +from pyrit.backend.models.common import filter_sensitive_fields +from pyrit.backend.models.registry import ( + ConverterMetadataResponse, + InitializerMetadataResponse, + ScenarioMetadataResponse, + ScorerMetadataResponse, + TargetMetadataResponse, +) +from pyrit.prompt_converter import PromptConverter +from pyrit.prompt_target import PromptChatTarget, PromptTarget +from pyrit.registry import InitializerRegistry, ScenarioRegistry, ScorerRegistry + +logger = logging.getLogger(__name__) + + +def _extract_params_schema(cls: Type[Any]) -> Dict[str, Any]: + """ + Extract parameter schema from a class constructor. + + Returns: + Dict[str, Any]: Dict with 'required' and 'optional' fields. + """ + required = [] + optional = [] + + try: + sig = inspect.signature(cls.__init__) + + for name, param in sig.parameters.items(): + if name in ("self", "cls", "args", "kwargs"): + continue + + if param.default == inspect.Parameter.empty: + required.append(name) + else: + optional.append(name) + except (ValueError, TypeError): + pass + + return {"required": required, "optional": optional} + + +def _get_all_subclasses(cls: type) -> List[type]: + """ + Recursively get all non-abstract subclasses of a class. + + Returns: + List[type]: List of non-abstract subclass types. + """ + subclasses = [] + for subclass in cls.__subclasses__(): + # Skip abstract classes + if hasattr(subclass, "__abstractmethods__") and subclass.__abstractmethods__: + subclasses.extend(_get_all_subclasses(subclass)) + else: + subclasses.append(subclass) + subclasses.extend(_get_all_subclasses(subclass)) + return subclasses + + +class RegistryService: + """Service for querying component registries.""" + + def get_targets( + self, + *, + is_chat_target: Optional[bool] = None, + ) -> List[TargetMetadataResponse]: + """ + Get available target types via introspection. + + Note: TargetRegistry may not exist yet (PR #1320). + Falls back to class introspection. + + Args: + is_chat_target: Filter by chat target support. + + Returns: + List of target metadata. + """ + # Get all concrete target subclasses via introspection + target_classes = _get_all_subclasses(PromptTarget) + + results = [] + for target_class in target_classes: + # Determine if chat target + is_chat = issubclass(target_class, PromptChatTarget) + + if is_chat_target is not None and is_chat != is_chat_target: + continue + + # Check JSON response support + supports_json = False + if is_chat: + try: + supports_json = hasattr(target_class, "is_json_response_supported") + except Exception: + pass + + # Get supported data types from class attribute if available + supported_types = getattr(target_class, "SUPPORTED_DATA_TYPES", ["text"]) + + results.append( + TargetMetadataResponse( + name=target_class.__name__, + class_name=target_class.__name__, + description=(target_class.__doc__ or "").split("\n")[0].strip(), + is_chat_target=is_chat, + supports_json_response=supports_json, + supported_data_types=list(supported_types), + params_schema=_extract_params_schema(target_class), + ) + ) + + return results + + def get_scenarios(self) -> List[ScenarioMetadataResponse]: + """ + Get all available scenarios from the registry. + + Returns: + List of scenario metadata. + """ + try: + registry = ScenarioRegistry.get_registry_singleton() + metadata_list = registry.list_metadata() + + results = [] + for m in metadata_list: + results.append( + ScenarioMetadataResponse( + name=m.name, + class_name=m.class_name, + description=m.class_description or "", + default_strategy=m.default_strategy, + all_strategies=list(m.all_strategies), + aggregate_strategies=list(m.aggregate_strategies), + default_datasets=list(m.default_datasets), + max_dataset_size=m.max_dataset_size, + ) + ) + return results + except Exception as e: + logger.warning(f"Failed to get scenarios from registry: {e}") + return [] + + def get_scorers( + self, + *, + scorer_type: Optional[str] = None, + ) -> List[ScorerMetadataResponse]: + """ + Get registered scorer instances. + + Args: + scorer_type: Filter by scorer type ('true_false' or 'float_scale'). + + Returns: + List of scorer metadata. + """ + try: + registry = ScorerRegistry.get_registry_singleton() + + # Build filter if scorer_type specified + include_filters: dict[str, object] | None = None + if scorer_type: + include_filters = {"scorer_type": scorer_type} + + metadata_list = registry.list_metadata(include_filters=include_filters) + + results = [] + for m in metadata_list: + # Get scorer identifier and filter sensitive fields + scorer_id = m.scorer_identifier.to_compact_dict() if m.scorer_identifier else {} + filtered_id = filter_sensitive_fields(scorer_id) + + results.append( + ScorerMetadataResponse( + name=m.name, + class_name=m.class_name, + description=m.class_description or "", + scorer_type=m.scorer_type, + scorer_identifier=filtered_id, + ) + ) + return results + except Exception as e: + logger.warning(f"Failed to get scorers from registry: {e}") + return [] + + def get_initializers(self) -> List[InitializerMetadataResponse]: + """ + Get all available initializers from the registry. + + Returns: + List of initializer metadata. + """ + try: + registry = InitializerRegistry.get_registry_singleton() + metadata_list = registry.list_metadata() + + results = [] + for m in metadata_list: + results.append( + InitializerMetadataResponse( + name=m.name, + class_name=m.class_name, + description=m.class_description or "", + required_env_vars=list(m.required_env_vars) if m.required_env_vars else [], + execution_order=getattr(m, "execution_order", 0), + ) + ) + return results + except Exception as e: + logger.warning(f"Failed to get initializers from registry: {e}") + return [] + + def get_converters( + self, + *, + is_llm_based: Optional[bool] = None, + is_deterministic: Optional[bool] = None, + ) -> List[ConverterMetadataResponse]: + """ + Get available converters via introspection. + + Note: ConverterRegistry may not exist yet. + Falls back to class introspection. + + Args: + is_llm_based: Filter by LLM-based converters. + is_deterministic: Filter by deterministic converters. + + Returns: + List of converter metadata. + """ + # Get all converter subclasses using the shared helper + converter_classes = _get_all_subclasses(PromptConverter) + + results = [] + for converter_class in converter_classes: + # Get supported types from class attributes + input_types = getattr(converter_class, "SUPPORTED_INPUT_TYPES", ["text"]) + output_types = getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ["text"]) + + # Determine if LLM-based (has converter_target parameter) + converter_is_llm_based = False + try: + sig = inspect.signature(converter_class) + converter_is_llm_based = "converter_target" in sig.parameters + except Exception: + pass + + if is_llm_based is not None and converter_is_llm_based != is_llm_based: + continue + + # Assume deterministic if not LLM-based + converter_is_deterministic = not converter_is_llm_based + + if is_deterministic is not None and converter_is_deterministic != is_deterministic: + continue + + results.append( + ConverterMetadataResponse( + name=converter_class.__name__, + class_name=converter_class.__name__, + description=(converter_class.__doc__ or "").split("\n")[0].strip(), + supported_input_types=list(input_types), + supported_output_types=list(output_types), + is_llm_based=converter_is_llm_based, + is_deterministic=converter_is_deterministic, + params_schema=_extract_params_schema(converter_class), + ) + ) + + return results + + +# Singleton instance +_registry_service: Optional[RegistryService] = None + + +def get_registry_service() -> RegistryService: + """ + Get the registry service singleton. + + Returns: + RegistryService: The registry service instance. + """ + global _registry_service + if _registry_service is None: + _registry_service = RegistryService() + return _registry_service diff --git a/tests/unit/backend/__init__.py b/tests/unit/backend/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/backend/test_common_models.py b/tests/unit/backend/test_common_models.py new file mode 100644 index 0000000000..06739e5c9e --- /dev/null +++ b/tests/unit/backend/test_common_models.py @@ -0,0 +1,405 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend common models. +""" + + + +from pyrit.backend.models.common import ( + FieldError, + IdentifierDict, + PaginatedResponse, + PaginationInfo, + ProblemDetail, + filter_sensitive_fields, +) + + +class TestPaginationInfo: + """Tests for PaginationInfo model.""" + + def test_pagination_info_creation(self) -> None: + """Test creating a PaginationInfo object.""" + info = PaginationInfo(limit=50, has_more=True, next_cursor="abc123") + + assert info.limit == 50 + assert info.has_more is True + assert info.next_cursor == "abc123" + assert info.prev_cursor is None + + def test_pagination_info_full(self) -> None: + """Test creating a PaginationInfo with all fields.""" + info = PaginationInfo( + limit=100, + has_more=False, + next_cursor="next", + prev_cursor="prev", + ) + + assert info.limit == 100 + assert info.has_more is False + assert info.next_cursor == "next" + assert info.prev_cursor == "prev" + + +class TestPaginatedResponse: + """Tests for PaginatedResponse model.""" + + def test_paginated_response_with_strings(self) -> None: + """Test creating a paginated response with string items.""" + pagination = PaginationInfo(limit=10, has_more=False) + response = PaginatedResponse[str]( + items=["a", "b", "c"], + pagination=pagination, + ) + + assert len(response.items) == 3 + assert response.items[0] == "a" + assert response.pagination.limit == 10 + + +class TestFieldError: + """Tests for FieldError model.""" + + def test_field_error_minimal(self) -> None: + """Test creating a FieldError with minimal fields.""" + error = FieldError(field="name", message="Required field") + + assert error.field == "name" + assert error.message == "Required field" + assert error.code is None + assert error.value is None + + def test_field_error_full(self) -> None: + """Test creating a FieldError with all fields.""" + error = FieldError( + field="pieces[0].data_type", + message="Invalid value", + code="type_error", + value="invalid", + ) + + assert error.field == "pieces[0].data_type" + assert error.message == "Invalid value" + assert error.code == "type_error" + assert error.value == "invalid" + + +class TestProblemDetail: + """Tests for ProblemDetail model.""" + + def test_problem_detail_minimal(self) -> None: + """Test creating a ProblemDetail with minimal fields.""" + problem = ProblemDetail( + type="/errors/test", + title="Test Error", + status=400, + detail="A test error occurred", + ) + + assert problem.type == "/errors/test" + assert problem.title == "Test Error" + assert problem.status == 400 + assert problem.detail == "A test error occurred" + assert problem.instance is None + assert problem.errors is None + + def test_problem_detail_with_errors(self) -> None: + """Test creating a ProblemDetail with field errors.""" + errors = [ + FieldError(field="name", message="Required"), + FieldError(field="age", message="Must be positive"), + ] + problem = ProblemDetail( + type="/errors/validation", + title="Validation Error", + status=422, + detail="Request validation failed", + instance="/api/v1/test", + errors=errors, + ) + + assert len(problem.errors) == 2 + assert problem.instance == "/api/v1/test" + + +class TestIdentifierDict: + """Tests for IdentifierDict model.""" + + def test_identifier_dict_creation(self) -> None: + """Test creating an IdentifierDict.""" + identifier = IdentifierDict(__type__="TestClass", __module__="pyrit.test") + + assert identifier.type_ == "TestClass" + assert identifier.module_ == "pyrit.test" + + +class TestFilterSensitiveFields: + """Tests for filter_sensitive_fields function.""" + + def test_filter_removes_api_key(self) -> None: + """Test that API keys are filtered out.""" + data = { + "name": "test", + "api_key": "secret123", + "endpoint": "https://api.test.com", + } + + result = filter_sensitive_fields(data) + + assert "name" in result + assert "endpoint" in result + assert "api_key" not in result + + def test_filter_removes_password(self) -> None: + """Test that passwords are filtered out.""" + data = { + "username": "user", + "password": "secret", + } + + result = filter_sensitive_fields(data) + + assert "username" in result + assert "password" not in result + + def test_filter_removes_token(self) -> None: + """Test that tokens are filtered out.""" + data = { + "access_token": "abc123", + "refresh_token": "xyz789", + "data": "public", + } + + result = filter_sensitive_fields(data) + + assert "data" in result + assert "access_token" not in result + assert "refresh_token" not in result + + def test_filter_handles_nested_dicts(self) -> None: + """Test that nested dictionaries are recursively filtered.""" + data = { + "config": { + "api_key": "secret", + "endpoint": "https://test.com", + }, + "name": "test", + } + + result = filter_sensitive_fields(data) + + assert result["name"] == "test" + assert "api_key" not in result["config"] + assert result["config"]["endpoint"] == "https://test.com" + + def test_filter_handles_lists(self) -> None: + """Test that lists with dicts are filtered.""" + data = { + "items": [ + {"api_key": "secret1", "id": 1}, + {"api_key": "secret2", "id": 2}, + ], + } + + result = filter_sensitive_fields(data) + + assert len(result["items"]) == 2 + assert "api_key" not in result["items"][0] + assert result["items"][0]["id"] == 1 + + def test_filter_non_dict_returns_as_is(self) -> None: + """Test that non-dict input is returned as-is.""" + result = filter_sensitive_fields("not a dict") # type: ignore[arg-type] + assert result == "not a dict" + + def test_filter_preserves_allowed_fields(self) -> None: + """Test that allowed fields are preserved.""" + data = { + "model_name": "gpt-4", + "temperature": 0.7, + "deployment_name": "my-deployment", + "api_key": "secret", + } + + result = filter_sensitive_fields(data) + + assert result["model_name"] == "gpt-4" + assert result["temperature"] == 0.7 + assert result["deployment_name"] == "my-deployment" + assert "api_key" not in result + + def test_filter_removes_secret_fields(self) -> None: + """Test that secret-related fields are filtered out.""" + data = { + "client_secret": "secret123", + "secret_key": "key456", + "model_name": "gpt-4", + } + + result = filter_sensitive_fields(data) + + assert "client_secret" not in result + assert "secret_key" not in result + assert result["model_name"] == "gpt-4" + + def test_filter_removes_credential_fields(self) -> None: + """Test that credential-related fields are filtered out.""" + data = { + "credentials": "cred123", + "user_credential": "cred456", + "endpoint": "https://api.test.com", + } + + result = filter_sensitive_fields(data) + + assert "credentials" not in result + assert "user_credential" not in result + assert result["endpoint"] == "https://api.test.com" + + def test_filter_removes_auth_fields(self) -> None: + """Test that auth-related fields are filtered out.""" + data = { + "auth_header": "Bearer token", + "authorization": "secret", + "username": "user", + } + + result = filter_sensitive_fields(data) + + assert "auth_header" not in result + assert "authorization" not in result + assert result["username"] == "user" + + def test_filter_empty_dict(self) -> None: + """Test filtering an empty dictionary.""" + result = filter_sensitive_fields({}) + + assert result == {} + + def test_filter_deeply_nested_dicts(self) -> None: + """Test filtering deeply nested dictionaries.""" + data = { + "level1": { + "level2": { + "level3": { + "api_key": "secret", + "data": "public", + } + } + } + } + + result = filter_sensitive_fields(data) + + assert result["level1"]["level2"]["level3"]["data"] == "public" + assert "api_key" not in result["level1"]["level2"]["level3"] + + def test_filter_list_with_non_dict_items(self) -> None: + """Test filtering lists containing non-dict items.""" + data = { + "items": ["string1", 123, True, None], + "api_key": "secret", + } + + result = filter_sensitive_fields(data) + + assert result["items"] == ["string1", 123, True, None] + assert "api_key" not in result + + def test_filter_mixed_list(self) -> None: + """Test filtering lists with mixed dict and non-dict items.""" + data = { + "items": [ + {"api_key": "secret", "id": 1}, + "string", + {"password": "pass", "name": "test"}, + ], + } + + result = filter_sensitive_fields(data) + + assert len(result["items"]) == 3 + assert result["items"][0] == {"id": 1} + assert result["items"][1] == "string" + assert result["items"][2] == {"name": "test"} + + def test_filter_case_insensitive(self) -> None: + """Test that filtering is case-insensitive.""" + data = { + "API_KEY": "secret", + "Api_Key": "secret2", + "apikey": "secret3", + "name": "test", + } + + result = filter_sensitive_fields(data) + + # All variations should be filtered + assert "API_KEY" not in result + assert "Api_Key" not in result + # Note: "apikey" contains "key" so should be filtered + assert "apikey" not in result + assert result["name"] == "test" + + +class TestPaginationInfoEdgeCases: + """Edge case tests for PaginationInfo.""" + + def test_pagination_with_zero_limit(self) -> None: + """Test creating pagination with zero limit.""" + # This tests the model creation, validation should happen at API level + info = PaginationInfo(limit=0, has_more=False) + + assert info.limit == 0 + + def test_pagination_with_large_limit(self) -> None: + """Test creating pagination with large limit.""" + info = PaginationInfo(limit=10000, has_more=True) + + assert info.limit == 10000 + + def test_pagination_with_empty_cursors(self) -> None: + """Test pagination with empty string cursors.""" + info = PaginationInfo( + limit=50, + has_more=False, + next_cursor="", + prev_cursor="", + ) + + assert info.next_cursor == "" + assert info.prev_cursor == "" + + +class TestProblemDetailEdgeCases: + """Edge case tests for ProblemDetail.""" + + def test_problem_detail_with_empty_errors_list(self) -> None: + """Test ProblemDetail with empty errors list.""" + problem = ProblemDetail( + type="/errors/test", + title="Test", + status=400, + detail="Test error", + errors=[], + ) + + assert problem.errors == [] + + def test_problem_detail_serialization(self) -> None: + """Test ProblemDetail JSON serialization.""" + problem = ProblemDetail( + type="/errors/test", + title="Test", + status=400, + detail="Test error", + ) + + data = problem.model_dump(exclude_none=True) + + assert "instance" not in data # None should be excluded + assert data["type"] == "/errors/test" + diff --git a/tests/unit/backend/test_conversation_service.py b/tests/unit/backend/test_conversation_service.py new file mode 100644 index 0000000000..da700c2d59 --- /dev/null +++ b/tests/unit/backend/test_conversation_service.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend conversation service. +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.backend.models.conversations import ( + CreateConversationRequest, + ConverterConfig, +) +from pyrit.backend.services.conversation_service import ( + ConversationService, + ConversationState, + get_conversation_service, +) + + +class TestConversationState: + """Tests for ConversationState model.""" + + def test_conversation_state_creation(self) -> None: + """Test creating a conversation state.""" + state = ConversationState( + conversation_id="test-id", + target_class="OpenAIChatTarget", + target_identifier={"endpoint": "test"}, + created_at=datetime.utcnow(), + ) + + assert state.conversation_id == "test-id" + assert state.target_class == "OpenAIChatTarget" + assert state.converters == [] + + def test_conversation_state_with_system_prompt(self) -> None: + """Test conversation state with system prompt.""" + state = ConversationState( + conversation_id="test-id", + target_class="OpenAIChatTarget", + target_identifier={}, + system_prompt="Test prompt", + created_at=datetime.utcnow(), + ) + + assert state.system_prompt == "Test prompt" + + def test_conversation_state_defaults(self) -> None: + """Test conversation state default values.""" + state = ConversationState( + conversation_id="test-id", + target_class="OpenAIChatTarget", + target_identifier={}, + created_at=datetime.utcnow(), + ) + + assert state.system_prompt is None + assert state.converters == [] + assert state.message_count == 0 + assert state.labels is None + + +class TestConversationService: + """Tests for ConversationService.""" + + @pytest.fixture + def service(self, patch_central_database: MagicMock) -> ConversationService: + """Create a conversation service instance. + + Args: + patch_central_database: The patched central database fixture. + + Returns: + ConversationService: The service instance. + """ + return ConversationService() + + @pytest.mark.asyncio + async def test_create_conversation_success( + self, service: ConversationService + ) -> None: + """Test creating a conversation successfully.""" + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"type": "TextTarget"} + + with patch.object( + service, "_instantiate_target_from_class", return_value=mock_target + ): + request = CreateConversationRequest(target_class="TextTarget") + result = await service.create_conversation(request) + + assert result is not None + assert result.conversation_id is not None + + @pytest.mark.asyncio + async def test_create_conversation_with_labels( + self, service: ConversationService + ) -> None: + """Test creating a conversation with labels.""" + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"type": "TextTarget"} + + with patch.object( + service, "_instantiate_target_from_class", return_value=mock_target + ): + request = CreateConversationRequest( + target_class="TextTarget", + labels={"test": "label"}, + ) + result = await service.create_conversation(request) + + assert result.labels == {"test": "label"} + + @pytest.mark.asyncio + async def test_create_conversation_invalid_target_class( + self, service: ConversationService + ) -> None: + """Test creating a conversation with invalid target class.""" + with patch.object( + service, + "_instantiate_target_from_class", + side_effect=ValueError("Target class 'InvalidTarget' not found"), + ): + request = CreateConversationRequest(target_class="InvalidTarget") + + with pytest.raises(ValueError, match="not found"): + await service.create_conversation(request) + + @pytest.mark.asyncio + async def test_get_conversation_existing( + self, service: ConversationService + ) -> None: + """Test getting an existing conversation.""" + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"type": "TextTarget"} + + with patch.object( + service, "_instantiate_target_from_class", return_value=mock_target + ): + request = CreateConversationRequest(target_class="TextTarget") + created = await service.create_conversation(request) + + result = await service.get_conversation(created.conversation_id) + + assert result is not None + assert result.conversation_id == created.conversation_id + + @pytest.mark.asyncio + async def test_get_conversation_nonexistent( + self, service: ConversationService + ) -> None: + """Test getting a nonexistent conversation.""" + result = await service.get_conversation("nonexistent-id") + + assert result is None + + @pytest.mark.asyncio + async def test_get_conversation_messages_returns_list( + self, service: ConversationService + ) -> None: + """Test getting messages from a conversation.""" + messages = await service.get_conversation_messages("any-id") + + assert isinstance(messages, list) + + @pytest.mark.asyncio + async def test_cleanup_conversation_existing( + self, service: ConversationService + ) -> None: + """Test cleaning up an existing conversation.""" + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"type": "TextTarget"} + + with patch.object( + service, "_instantiate_target_from_class", return_value=mock_target + ): + request = CreateConversationRequest(target_class="TextTarget") + created = await service.create_conversation(request) + + service.cleanup_conversation(created.conversation_id) + + result = await service.get_conversation(created.conversation_id) + assert result is None + + @pytest.mark.asyncio + async def test_cleanup_conversation_removes_target_instance( + self, service: ConversationService + ) -> None: + """Test that cleanup removes target instance.""" + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"type": "TextTarget"} + + with patch.object( + service, "_instantiate_target_from_class", return_value=mock_target + ): + request = CreateConversationRequest(target_class="TextTarget") + created = await service.create_conversation(request) + + assert created.conversation_id in service._target_instances + + service.cleanup_conversation(created.conversation_id) + + assert created.conversation_id not in service._target_instances + + def test_cleanup_conversation_nonexistent_no_error( + self, service: ConversationService + ) -> None: + """Test cleaning up nonexistent conversation doesn't raise error.""" + # Should not raise any exception + service.cleanup_conversation("nonexistent-id") + + +class TestGetConversationServiceSingleton: + """Tests for get_conversation_service singleton function.""" + + def test_returns_conversation_service_instance( + self, patch_central_database: MagicMock + ) -> None: + """Test that get_conversation_service returns a ConversationService.""" + import pyrit.backend.services.conversation_service as module + + module._conversation_service = None + + service = get_conversation_service() + + assert isinstance(service, ConversationService) + + def test_returns_same_instance( + self, patch_central_database: MagicMock + ) -> None: + """Test that get_conversation_service returns the same instance.""" + import pyrit.backend.services.conversation_service as module + + module._conversation_service = None + + service1 = get_conversation_service() + service2 = get_conversation_service() + + assert service1 is service2 diff --git a/tests/unit/backend/test_error_handlers.py b/tests/unit/backend/test_error_handlers.py new file mode 100644 index 0000000000..b1afe028c7 --- /dev/null +++ b/tests/unit/backend/test_error_handlers.py @@ -0,0 +1,200 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend error handler middleware. +""" + +from unittest.mock import MagicMock + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + +from pyrit.backend.middleware.error_handlers import register_error_handlers + + +class TestErrorHandlers: + """Tests for RFC 7807 error handlers.""" + + @pytest.fixture + def app(self) -> FastAPI: + """Create a FastAPI app with error handlers registered. + + Returns: + FastAPI: The test app. + """ + app = FastAPI() + register_error_handlers(app) + return app + + @pytest.fixture + def client(self, app: FastAPI) -> TestClient: + """Create a test client. + + Args: + app: The FastAPI app. + + Returns: + TestClient: The test client. + """ + return TestClient(app, raise_server_exceptions=False) + + def test_validation_error_returns_422(self, app: FastAPI, client: TestClient) -> None: + """Test that validation errors return 422 with RFC 7807 format.""" + from pydantic import BaseModel + + class TestInput(BaseModel): + name: str + age: int + + @app.post("/test") + async def test_endpoint(data: TestInput) -> dict: + return {"ok": True} + + response = client.post("/test", json={"name": 123}) # Missing age, wrong type + + assert response.status_code == 422 + data = response.json() + assert data["type"] == "/errors/validation-error" + assert data["title"] == "Validation Error" + assert data["status"] == 422 + assert "errors" in data + + def test_validation_error_includes_field_details( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that validation errors include field-level details.""" + from pydantic import BaseModel + + class TestInput(BaseModel): + name: str + + @app.post("/test") + async def test_endpoint(data: TestInput) -> dict: + return {"ok": True} + + response = client.post("/test", json={}) # Missing required field + + data = response.json() + assert "errors" in data + assert len(data["errors"]) > 0 + # Check field error structure + error = data["errors"][0] + assert "field" in error + assert "message" in error + + def test_value_error_returns_400(self, app: FastAPI, client: TestClient) -> None: + """Test that ValueError returns 400 with RFC 7807 format.""" + + @app.get("/test") + async def test_endpoint() -> dict: + raise ValueError("Invalid input value") + + response = client.get("/test") + + assert response.status_code == 400 + data = response.json() + assert data["type"] == "/errors/bad-request" + assert data["title"] == "Bad Request" + assert data["status"] == 400 + assert "Invalid input value" in data["detail"] + + def test_file_not_found_error_returns_404( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that FileNotFoundError returns 404 with RFC 7807 format.""" + + @app.get("/test") + async def test_endpoint() -> dict: + raise FileNotFoundError("Resource not found") + + response = client.get("/test") + + assert response.status_code == 404 + data = response.json() + assert data["type"] == "/errors/not-found" + assert data["title"] == "Not Found" + assert data["status"] == 404 + + def test_permission_error_returns_403( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that PermissionError returns 403 with RFC 7807 format.""" + + @app.get("/test") + async def test_endpoint() -> dict: + raise PermissionError("Access denied") + + response = client.get("/test") + + assert response.status_code == 403 + data = response.json() + assert data["type"] == "/errors/forbidden" + assert data["title"] == "Forbidden" + assert data["status"] == 403 + + def test_not_implemented_error_returns_501( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that NotImplementedError returns 501 with RFC 7807 format.""" + + @app.get("/test") + async def test_endpoint() -> dict: + raise NotImplementedError("Feature not yet implemented") + + response = client.get("/test") + + assert response.status_code == 501 + data = response.json() + assert data["type"] == "/errors/not-implemented" + assert data["title"] == "Not Implemented" + assert data["status"] == 501 + + def test_generic_exception_returns_500( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that unexpected exceptions return 500 with RFC 7807 format.""" + + @app.get("/test") + async def test_endpoint() -> dict: + raise RuntimeError("Something went wrong") + + response = client.get("/test") + + assert response.status_code == 500 + data = response.json() + assert data["type"] == "/errors/internal-error" + assert data["title"] == "Internal Server Error" + assert data["status"] == 500 + # Should not leak internal error details + assert "An unexpected error occurred" in data["detail"] + + def test_error_response_includes_instance( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that error responses include the request path as instance.""" + + @app.get("/api/v1/test/resource") + async def test_endpoint() -> dict: + raise ValueError("Test error") + + response = client.get("/api/v1/test/resource") + + data = response.json() + assert data["instance"] == "/api/v1/test/resource" + + def test_error_excludes_none_fields( + self, app: FastAPI, client: TestClient + ) -> None: + """Test that None fields are excluded from error response.""" + + @app.get("/test") + async def test_endpoint() -> dict: + raise ValueError("Test error") + + response = client.get("/test") + + data = response.json() + # 'errors' should not be present for non-validation errors + assert "errors" not in data diff --git a/tests/unit/backend/test_memory_service.py b/tests/unit/backend/test_memory_service.py new file mode 100644 index 0000000000..95b08f63e3 --- /dev/null +++ b/tests/unit/backend/test_memory_service.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend memory service. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from pyrit.backend.models.common import PaginatedResponse +from pyrit.backend.services.memory_service import ( + MemoryService, + get_memory_service, + _parse_cursor, + _build_cursor, +) + + +class TestCursorFunctions: + """Tests for cursor parsing and building functions.""" + + def test_parse_cursor_with_valid_cursor(self) -> None: + """Test parsing a valid cursor string.""" + timestamp = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc) + cursor = f"{timestamp.isoformat()}_abc123" + + parsed_time, parsed_id = _parse_cursor(cursor) + + assert parsed_id == "abc123" + assert parsed_time is not None + assert parsed_time.year == 2024 + + def test_parse_cursor_with_none(self) -> None: + """Test parsing None cursor.""" + parsed_time, parsed_id = _parse_cursor(None) + + assert parsed_time is None + assert parsed_id is None + + def test_parse_cursor_with_empty_string(self) -> None: + """Test parsing empty cursor string.""" + parsed_time, parsed_id = _parse_cursor("") + + assert parsed_time is None + assert parsed_id is None + + def test_parse_cursor_with_invalid_format(self) -> None: + """Test parsing cursor with invalid format.""" + parsed_time, parsed_id = _parse_cursor("invalid_cursor_without_timestamp") + + assert parsed_time is None + assert parsed_id is None + + def test_parse_cursor_with_malformed_timestamp(self) -> None: + """Test parsing cursor with malformed timestamp.""" + parsed_time, parsed_id = _parse_cursor("not-a-timestamp_abc123") + + assert parsed_time is None + assert parsed_id is None + + def test_build_cursor_creates_valid_string(self) -> None: + """Test building a cursor string.""" + timestamp = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc) + record_id = "test-id-123" + + cursor = _build_cursor(timestamp, record_id) + + assert record_id in cursor + assert timestamp.isoformat() in cursor + + def test_cursor_roundtrip(self) -> None: + """Test that a cursor can be built and parsed back.""" + original_time = datetime(2024, 6, 15, 14, 30, 0, tzinfo=timezone.utc) + original_id = "message-uuid-123" + + cursor = _build_cursor(original_time, original_id) + parsed_time, parsed_id = _parse_cursor(cursor) + + assert parsed_id == original_id + assert parsed_time is not None + + +class TestMemoryService: + """Tests for MemoryService.""" + + @pytest.fixture + def service(self, patch_central_database: MagicMock) -> MemoryService: + """Create a memory service with patched database. + + Args: + patch_central_database: The patched central database fixture. + + Returns: + MemoryService: The service instance. + """ + return MemoryService() + + @pytest.mark.asyncio + async def test_get_messages_returns_paginated_result( + self, service: MemoryService + ) -> None: + """Test that get_messages returns paginated results.""" + result = await service.get_messages() + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.items, list) + assert result.pagination is not None + + @pytest.mark.asyncio + async def test_get_messages_with_conversation_id( + self, service: MemoryService + ) -> None: + """Test filtering messages by conversation ID.""" + result = await service.get_messages(conversation_id="test-conv-id") + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_messages_respects_limit( + self, service: MemoryService + ) -> None: + """Test that limit parameter is respected.""" + result = await service.get_messages(limit=10) + + assert len(result.items) <= 10 + + @pytest.mark.asyncio + async def test_get_messages_with_role_filter( + self, service: MemoryService + ) -> None: + """Test filtering messages by role.""" + result = await service.get_messages(role="user") + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_messages_with_time_filters( + self, service: MemoryService + ) -> None: + """Test filtering messages by time range.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 12, 31, tzinfo=timezone.utc) + + result = await service.get_messages(start_time=start, end_time=end) + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_messages_pagination_has_more( + self, service: MemoryService + ) -> None: + """Test that pagination correctly reports has_more.""" + result = await service.get_messages(limit=1) + + assert isinstance(result.pagination.has_more, bool) + + @pytest.mark.asyncio + async def test_get_scores_returns_paginated_result( + self, service: MemoryService + ) -> None: + """Test that get_scores returns paginated results.""" + result = await service.get_scores() + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.items, list) + assert result.pagination is not None + + @pytest.mark.asyncio + async def test_get_scores_with_message_id( + self, service: MemoryService + ) -> None: + """Test filtering scores by message ID.""" + result = await service.get_scores(message_id="test-message-id") + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_scores_with_score_type( + self, service: MemoryService + ) -> None: + """Test filtering scores by score type.""" + result = await service.get_scores(score_type="true_false") + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_attack_results_returns_paginated_result( + self, service: MemoryService + ) -> None: + """Test that get_attack_results returns paginated results.""" + result = await service.get_attack_results() + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.items, list) + assert result.pagination is not None + + @pytest.mark.asyncio + async def test_get_attack_results_with_outcome_filter( + self, service: MemoryService + ) -> None: + """Test filtering attack results by outcome.""" + result = await service.get_attack_results(outcome="success") + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_attack_results_with_turn_filters( + self, service: MemoryService + ) -> None: + """Test filtering attack results by turn count.""" + result = await service.get_attack_results(min_turns=1, max_turns=10) + + assert isinstance(result, PaginatedResponse) + + @pytest.mark.asyncio + async def test_get_seeds_returns_paginated_result( + self, service: MemoryService + ) -> None: + """Test that get_seeds returns paginated results.""" + result = await service.get_seeds() + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.items, list) + assert result.pagination is not None + + @pytest.mark.asyncio + async def test_get_scenario_results_returns_paginated_result( + self, service: MemoryService + ) -> None: + """Test that get_scenario_results returns paginated results.""" + result = await service.get_scenario_results() + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.items, list) + assert result.pagination is not None + + +class TestGetMemoryServiceSingleton: + """Tests for get_memory_service singleton function.""" + + def test_returns_memory_service_instance( + self, patch_central_database: MagicMock + ) -> None: + """Test that get_memory_service returns a MemoryService.""" + import pyrit.backend.services.memory_service as module + + module._memory_service = None + + service = get_memory_service() + + assert isinstance(service, MemoryService) + + def test_returns_same_instance( + self, patch_central_database: MagicMock + ) -> None: + """Test that get_memory_service returns the same instance.""" + import pyrit.backend.services.memory_service as module + + module._memory_service = None + + service1 = get_memory_service() + service2 = get_memory_service() + + assert service1 is service2 diff --git a/tests/unit/backend/test_registry_service.py b/tests/unit/backend/test_registry_service.py new file mode 100644 index 0000000000..95e964c280 --- /dev/null +++ b/tests/unit/backend/test_registry_service.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend registry service. +""" + + + +from pyrit.backend.services.registry_service import ( + RegistryService, + _extract_params_schema, + _get_all_subclasses, + get_registry_service, +) + + +class TestExtractParamsSchema: + """Tests for _extract_params_schema helper function.""" + + def test_extract_params_with_required_and_optional(self) -> None: + """Test extracting params from a class with required and optional params.""" + + class TestClass: + def __init__(self, required_param: str, optional_param: str = "default") -> None: + pass + + result = _extract_params_schema(TestClass) + + assert "required_param" in result["required"] + assert "optional_param" in result["optional"] + + def test_extract_params_ignores_self(self) -> None: + """Test that self is ignored in param extraction.""" + + class TestClass: + def __init__(self, param: str) -> None: + pass + + result = _extract_params_schema(TestClass) + + assert "self" not in result["required"] + assert "self" not in result["optional"] + + +class TestGetAllSubclasses: + """Tests for _get_all_subclasses helper function.""" + + def test_get_subclasses_finds_concrete_classes(self) -> None: + """Test that concrete subclasses are found.""" + + class Base: + pass + + class Child1(Base): + pass + + class Child2(Base): + pass + + result = _get_all_subclasses(Base) + + assert Child1 in result + assert Child2 in result + + +class TestRegistryService: + """Tests for RegistryService.""" + + def test_get_targets_returns_list(self) -> None: + """Test that get_targets returns a list.""" + service = RegistryService() + + result = service.get_targets() + + assert isinstance(result, list) + + def test_get_targets_filters_chat_targets(self) -> None: + """Test that get_targets can filter by chat target support.""" + service = RegistryService() + + chat_only = service.get_targets(is_chat_target=True) + non_chat = service.get_targets(is_chat_target=False) + + # Chat targets and non-chat targets should be different + chat_names = {t.name for t in chat_only} + non_chat_names = {t.name for t in non_chat} + # They should be disjoint (no overlap) + assert len(chat_names & non_chat_names) == 0 + + def test_get_converters_returns_list(self) -> None: + """Test that get_converters returns a list.""" + service = RegistryService() + + result = service.get_converters() + + assert isinstance(result, list) + + def test_get_converters_filters_llm_based(self) -> None: + """Test that get_converters can filter by LLM-based status.""" + service = RegistryService() + + llm_based = service.get_converters(is_llm_based=True) + non_llm = service.get_converters(is_llm_based=False) + + # LLM-based and non-LLM converters should be different + llm_names = {c.name for c in llm_based} + non_llm_names = {c.name for c in non_llm} + # They should be disjoint + assert len(llm_names & non_llm_names) == 0 + + def test_get_scenarios_returns_list(self) -> None: + """Test that get_scenarios returns a list.""" + service = RegistryService() + + result = service.get_scenarios() + + assert isinstance(result, list) + + def test_get_scorers_returns_list(self) -> None: + """Test that get_scorers returns a list.""" + service = RegistryService() + + result = service.get_scorers() + + assert isinstance(result, list) + + def test_get_initializers_returns_list(self) -> None: + """Test that get_initializers returns a list.""" + service = RegistryService() + + result = service.get_initializers() + + assert isinstance(result, list) + + +class TestGetRegistryServiceSingleton: + """Tests for get_registry_service singleton function.""" + + def test_returns_same_instance(self) -> None: + """Test that get_registry_service returns the same instance.""" + # Reset singleton for test + import pyrit.backend.services.registry_service as module + + module._registry_service = None + + service1 = get_registry_service() + service2 = get_registry_service() + + assert service1 is service2 + + def test_returns_registry_service_instance(self) -> None: + """Test that get_registry_service returns a RegistryService.""" + service = get_registry_service() + + assert isinstance(service, RegistryService) diff --git a/tests/unit/backend/test_routes.py b/tests/unit/backend/test_routes.py new file mode 100644 index 0000000000..5a907bdedc --- /dev/null +++ b/tests/unit/backend/test_routes.py @@ -0,0 +1,351 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend API routes. +""" + +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from pyrit.backend.models.common import PaginatedResponse, PaginationInfo +from pyrit.backend.models.conversations import CreateConversationResponse +from pyrit.backend.models.memory import MessageQueryResponse +from pyrit.backend.routes import health, version + + +class TestHealthRoute: + """Tests for health check endpoint.""" + + @pytest.fixture + def client(self) -> TestClient: + """Create a test client for health routes. + + Returns: + TestClient: The test client. + """ + app = FastAPI() + app.include_router(health.router) + return TestClient(app) + + def test_health_returns_200(self, client: TestClient) -> None: + """Test that health endpoint returns 200.""" + response = client.get("/health") + + assert response.status_code == 200 + + def test_health_returns_healthy_status(self, client: TestClient) -> None: + """Test that health endpoint returns healthy status.""" + response = client.get("/health") + data = response.json() + + assert data["status"] == "healthy" + + def test_health_returns_service_name(self, client: TestClient) -> None: + """Test that health endpoint returns service name.""" + response = client.get("/health") + data = response.json() + + assert data["service"] == "pyrit-backend" + + def test_health_returns_timestamp(self, client: TestClient) -> None: + """Test that health endpoint returns timestamp.""" + response = client.get("/health") + data = response.json() + + assert "timestamp" in data + # Verify it's a valid ISO format timestamp + datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) + + +class TestVersionRoute: + """Tests for version endpoint.""" + + @pytest.fixture + def client(self) -> TestClient: + """Create a test client for version routes. + + Returns: + TestClient: The test client. + """ + app = FastAPI() + app.include_router(version.router) + return TestClient(app) + + def test_version_returns_200(self, client: TestClient) -> None: + """Test that version endpoint returns 200.""" + response = client.get("/api/version") + + assert response.status_code == 200 + + def test_version_returns_version_string(self, client: TestClient) -> None: + """Test that version endpoint returns version string.""" + response = client.get("/api/version") + data = response.json() + + assert "version" in data + assert isinstance(data["version"], str) + + def test_version_returns_display_string(self, client: TestClient) -> None: + """Test that version endpoint returns display string.""" + response = client.get("/api/version") + data = response.json() + + assert "display" in data + assert isinstance(data["display"], str) + + +class TestRegistryRoutes: + """Tests for registry endpoints.""" + + @pytest.fixture + def mock_service(self) -> MagicMock: + """Create a mock registry service. + + Returns: + MagicMock: The mock service. + """ + return MagicMock() + + @pytest.fixture + def client(self) -> TestClient: + """Create a test client for registry routes. + + Returns: + TestClient: The test client. + """ + from pyrit.backend.routes import registry + + app = FastAPI() + app.include_router(registry.router) + return TestClient(app) + + def test_list_targets_returns_200(self, client: TestClient) -> None: + """Test that list targets returns 200.""" + mock_service = MagicMock() + mock_service.get_targets.return_value = [] + + with patch( + "pyrit.backend.routes.registry.get_registry_service", + return_value=mock_service, + ): + response = client.get("/registry/targets") + + assert response.status_code == 200 + + def test_list_targets_returns_list(self, client: TestClient) -> None: + """Test that list targets returns a list.""" + mock_service = MagicMock() + mock_service.get_targets.return_value = [] + + with patch( + "pyrit.backend.routes.registry.get_registry_service", + return_value=mock_service, + ): + response = client.get("/registry/targets") + data = response.json() + + assert isinstance(data, list) + + def test_list_converters_returns_200(self, client: TestClient) -> None: + """Test that list converters returns 200.""" + mock_service = MagicMock() + mock_service.get_converters.return_value = [] + + with patch( + "pyrit.backend.routes.registry.get_registry_service", + return_value=mock_service, + ): + response = client.get("/registry/converters") + + assert response.status_code == 200 + + def test_list_scorers_returns_200(self, client: TestClient) -> None: + """Test that list scorers returns 200.""" + mock_service = MagicMock() + mock_service.get_scorers.return_value = [] + + with patch( + "pyrit.backend.routes.registry.get_registry_service", + return_value=mock_service, + ): + response = client.get("/registry/scorers") + + assert response.status_code == 200 + + def test_list_scenarios_returns_200(self, client: TestClient) -> None: + """Test that list scenarios returns 200.""" + mock_service = MagicMock() + mock_service.get_scenarios.return_value = [] + + with patch( + "pyrit.backend.routes.registry.get_registry_service", + return_value=mock_service, + ): + response = client.get("/registry/scenarios") + + assert response.status_code == 200 + + +class TestConversationRoutes: + """Tests for conversation endpoints.""" + + @pytest.fixture + def client(self, patch_central_database: MagicMock) -> TestClient: + """Create a test client for conversation routes. + + Args: + patch_central_database: The patched central database fixture. + + Returns: + TestClient: The test client. + """ + from pyrit.backend.routes import conversations + + app = FastAPI() + app.include_router(conversations.router) + return TestClient(app) + + def test_create_conversation_returns_201(self, client: TestClient, patch_central_database: MagicMock) -> None: + """Test that create conversation returns 201.""" + mock_service = MagicMock() + mock_response = CreateConversationResponse( + conversation_id="test-id", + target_identifier={"__type__": "TextTarget"}, + labels=None, + created_at=datetime.now(), + ) + mock_service.create_conversation = AsyncMock(return_value=mock_response) + + with patch( + "pyrit.backend.routes.conversations.get_conversation_service", + return_value=mock_service, + ): + response = client.post( + "/conversations", + json={ + "target_class": "TextTarget", + "target_params": None, + }, + ) + + assert response.status_code == 201 + + def test_get_conversation_returns_404_for_missing( + self, client: TestClient, patch_central_database: MagicMock + ) -> None: + """Test that get conversation returns 404 for missing.""" + mock_service = MagicMock() + mock_service.get_conversation = AsyncMock(return_value=None) + + with patch( + "pyrit.backend.routes.conversations.get_conversation_service", + return_value=mock_service, + ): + response = client.get("/conversations/nonexistent") + + assert response.status_code == 404 + + def test_delete_conversation_returns_204(self, client: TestClient, patch_central_database: MagicMock) -> None: + """Test that delete conversation returns 204.""" + mock_service = MagicMock() + # Must return a conversation state for delete to work + mock_service.get_conversation = AsyncMock(return_value=MagicMock(conversation_id="conv-1")) + mock_service.cleanup_conversation = MagicMock() + + with patch( + "pyrit.backend.routes.conversations.get_conversation_service", + return_value=mock_service, + ): + response = client.delete("/conversations/conv-1") + + assert response.status_code == 204 + + +class TestMemoryRoutes: + """Tests for memory endpoints.""" + + @pytest.fixture + def client(self) -> TestClient: + """Create a test client for memory routes. + + Returns: + TestClient: The test client. + """ + from pyrit.backend.routes import memory + + app = FastAPI() + app.include_router(memory.router) + return TestClient(app) + + def test_query_messages_returns_200(self, client: TestClient) -> None: + """Test that query messages returns 200.""" + mock_service = MagicMock() + mock_response = PaginatedResponse[MessageQueryResponse]( + items=[], + pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), + ) + mock_service.get_messages = AsyncMock(return_value=mock_response) + + with patch( + "pyrit.backend.routes.memory.get_memory_service", + return_value=mock_service, + ): + response = client.get("/memory/messages") + + assert response.status_code == 200 + + def test_query_scores_returns_200(self, client: TestClient) -> None: + """Test that query scores returns 200.""" + mock_service = MagicMock() + mock_response = PaginatedResponse( + items=[], + pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), + ) + mock_service.get_scores = AsyncMock(return_value=mock_response) + + with patch( + "pyrit.backend.routes.memory.get_memory_service", + return_value=mock_service, + ): + response = client.get("/memory/scores") + + assert response.status_code == 200 + + def test_query_attack_results_returns_200(self, client: TestClient) -> None: + """Test that query attack results returns 200.""" + mock_service = MagicMock() + mock_response = PaginatedResponse( + items=[], + pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), + ) + mock_service.get_attack_results = AsyncMock(return_value=mock_response) + + with patch( + "pyrit.backend.routes.memory.get_memory_service", + return_value=mock_service, + ): + response = client.get("/memory/attack-results") + + assert response.status_code == 200 + + def test_query_seeds_returns_200(self, client: TestClient) -> None: + """Test that query seeds returns 200.""" + mock_service = MagicMock() + mock_response = PaginatedResponse( + items=[], + pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), + ) + mock_service.get_seeds = AsyncMock(return_value=mock_response) + + with patch( + "pyrit.backend.routes.memory.get_memory_service", + return_value=mock_service, + ): + response = client.get("/memory/seeds") + + assert response.status_code == 200 From 560c7b9406467f1b7c7648e7ab103dce968554b1 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 27 Jan 2026 14:38:22 -0800 Subject: [PATCH 02/35] add updated backend after review --- pyrit/backend/main.py | 24 +- pyrit/backend/models/__init__.py | 118 ++-- pyrit/backend/models/attacks.py | 199 ++++++ pyrit/backend/models/conversations.py | 201 ------ pyrit/backend/models/converters.py | 140 +++- pyrit/backend/models/memory.py | 125 ---- pyrit/backend/models/targets.py | 57 ++ pyrit/backend/routes/__init__.py | 6 +- pyrit/backend/routes/attacks.py | 216 +++++++ pyrit/backend/routes/conversations.py | 338 ---------- pyrit/backend/routes/converters.py | 205 ++++-- pyrit/backend/routes/memory.py | 206 ------ pyrit/backend/routes/targets.py | 135 ++++ pyrit/backend/services/__init__.py | 30 +- pyrit/backend/services/attack_service.py | 458 +++++++++++++ .../backend/services/conversation_service.py | 508 --------------- pyrit/backend/services/converter_service.py | 344 ++++++++++ pyrit/backend/services/memory_service.py | 444 ------------- pyrit/backend/services/target_service.py | 222 +++++++ tests/unit/backend/test_api_routes.py | 602 ++++++++++++++++++ tests/unit/backend/test_attack_service.py | 497 +++++++++++++++ tests/unit/backend/test_common_models.py | 3 - .../unit/backend/test_conversation_service.py | 243 ------- tests/unit/backend/test_converter_service.py | 538 ++++++++++++++++ tests/unit/backend/test_error_handlers.py | 32 +- tests/unit/backend/test_memory_service.py | 267 -------- tests/unit/backend/test_registry_service.py | 2 - tests/unit/backend/test_routes.py | 351 ---------- tests/unit/backend/test_target_service.py | 368 +++++++++++ 29 files changed, 4008 insertions(+), 2871 deletions(-) create mode 100644 pyrit/backend/models/attacks.py delete mode 100644 pyrit/backend/models/conversations.py delete mode 100644 pyrit/backend/models/memory.py create mode 100644 pyrit/backend/models/targets.py create mode 100644 pyrit/backend/routes/attacks.py delete mode 100644 pyrit/backend/routes/conversations.py delete mode 100644 pyrit/backend/routes/memory.py create mode 100644 pyrit/backend/routes/targets.py create mode 100644 pyrit/backend/services/attack_service.py delete mode 100644 pyrit/backend/services/conversation_service.py create mode 100644 pyrit/backend/services/converter_service.py delete mode 100644 pyrit/backend/services/memory_service.py create mode 100644 pyrit/backend/services/target_service.py create mode 100644 tests/unit/backend/test_api_routes.py create mode 100644 tests/unit/backend/test_attack_service.py delete mode 100644 tests/unit/backend/test_conversation_service.py create mode 100644 tests/unit/backend/test_converter_service.py delete mode 100644 tests/unit/backend/test_memory_service.py delete mode 100644 tests/unit/backend/test_routes.py create mode 100644 tests/unit/backend/test_target_service.py diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index e4dc422c8f..49bc824397 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -3,20 +3,22 @@ """ FastAPI application entry point for PyRIT backend. + +This is the attack-centric API - all interactions are modeled as "attacks". """ import os import sys from pathlib import Path -from fastapi import APIRouter, FastAPI +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles import pyrit from pyrit.backend.middleware import register_error_handlers -from pyrit.backend.routes import conversations, converters, health, memory, registry, version +from pyrit.backend.routes import attacks, converters, health, registry, targets, version from pyrit.setup.initialization import initialize_pyrit_async # Check for development mode from environment variable @@ -50,19 +52,11 @@ async def startup_event_async() -> None: ) -# Create versioned API router -api_v1 = APIRouter(prefix="/api/v1") - -# Include v1 routes -api_v1.include_router(conversations.router) -api_v1.include_router(converters.router) -api_v1.include_router(memory.router) -api_v1.include_router(registry.router) - -# Mount versioned API -app.include_router(api_v1) - -# Include legacy/non-versioned routes +# Include API routes +app.include_router(attacks.router, prefix="/api", tags=["attacks"]) +app.include_router(targets.router, prefix="/api", tags=["targets"]) +app.include_router(converters.router, prefix="/api", tags=["converters"]) +app.include_router(registry.router, prefix="/api", tags=["registry"]) app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(version.router, tags=["version"]) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 17554d2781..191f6c85cb 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -7,6 +7,21 @@ Pydantic models for API requests and responses. """ +from pyrit.backend.models.attacks import ( + AttackDetail, + AttackListResponse, + AttackSummary, + CreateAttackRequest, + CreateAttackResponse, + Message, + MessagePiece, + MessagePieceRequest, + PrependedMessageRequest, + Score, + SendMessageRequest, + SendMessageResponse, + UpdateAttackRequest, +) from pyrit.backend.models.common import ( ALLOWED_IDENTIFIER_FIELDS, SENSITIVE_FIELD_PATTERNS, @@ -17,36 +32,17 @@ ProblemDetail, filter_sensitive_fields, ) -from pyrit.backend.models.conversations import ( - BranchConversationRequest, - BranchConversationResponse, - ConversationResponse, - ConverterConfig, - ConvertersResponse, - CreateConversationRequest, - CreateConversationResponse, - MessagePieceInput, - MessagePieceResponse, - MessageResponse, - SendMessageRequest, - SendMessageResponse, - SetConvertersRequest, - SetSystemPromptRequest, - SystemPromptResponse, -) from pyrit.backend.models.converters import ( - ConversionStep, - ConverterListResponse, + ConverterInstance, + ConverterInstanceListResponse, ConverterMetadataResponse, - PreviewConverterRequest, - PreviewConverterResponse, -) -from pyrit.backend.models.memory import ( - AttackResultQueryResponse, - MessageQueryResponse, - ScenarioResultQueryResponse, - ScoreQueryResponse, - SeedQueryResponse, + ConverterPreviewRequest, + ConverterPreviewResponse, + CreateConverterRequest, + CreateConverterResponse, + InlineConverterConfig, + NestedConverterConfig, + PreviewStep, ) from pyrit.backend.models.registry import ( InitializerListResponse, @@ -55,11 +51,33 @@ ScenarioMetadataResponse, ScorerListResponse, ScorerMetadataResponse, - TargetListResponse, TargetMetadataResponse, ) +from pyrit.backend.models.registry import ( + TargetListResponse as RegistryTargetListResponse, +) +from pyrit.backend.models.targets import ( + CreateTargetRequest, + CreateTargetResponse, + TargetInstance, + TargetListResponse, +) __all__ = [ + # Attacks + "AttackDetail", + "AttackListResponse", + "AttackSummary", + "CreateAttackRequest", + "CreateAttackResponse", + "Message", + "MessagePiece", + "MessagePieceRequest", + "PrependedMessageRequest", + "Score", + "SendMessageRequest", + "SendMessageResponse", + "UpdateAttackRequest", # Common "ALLOWED_IDENTIFIER_FIELDS", "SENSITIVE_FIELD_PATTERNS", @@ -69,34 +87,17 @@ "PaginatedResponse", "PaginationInfo", "ProblemDetail", - # Conversations - "BranchConversationRequest", - "BranchConversationResponse", - "ConversationResponse", - "ConverterConfig", - "ConvertersResponse", - "CreateConversationRequest", - "CreateConversationResponse", - "MessagePieceInput", - "MessagePieceResponse", - "MessageResponse", - "SendMessageRequest", - "SendMessageResponse", - "SetConvertersRequest", - "SetSystemPromptRequest", - "SystemPromptResponse", # Converters - "ConversionStep", - "ConverterListResponse", + "ConverterInstance", + "ConverterInstanceListResponse", "ConverterMetadataResponse", - "PreviewConverterRequest", - "PreviewConverterResponse", - # Memory - "AttackResultQueryResponse", - "MessageQueryResponse", - "ScenarioResultQueryResponse", - "ScoreQueryResponse", - "SeedQueryResponse", + "ConverterPreviewRequest", + "ConverterPreviewResponse", + "CreateConverterRequest", + "CreateConverterResponse", + "InlineConverterConfig", + "NestedConverterConfig", + "PreviewStep", # Registry "InitializerListResponse", "InitializerMetadataResponse", @@ -104,6 +105,11 @@ "ScenarioMetadataResponse", "ScorerListResponse", "ScorerMetadataResponse", - "TargetListResponse", + "RegistryTargetListResponse", "TargetMetadataResponse", + # Targets + "CreateTargetRequest", + "CreateTargetResponse", + "TargetInstance", + "TargetListResponse", ] diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py new file mode 100644 index 0000000000..fd0d851d84 --- /dev/null +++ b/pyrit/backend/models/attacks.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Attack-related request and response models. + +All interactions in the UI are modeled as "attacks" - including manual conversations. +This is the attack-centric API design where every user interaction targets a model. +""" + +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.models.converters import InlineConverterConfig + + +class Score(BaseModel): + """A score associated with a message piece.""" + + score_id: str = Field(..., description="Unique score identifier") + scorer_type: str = Field(..., description="Type of scorer (e.g., 'bias', 'toxicity')") + score_value: float = Field(..., description="Numeric score value") + score_rationale: Optional[str] = Field(None, description="Explanation for the score") + scored_at: datetime = Field(..., description="When the score was generated") + + +class MessagePiece(BaseModel): + """ + A piece of a message (text, image, audio, etc.). + + Supports multimodal content with original/converted values and embedded scores. + Media content is base64-encoded since frontend can't access server file paths. + """ + + piece_id: str = Field(..., description="Unique piece identifier") + data_type: str = Field(..., description="Data type: 'text', 'image', 'audio', 'video', etc.") + original_value: Optional[str] = Field(None, description="Original value before conversion") + original_value_mime_type: Optional[str] = Field(None, description="MIME type of original value") + converted_value: str = Field(..., description="Converted value (text or base64 for media)") + converted_value_mime_type: Optional[str] = Field(None, description="MIME type of converted value") + scores: List[Score] = Field(default_factory=list, description="Scores embedded in this piece") + + +class Message(BaseModel): + """A message within an attack.""" + + message_id: str = Field(..., description="Unique message identifier") + turn_number: int = Field(..., description="Turn number in the conversation (1-indexed)") + role: Literal["user", "assistant", "system"] = Field(..., description="Message role") + pieces: List[MessagePiece] = Field(..., description="Message pieces (multimodal support)") + created_at: datetime = Field(..., description="Message creation timestamp") + + +# ============================================================================ +# Attack Summary (List View) +# ============================================================================ + + +class AttackSummary(BaseModel): + """Summary view of an attack (for list views, omits full message content).""" + + attack_id: str = Field(..., description="Unique attack identifier") + name: Optional[str] = Field(None, description="Attack name/label") + target_id: str = Field(..., description="Target instance ID") + target_type: str = Field(..., description="Target type (e.g., 'azure_openai')") + outcome: Optional[Literal["pending", "success", "failure"]] = Field( + None, description="Attack outcome (null if not yet determined)" + ) + last_message_preview: Optional[str] = Field( + None, description="Preview of the last message (truncated to ~100 chars)" + ) + message_count: int = Field(0, description="Total number of messages in the attack") + created_at: datetime = Field(..., description="Attack creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + + +# ============================================================================ +# Attack Detail (Single Attack View) +# ============================================================================ + + +class AttackDetail(BaseModel): + """Detailed view of an attack (includes all messages).""" + + attack_id: str = Field(..., description="Unique attack identifier") + name: Optional[str] = Field(None, description="Attack name/label") + target_id: str = Field(..., description="Target instance ID") + target_type: str = Field(..., description="Target type (e.g., 'azure_openai')") + outcome: Optional[Literal["pending", "success", "failure"]] = Field( + None, description="Attack outcome" + ) + prepended_conversation: List[Message] = Field( + default_factory=list, description="Prepended messages (system prompts, branching context)" + ) + messages: List[Message] = Field(default_factory=list, description="Attack messages in order") + converter_ids: List[str] = Field(default_factory=list, description="Converter instance IDs applied") + created_at: datetime = Field(..., description="Attack creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + + +# ============================================================================ +# Attack List Response (Paginated) +# ============================================================================ + + +class AttackListResponse(BaseModel): + """Paginated response for listing attacks.""" + + items: List[AttackSummary] = Field(..., description="List of attack summaries") + pagination: PaginationInfo = Field(..., description="Pagination metadata") + + +# ============================================================================ +# Create Attack +# ============================================================================ + + +class PrependedMessageRequest(BaseModel): + """A message to prepend to the attack (for system prompt/branching).""" + + role: Literal["user", "assistant", "system"] = Field(..., description="Message role") + content: str = Field(..., description="Message content (text)") + + +class CreateAttackRequest(BaseModel): + """Request to create a new attack.""" + + name: Optional[str] = Field(None, description="Attack name/label") + target_id: str = Field(..., description="Target instance ID to attack") + prepended_conversation: Optional[List[PrependedMessageRequest]] = Field( + None, description="Messages to prepend (system prompts, branching context)" + ) + converter_ids: Optional[List[str]] = Field( + None, description="Converter instance IDs to apply to user messages" + ) + + +class CreateAttackResponse(BaseModel): + """Response after creating an attack.""" + + attack_id: str = Field(..., description="Unique attack identifier") + name: Optional[str] = Field(None, description="Attack name/label") + target_id: str = Field(..., description="Target instance ID") + target_type: str = Field(..., description="Target type") + outcome: Optional[str] = Field(None, description="Attack outcome (initially null)") + prepended_conversation: List[Message] = Field( + default_factory=list, description="Prepended messages (converted to Message format)" + ) + created_at: datetime = Field(..., description="Attack creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + + +# ============================================================================ +# Update Attack +# ============================================================================ + + +class UpdateAttackRequest(BaseModel): + """Request to update an attack's outcome.""" + + outcome: Literal["pending", "success", "failure"] = Field( + ..., description="Updated attack outcome" + ) + + +# ============================================================================ +# Send Message +# ============================================================================ + + +class MessagePieceRequest(BaseModel): + """A piece of content to send in a message.""" + + data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") + content: str = Field(..., description="Content to send (text or base64 for media)") + mime_type: Optional[str] = Field(None, description="MIME type for media content") + + +class SendMessageRequest(BaseModel): + """Request to send a message within an attack.""" + + pieces: List[MessagePieceRequest] = Field(..., description="Message pieces to send") + converter_ids: Optional[List[str]] = Field( + None, description="Converter instance IDs to apply (overrides attack-level)" + ) + converters: Optional[List[InlineConverterConfig]] = Field( + None, description="Inline converter definitions (for one-off use)" + ) + + +class SendMessageResponse(BaseModel): + """Response after sending a message.""" + + user_message: Message = Field(..., description="The user message that was sent") + assistant_message: Message = Field(..., description="The assistant's response") + attack_summary: AttackSummary = Field(..., description="Updated attack summary") diff --git a/pyrit/backend/models/conversations.py b/pyrit/backend/models/conversations.py deleted file mode 100644 index f151148c36..0000000000 --- a/pyrit/backend/models/conversations.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Conversation-related request and response models. - -These models align with PyRIT's MessagePiece and Message structures. -""" - -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional - -from pydantic import BaseModel, Field - -from pyrit.models import PromptDataType, PromptResponseError - - -class ConverterConfig(BaseModel): - """Configuration for a single converter.""" - - class_name: str = Field(..., description="Converter class name (e.g., 'TranslationConverter')") - module: str = Field( - default="pyrit.prompt_converter", - description="Module containing the converter class", - ) - params: Optional[Dict[str, Any]] = Field(default=None, description="Constructor parameters") - - -# ============================================================================ -# Conversation Creation -# ============================================================================ - - -class CreateConversationRequest(BaseModel): - """Request to create a new conversation.""" - - target_class: str = Field( - ..., - description="Target class name (e.g., 'TextTarget', 'AzureOpenAIGPT4OChatTarget')", - ) - target_params: Optional[Dict[str, Any]] = Field( - None, - description="Constructor parameters for the target", - ) - labels: Optional[Dict[str, str]] = Field(None, description="Key-value labels for filtering") - - -class CreateConversationResponse(BaseModel): - """Response after creating a conversation.""" - - conversation_id: str = Field(..., description="Unique conversation identifier") - target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") - labels: Optional[Dict[str, str]] = Field(None, description="Applied labels") - created_at: datetime = Field(..., description="Creation timestamp") - - -# ============================================================================ -# System Prompt -# ============================================================================ - - -class SetSystemPromptRequest(BaseModel): - """Request to set the system prompt for a conversation.""" - - system_prompt: str = Field(..., description="The system prompt text") - - -class SystemPromptResponse(BaseModel): - """Response containing the system prompt.""" - - system_prompt: Optional[str] = Field(None, description="Current system prompt") - piece_id: Optional[str] = Field(None, description="ID of the system prompt message piece") - - -# ============================================================================ -# Converter Configuration -# ============================================================================ - - -class SetConvertersRequest(BaseModel): - """Request to set the converter chain for a conversation.""" - - converters: List[ConverterConfig] = Field(..., description="Ordered list of converters") - - -class ConvertersResponse(BaseModel): - """Response containing the converter chain.""" - - converters: List[ConverterConfig] = Field(default_factory=list, description="Current converter chain") - - -# ============================================================================ -# Message Pieces (aligned with MessagePiece) -# ============================================================================ - - -class MessagePieceInput(BaseModel): - """ - Input for a single message piece. - - Aligned with pyrit.models.MessagePiece fields. - """ - - original_value: Optional[str] = Field(None, description="Text content (for text type)") - original_value_data_type: PromptDataType = Field(..., description="Data type of the content") - file_name: Optional[str] = Field(None, description="Filename in multipart request (for file types)") - converted_value: Optional[str] = Field(None, description="Pre-converted content (if pre_converted=true)") - converted_value_data_type: Optional[PromptDataType] = Field(None, description="Data type after conversion") - converter_identifiers: Optional[List[Dict[str, Any]]] = Field( - None, description="Converters already applied (if pre_converted=true)" - ) - - -class MessagePieceResponse(BaseModel): - """ - Response model for a single message piece. - - Aligned with pyrit.models.MessagePiece fields. - """ - - id: str = Field(..., description="Unique piece identifier (UUID)") - original_value: str = Field(..., description="Original content or file path") - original_value_data_type: PromptDataType = Field(..., description="Original data type") - converted_value: str = Field(..., description="Converted content or file path") - converted_value_data_type: PromptDataType = Field(..., description="Converted data type") - converter_identifiers: List[Dict[str, Any]] = Field( - default_factory=list, description="Applied converters with params" - ) - response_error: Optional[PromptResponseError] = Field(None, description="Error type if any") - timestamp: Optional[datetime] = Field(None, description="Piece timestamp") - - -# ============================================================================ -# Messages -# ============================================================================ - -ChatMessageRole = Literal["system", "user", "assistant", "simulated_assistant", "tool", "developer"] - - -class MessageResponse(BaseModel): - """Response model for a message (group of pieces with same sequence).""" - - sequence: int = Field(..., description="Sequence number in conversation") - role: ChatMessageRole = Field(..., description="Message role") - pieces: List[MessagePieceResponse] = Field(..., description="Message content pieces") - timestamp: datetime = Field(..., description="Message timestamp") - - -class SendMessageRequest(BaseModel): - """ - Request to send a message. - - Note: For file uploads, use multipart/form-data with 'pieces' as JSON - and files attached with their filenames. - """ - - pieces: List[MessagePieceInput] = Field(..., description="Message content pieces") - pre_converted: bool = Field(False, description="If true, skip converter chain") - - -class SendMessageResponse(BaseModel): - """Response after sending a message.""" - - user_message: MessageResponse = Field(..., description="The sent user message") - assistant_message: Optional[MessageResponse] = Field(None, description="The assistant's response") - - -# ============================================================================ -# Branch -# ============================================================================ - - -class BranchConversationRequest(BaseModel): - """Request to branch a conversation.""" - - last_included_sequence: int = Field(..., description="Copy messages with sequence <= this value") - - -class BranchConversationResponse(BaseModel): - """Response after branching a conversation.""" - - conversation_id: str = Field(..., description="New conversation ID") - branched_from: Dict[str, Any] = Field(..., description="Source conversation info") - message_count: int = Field(..., description="Number of messages copied") - created_at: datetime = Field(..., description="Branch creation timestamp") - - -# ============================================================================ -# Full Conversation -# ============================================================================ - - -class ConversationResponse(BaseModel): - """Full conversation with all messages.""" - - conversation_id: str = Field(..., description="Unique conversation identifier") - target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") - labels: Optional[Dict[str, str]] = Field(None, description="Applied labels") - converters: List[ConverterConfig] = Field(default_factory=list, description="Configured converters") - created_at: datetime = Field(..., description="Creation timestamp") - messages: List[MessageResponse] = Field(default_factory=list, description="All messages in order") diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index 2bd56a3c52..f507f2d2c8 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -3,58 +3,138 @@ """ Converter-related request and response models. + +Converters have two concepts: +- Types: Static metadata bundled with frontend (from registry) +- Instances: Runtime objects created via API with specific configuration + +This module defines both the Instance models and preview functionality. +Nested converters (e.g., SelectiveTextConverter wrapping Base64Converter) are supported. """ -from typing import Any, Dict, List +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field -from pyrit.backend.models.conversations import ConverterConfig from pyrit.backend.models.registry import ConverterMetadataResponse from pyrit.models import PromptDataType -# Re-export for convenience __all__ = [ "ConverterMetadataResponse", - "ConverterListResponse", - "ConverterConfig", - "ConversionStep", - "PreviewConverterRequest", - "PreviewConverterResponse", + "ConverterInstance", + "ConverterInstanceListResponse", + "CreateConverterRequest", + "CreateConverterResponse", + "InlineConverterConfig", + "NestedConverterConfig", + "ConverterPreviewRequest", + "ConverterPreviewResponse", + "PreviewStep", ] -class ConverterListResponse(BaseModel): - """Response containing list of available converters.""" +# ============================================================================ +# Converter Instances (Runtime Objects) +# ============================================================================ + + +class InlineConverterConfig(BaseModel): + """Inline converter configuration (type + params).""" + + type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") + params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters") + + +class NestedConverterConfig(BaseModel): + """ + Converter config that may contain nested converters. + + Used for composite converters like SelectiveTextConverter that wrap other converters. + The 'converter' param can contain another NestedConverterConfig. + """ + + type: str = Field(..., description="Converter type") + params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters") + + +class ConverterInstance(BaseModel): + """A registered converter instance.""" + + converter_id: str = Field(..., description="Unique converter instance identifier") + type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") + display_name: Optional[str] = Field(None, description="Human-readable display name") + params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters (resolved)") + created_at: datetime = Field(..., description="Creation timestamp") + source: Literal["initializer", "user"] = Field(..., description="How the converter was created") + - converters: List[ConverterMetadataResponse] = Field(..., description="Available converter types") +class ConverterInstanceListResponse(BaseModel): + """Response for listing converter instances.""" + items: List[ConverterInstance] = Field(..., description="List of converter instances") -class ConversionStep(BaseModel): - """Single step in a conversion chain.""" - converter_class: str = Field(..., description="Converter class that was applied") - input: str = Field(..., description="Input to this converter") +class CreateConverterRequest(BaseModel): + """ + Request to create a new converter instance. + + Supports nested converters - if params contains a 'converter' key with + an InlineConverterConfig, the backend will create both and link them. + """ + + type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") + display_name: Optional[str] = Field(None, description="Human-readable display name") + params: Dict[str, Any] = Field( + default_factory=dict, + description="Converter parameters (may include nested 'converter' config)", + ) + + +class CreateConverterResponse(BaseModel): + """Response after creating a converter instance.""" + + converter_id: str = Field(..., description="Unique converter instance identifier") + type: str = Field(..., description="Converter type") + display_name: Optional[str] = Field(None, description="Human-readable display name") + params: Dict[str, Any] = Field(default_factory=dict, description="Resolved parameters (nested converters have IDs)") + created_converters: Optional[List[ConverterInstance]] = Field( + None, description="All converters created (including nested), ordered inner-to-outer" + ) + created_at: datetime = Field(..., description="Creation timestamp") + source: Literal["user"] = Field(default="user", description="Source is always 'user' for API-created") + + +# ============================================================================ +# Converter Preview +# ============================================================================ + + +class PreviewStep(BaseModel): + """A single step in the conversion preview.""" + + converter_id: Optional[str] = Field(None, description="Converter instance ID (if using ID)") + converter_type: str = Field(..., description="Converter type") + input_value: str = Field(..., description="Input to this converter") input_data_type: PromptDataType = Field(..., description="Input data type") - output: str = Field(..., description="Output from this converter") + output_value: str = Field(..., description="Output from this converter") output_data_type: PromptDataType = Field(..., description="Output data type") -class PreviewConverterRequest(BaseModel): - """Request to preview converter output.""" +class ConverterPreviewRequest(BaseModel): + """Request to preview converter transformation.""" - content: str = Field(..., description="Original content to convert") - data_type: PromptDataType = Field("text", description="Content data type") - converters: List[ConverterConfig] = Field(..., description="Ordered list of converters to apply") + original_value: str = Field(..., description="Text to convert") + original_value_data_type: PromptDataType = Field(default="text", description="Data type of original value") + converter_ids: Optional[List[str]] = Field(None, description="Converter instance IDs to apply") + converters: Optional[List[InlineConverterConfig]] = Field(None, description="Inline converter definitions") -class PreviewConverterResponse(BaseModel): - """Response with converter preview results.""" +class ConverterPreviewResponse(BaseModel): + """Response from converter preview.""" - original_content: str = Field(..., description="Original input content") - converted_content: str = Field(..., description="Final converted content") - converted_data_type: PromptDataType = Field(..., description="Final data type") - conversion_chain: List[ConversionStep] = Field(..., description="Step-by-step conversion results") - converter_identifiers: List[Dict[str, Any]] = Field( - ..., description="Converter identifiers for use in pre_converted requests" - ) + original_value: str = Field(..., description="Original input text") + original_value_data_type: PromptDataType = Field(..., description="Data type of original value") + converted_value: str = Field(..., description="Final converted text") + converted_value_data_type: PromptDataType = Field(..., description="Data type of converted value") + steps: List[PreviewStep] = Field(..., description="Step-by-step conversion results") diff --git a/pyrit/backend/models/memory.py b/pyrit/backend/models/memory.py deleted file mode 100644 index 98741f6eeb..0000000000 --- a/pyrit/backend/models/memory.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Memory query response models. - -Models for messages, scores, attack results, scenario results, and seeds. -""" - -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional - -from pydantic import BaseModel, Field - -from pyrit.models import PromptDataType, PromptResponseError - -# ============================================================================ -# Message Queries -# ============================================================================ - - -class MessageQueryResponse(BaseModel): - """Response model for message piece queries.""" - - id: str = Field(..., description="Message piece ID") - conversation_id: str = Field(..., description="Parent conversation ID") - sequence: int = Field(..., description="Sequence in conversation") - role: str = Field(..., description="Message role") - original_value: str = Field(..., description="Original content") - original_value_data_type: PromptDataType = Field(..., description="Original data type") - converted_value: str = Field(..., description="Converted content") - converted_value_data_type: PromptDataType = Field(..., description="Converted data type") - converter_identifiers: List[Dict[str, Any]] = Field(default_factory=list, description="Applied converters") - target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") - labels: Optional[Dict[str, str]] = Field(None, description="Message labels") - response_error: Optional[PromptResponseError] = Field(None, description="Error type if any") - timestamp: datetime = Field(..., description="Message timestamp") - - -# ============================================================================ -# Score Queries -# ============================================================================ - -ScoreType = Literal["true_false", "float_scale", "unknown"] - - -class ScoreQueryResponse(BaseModel): - """Response model for score queries.""" - - id: str = Field(..., description="Score ID") - message_piece_id: str = Field(..., description="Associated message piece ID") - score_value: str = Field(..., description="Score value ('true'/'false' or numeric)") - score_value_description: str = Field(..., description="Human-readable score description") - score_type: ScoreType = Field(..., description="Type of score") - score_category: Optional[List[str]] = Field(None, description="Score categories") - score_rationale: str = Field(..., description="Explanation for the score") - scorer_identifier: Dict[str, Any] = Field(..., description="Scorer identifier (filtered)") - objective: Optional[str] = Field(None, description="Scoring objective") - timestamp: datetime = Field(..., description="Score timestamp") - - -# ============================================================================ -# Attack Results -# ============================================================================ - -AttackOutcome = Literal["success", "failure", "undetermined"] - - -class AttackResultQueryResponse(BaseModel): - """Response model for attack result queries.""" - - id: str = Field(..., description="Attack result ID") - conversation_id: str = Field(..., description="Associated conversation ID") - objective: str = Field(..., description="Attack objective") - attack_identifier: Dict[str, Any] = Field(..., description="Attack identifier (filtered)") - outcome: Optional[str] = Field(None, description="Attack outcome (success, failure, undetermined)") - outcome_reason: Optional[str] = Field(None, description="Explanation for outcome") - executed_turns: int = Field(..., description="Number of turns executed") - execution_time_ms: int = Field(..., description="Execution time in milliseconds") - timestamp: Optional[datetime] = Field(None, description="Result timestamp") - - -# ============================================================================ -# Scenario Results -# ============================================================================ - -ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"] - - -class ScenarioResultQueryResponse(BaseModel): - """Response model for scenario result queries.""" - - id: str = Field(..., description="Scenario result ID") - scenario_name: str = Field(..., description="Scenario name") - scenario_description: Optional[str] = Field(None, description="Scenario description") - scenario_version: int = Field(..., description="Scenario version") - pyrit_version: str = Field(..., description="PyRIT version used") - run_state: ScenarioRunState = Field(..., description="Current run state") - objective_target_identifier: Dict[str, Any] = Field(..., description="Target identifier (filtered)") - labels: Optional[Dict[str, str]] = Field(None, description="Scenario labels") - number_tries: int = Field(..., description="Number of objectives attempted") - completion_time: Optional[datetime] = Field(None, description="Completion timestamp") - timestamp: datetime = Field(..., description="Creation timestamp") - - -# ============================================================================ -# Seeds -# ============================================================================ - -SeedType = Literal["prompt", "objective", "simulated_conversation"] - - -class SeedQueryResponse(BaseModel): - """Response model for seed queries.""" - - id: str = Field(..., description="Seed ID") - value: str = Field(..., description="Seed content") - data_type: PromptDataType = Field(..., description="Content data type") - name: Optional[str] = Field(None, description="Seed name") - dataset_name: Optional[str] = Field(None, description="Dataset name") - seed_type: SeedType = Field(..., description="Type of seed") - harm_categories: Optional[List[str]] = Field(None, description="Harm categories") - description: Optional[str] = Field(None, description="Seed description") - source: Optional[str] = Field(None, description="Seed source") - date_added: Optional[datetime] = Field(None, description="Date added") diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py new file mode 100644 index 0000000000..d387680cc4 --- /dev/null +++ b/pyrit/backend/models/targets.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target instance models. + +Targets have two concepts: +- Types: Static metadata bundled with frontend (from registry) +- Instances: Runtime objects created via API with specific configuration + +This module defines the Instance models for runtime target management. +""" + +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, Field + + +class TargetInstance(BaseModel): + """ + A runtime target instance. + + Created either by an initializer (at startup) or by user (via API). + """ + + target_id: str = Field(..., description="Unique target instance identifier") + type: str = Field(..., description="Target type (e.g., 'azure_openai', 'text_target')") + display_name: Optional[str] = Field(None, description="Human-readable display name") + params: Dict[str, Any] = Field(default_factory=dict, description="Target configuration (sensitive fields filtered)") + created_at: datetime = Field(..., description="Instance creation timestamp") + source: Literal["initializer", "user"] = Field(..., description="How the target was created") + + +class TargetListResponse(BaseModel): + """Response for listing target instances.""" + + items: List[TargetInstance] = Field(..., description="List of target instances") + + +class CreateTargetRequest(BaseModel): + """Request to create a new target instance.""" + + type: str = Field(..., description="Target type (e.g., 'azure_openai', 'text_target')") + display_name: Optional[str] = Field(None, description="Human-readable display name") + params: Dict[str, Any] = Field(default_factory=dict, description="Target constructor parameters") + + +class CreateTargetResponse(BaseModel): + """Response after creating a target instance.""" + + target_id: str = Field(..., description="Unique target instance identifier") + type: str = Field(..., description="Target type") + display_name: Optional[str] = Field(None, description="Human-readable display name") + params: Dict[str, Any] = Field(default_factory=dict, description="Filtered configuration (no secrets)") + created_at: datetime = Field(..., description="Instance creation timestamp") + source: Literal["user"] = Field(default="user", description="Source is always 'user' for API-created targets") diff --git a/pyrit/backend/routes/__init__.py b/pyrit/backend/routes/__init__.py index 78bfbae3f3..4f16ed7759 100644 --- a/pyrit/backend/routes/__init__.py +++ b/pyrit/backend/routes/__init__.py @@ -5,13 +5,13 @@ API route handlers. """ -from pyrit.backend.routes import conversations, converters, health, memory, registry, version +from pyrit.backend.routes import attacks, converters, health, registry, targets, version __all__ = [ - "conversations", + "attacks", "converters", "health", - "memory", "registry", + "targets", "version", ] diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py new file mode 100644 index 0000000000..23d6a316da --- /dev/null +++ b/pyrit/backend/routes/attacks.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Attack API routes. + +All interactions are modeled as "attacks" - including manual conversations. +This is the attack-centric API design. +""" + +from typing import Literal, Optional + +from fastapi import APIRouter, HTTPException, Query, status + +from pyrit.backend.models.attacks import ( + AttackDetail, + AttackListResponse, + CreateAttackRequest, + CreateAttackResponse, + SendMessageRequest, + SendMessageResponse, + UpdateAttackRequest, +) +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.services.attack_service import get_attack_service + +router = APIRouter(prefix="/attacks", tags=["attacks"]) + + +@router.get( + "", + response_model=AttackListResponse, +) +async def list_attacks( + target_id: Optional[str] = Query(None, description="Filter by target instance ID"), + outcome: Optional[Literal["pending", "success", "failure"]] = Query(None, description="Filter by outcome"), + limit: int = Query(20, ge=1, le=100, description="Maximum items per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor (attack_id)"), +) -> AttackListResponse: + """ + List attacks with optional filtering and pagination. + + Returns attack summaries (not full message content). + Use GET /attacks/{id} for full details. + + Returns: + AttackListResponse: Paginated list of attack summaries. + """ + service = get_attack_service() + return await service.list_attacks( + target_id=target_id, + outcome=outcome, + limit=limit, + cursor=cursor, + ) + + +@router.post( + "", + response_model=CreateAttackResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid request"}, + 404: {"model": ProblemDetail, "description": "Target or converter not found"}, + 422: {"model": ProblemDetail, "description": "Validation error"}, + }, +) +async def create_attack(request: CreateAttackRequest) -> CreateAttackResponse: + """ + Create a new attack. + + Establishes a new attack session with the specified target. + Optionally include prepended_conversation for system prompts or branching context. + + Returns: + CreateAttackResponse: The created attack details. + """ + service = get_attack_service() + + try: + return await service.create_attack(request) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) + + +@router.get( + "/{attack_id}", + response_model=AttackDetail, + responses={ + 404: {"model": ProblemDetail, "description": "Attack not found"}, + }, +) +async def get_attack(attack_id: str) -> AttackDetail: + """ + Get attack details including all messages. + + Returns the full attack with prepended_conversation and all messages. + + Returns: + AttackDetail: Full attack details with messages. + """ + service = get_attack_service() + + attack = await service.get_attack(attack_id) + if not attack: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Attack '{attack_id}' not found", + ) + + return attack + + +@router.patch( + "/{attack_id}", + response_model=AttackDetail, + responses={ + 404: {"model": ProblemDetail, "description": "Attack not found"}, + }, +) +async def update_attack( + attack_id: str, + request: UpdateAttackRequest, +) -> AttackDetail: + """ + Update an attack's outcome. + + Used to mark attacks as success/failure/pending. + + Returns: + AttackDetail: Updated attack details. + """ + service = get_attack_service() + + attack = await service.update_attack(attack_id, request) + if not attack: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Attack '{attack_id}' not found", + ) + + return attack + + +@router.post( + "/{attack_id}/messages", + response_model=SendMessageResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Attack not found"}, + 400: {"model": ProblemDetail, "description": "Message send failed"}, + }, +) +async def send_message( + attack_id: str, + request: SendMessageRequest, +) -> SendMessageResponse: + """ + Send a message in an attack. + + Sends the user message to the target, applies converters, and returns + both the user message and assistant response. + + Converters can be specified at three levels (in priority order): + 1. request.converter_ids - per-message converter instances + 2. request.converters - inline converter definitions + 3. attack.converter_ids - attack-level defaults + + Returns: + SendMessageResponse: The sent message and assistant response. + """ + service = get_attack_service() + + try: + return await service.send_message(attack_id, request) + except ValueError as e: + error_msg = str(e) + if "not found" in error_msg.lower(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=error_msg, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=error_msg, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to send message: {str(e)}", + ) + + +@router.delete( + "/{attack_id}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 404: {"model": ProblemDetail, "description": "Attack not found"}, + }, +) +async def delete_attack(attack_id: str) -> None: + """ + Delete an attack. + + Removes the attack and all associated messages. + """ + service = get_attack_service() + + deleted = await service.delete_attack(attack_id) + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Attack '{attack_id}' not found", + ) diff --git a/pyrit/backend/routes/conversations.py b/pyrit/backend/routes/conversations.py deleted file mode 100644 index 83de1cfd57..0000000000 --- a/pyrit/backend/routes/conversations.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Conversation API routes. - -Provides endpoints for managing interactive conversation sessions. -""" - -from typing import List - -from fastapi import APIRouter, HTTPException, status - -from pyrit.backend.models.common import ProblemDetail -from pyrit.backend.models.conversations import ( - BranchConversationRequest, - BranchConversationResponse, - ConverterConfig, - ConvertersResponse, - CreateConversationRequest, - CreateConversationResponse, - MessageResponse, - SendMessageRequest, - SendMessageResponse, - SetSystemPromptRequest, - SystemPromptResponse, -) -from pyrit.backend.services import get_conversation_service - -router = APIRouter(prefix="/conversations", tags=["conversations"]) - - -@router.post( - "", - response_model=CreateConversationResponse, - status_code=status.HTTP_201_CREATED, - responses={ - 400: {"model": ProblemDetail, "description": "Invalid request"}, - 422: {"model": ProblemDetail, "description": "Validation error"}, - }, -) -async def create_conversation(request: CreateConversationRequest) -> CreateConversationResponse: - """ - Create a new conversation session. - - Establishes a new conversation with the specified target and optional - system prompt and converters. - - Returns: - CreateConversationResponse: The created conversation details. - """ - service = get_conversation_service() - - try: - return await service.create_conversation(request) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create conversation: {str(e)}", - ) - - -@router.get( - "/{conversation_id}", - response_model=List[MessageResponse], - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - }, -) -async def get_conversation(conversation_id: str) -> List[MessageResponse]: - """ - Get all messages in a conversation. - - Returns messages in sequence order. - - Returns: - List[MessageResponse]: List of messages in the conversation. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - return await service.get_conversation_messages(conversation_id) - - -@router.post( - "/{conversation_id}/messages", - response_model=SendMessageResponse, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - 400: {"model": ProblemDetail, "description": "Message send failed"}, - }, -) -async def send_message( - conversation_id: str, - request: SendMessageRequest, -) -> SendMessageResponse: - """ - Send a message in a conversation. - - Sends the user message to the target, applies converters, and returns - both the sent message and assistant response(s). - - Returns: - SendMessageResponse: The sent message and assistant response. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - try: - return await service.send_message(conversation_id, request) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to send message: {str(e)}", - ) - - -@router.get( - "/{conversation_id}/system-prompt", - response_model=SystemPromptResponse, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - }, -) -async def get_system_prompt(conversation_id: str) -> SystemPromptResponse: - """ - Get the current system prompt for a conversation. - - Returns: - SystemPromptResponse: The current system prompt. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - return SystemPromptResponse( - system_prompt=state.system_prompt, - piece_id=None, # System prompts stored in state, not as MessagePiece - ) - - -@router.put( - "/{conversation_id}/system-prompt", - response_model=SystemPromptResponse, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - }, -) -async def update_system_prompt( - conversation_id: str, - request: SetSystemPromptRequest, -) -> SystemPromptResponse: - """ - Update the system prompt for a conversation. - - Takes effect for subsequent messages. - - Returns: - SystemPromptResponse: The updated system prompt. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - await service.update_system_prompt(conversation_id, request.system_prompt) - - return SystemPromptResponse( - system_prompt=request.system_prompt, - piece_id=None, - ) - - -@router.get( - "/{conversation_id}/converters", - response_model=ConvertersResponse, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - }, -) -async def get_converters(conversation_id: str) -> ConvertersResponse: - """ - Get the current converters for a conversation. - - Returns: - ConvertersResponse: The current converter configurations. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - return ConvertersResponse( - converters=state.converters, - ) - - -@router.put( - "/{conversation_id}/converters", - response_model=ConvertersResponse, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - 400: {"model": ProblemDetail, "description": "Invalid converter configuration"}, - }, -) -async def update_converters( - conversation_id: str, - converters: List[ConverterConfig], -) -> ConvertersResponse: - """ - Update the converters for a conversation. - - Replaces all current converters with the provided list. - Takes effect for subsequent messages. - - Returns: - ConvertersResponse: The updated converter configurations. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - try: - await service.update_converters(conversation_id, converters) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - - return ConvertersResponse( - converters=converters, - ) - - -@router.post( - "/{conversation_id}/branch", - response_model=BranchConversationResponse, - status_code=status.HTTP_201_CREATED, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - 400: {"model": ProblemDetail, "description": "Invalid branch request"}, - }, -) -async def branch_conversation( - conversation_id: str, - request: BranchConversationRequest, -) -> BranchConversationResponse: - """ - Branch a conversation from a specific point. - - Creates a new conversation with messages copied up to and including - the specified sequence number. The original conversation is unchanged. - - Returns: - BranchConversationResponse: The new branched conversation details. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - try: - return await service.branch_conversation(conversation_id, request) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - - -@router.delete( - "/{conversation_id}", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - 404: {"model": ProblemDetail, "description": "Conversation not found"}, - }, -) -async def delete_conversation(conversation_id: str) -> None: - """ - Delete a conversation session. - - Cleans up in-memory resources. Messages remain in memory database. - """ - service = get_conversation_service() - - state = await service.get_conversation(conversation_id) - if not state: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Conversation {conversation_id} not found", - ) - - service.cleanup_conversation(conversation_id) diff --git a/pyrit/backend/routes/converters.py b/pyrit/backend/routes/converters.py index 2790598ab0..318f251b46 100644 --- a/pyrit/backend/routes/converters.py +++ b/pyrit/backend/routes/converters.py @@ -4,41 +4,53 @@ """ Converters API routes. -Provides endpoints for listing and previewing prompt converters. +Provides endpoints for: +- Listing converter types (metadata from registry) +- Managing converter instances (runtime objects) +- Previewing converter transformations """ -from typing import List, Optional +from typing import List, Literal, Optional from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.converters import ( - ConversionStep, + ConverterInstance, + ConverterInstanceListResponse, ConverterMetadataResponse, - PreviewConverterRequest, - PreviewConverterResponse, + ConverterPreviewRequest, + ConverterPreviewResponse, + CreateConverterRequest, + CreateConverterResponse, ) -from pyrit.backend.services import get_conversation_service, get_registry_service +from pyrit.backend.services import get_registry_service +from pyrit.backend.services.converter_service import get_converter_service router = APIRouter(prefix="/converters", tags=["converters"]) +# ============================================================================ +# Converter Types (from registry) +# ============================================================================ + + @router.get( - "", + "/types", response_model=List[ConverterMetadataResponse], ) -async def list_converters( +async def list_converter_types( is_llm_based: Optional[bool] = Query(None, description="Filter by LLM-based converters"), is_deterministic: Optional[bool] = Query(None, description="Filter by deterministic converters"), ) -> List[ConverterMetadataResponse]: """ - List available converters. + List available converter types. - Returns metadata about all available prompt converters, optionally - filtered by LLM-based status or determinism. + Returns metadata about all available prompt converter types (not instances). + For instances, use GET /converters/instances. Returns: - List[ConverterMetadataResponse]: List of converter metadata. + List[ConverterMetadataResponse]: List of converter type metadata. """ service = get_registry_service() @@ -48,50 +60,157 @@ async def list_converters( ) +# ============================================================================ +# Converter Instances (runtime objects) +# ============================================================================ + + +@router.get( + "/instances", + response_model=ConverterInstanceListResponse, +) +async def list_converter_instances( + source: Optional[Literal["initializer", "user"]] = Query( + None, description="Filter by source (initializer or user)" + ), +) -> ConverterInstanceListResponse: + """ + List converter instances. + + Returns all registered converter instances. + + Returns: + ConverterInstanceListResponse: List of converter instances. + """ + service = get_converter_service() + return await service.list_converters(source=source) + + +@router.post( + "/instances", + response_model=CreateConverterResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid converter type or parameters"}, + 422: {"model": ProblemDetail, "description": "Validation error"}, + }, +) +async def create_converter_instance(request: CreateConverterRequest) -> CreateConverterResponse: + """ + Create a new converter instance. + + Supports nested converters - if params contains a 'converter' key with + a type/params object, the nested converter will be created first and + linked to the outer converter. + + Example for SelectiveTextConverter: + ```json + { + "type": "selective_text", + "params": { + "pattern": "\\[CONVERT\\]", + "converter": { + "type": "base64", + "params": {} + } + } + } + ``` + + Returns: + CreateConverterResponse: The created converter instance details. + """ + service = get_converter_service() + + try: + return await service.create_converter(request) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create converter: {str(e)}", + ) + + +@router.get( + "/instances/{converter_id}", + response_model=ConverterInstance, + responses={ + 404: {"model": ProblemDetail, "description": "Converter not found"}, + }, +) +async def get_converter_instance(converter_id: str) -> ConverterInstance: + """ + Get a converter instance by ID. + + Returns: + ConverterInstance: The converter instance details. + """ + service = get_converter_service() + + converter = await service.get_converter(converter_id) + if not converter: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Converter '{converter_id}' not found", + ) + + return converter + + +@router.delete( + "/instances/{converter_id}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 404: {"model": ProblemDetail, "description": "Converter not found"}, + }, +) +async def delete_converter_instance(converter_id: str) -> None: + """ + Delete a converter instance. + + Note: Converters in use by active attacks cannot be deleted. + """ + service = get_converter_service() + + deleted = await service.delete_converter(converter_id) + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Converter '{converter_id}' not found", + ) + + +# ============================================================================ +# Converter Preview +# ============================================================================ + + @router.post( "/preview", - response_model=PreviewConverterResponse, + response_model=ConverterPreviewResponse, responses={ 400: {"model": ProblemDetail, "description": "Invalid converter configuration"}, }, ) -async def preview_converters(request: PreviewConverterRequest) -> PreviewConverterResponse: +async def preview_conversion(request: ConverterPreviewRequest) -> ConverterPreviewResponse: """ - Preview text through a converter pipeline. + Preview conversion through a converter pipeline. - Applies the specified converters in sequence and returns - intermediate results at each step. Useful for testing converter - configurations before applying to conversations. + Applies converters to the input and returns step-by-step results. + Can use either converter_ids (existing instances) or inline converters. Returns: - PreviewConverterResponse: Original content, converted content, and conversion steps. + ConverterPreviewResponse: Original, converted values, and conversion steps. """ - service = get_conversation_service() + service = get_converter_service() try: - steps_data = await service.preview_converters(request.content, request.converters) - - steps = [ - ConversionStep( - converter_class=s["converter_type"], - input=s["input"], - input_data_type=s.get("input_data_type", "text"), - output=s["output"], - output_data_type=s.get("output_type", "text"), - ) - for s in steps_data - ] - - final_output = steps[-1].output if steps else request.content - final_data_type = steps[-1].output_data_type if steps else request.data_type - - return PreviewConverterResponse( - original_content=request.content, - converted_content=final_output, - converted_data_type=final_data_type, - conversion_chain=steps, - converter_identifiers=[{"class_name": s.converter_class} for s in steps], - ) + return await service.preview_conversion(request) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/pyrit/backend/routes/memory.py b/pyrit/backend/routes/memory.py deleted file mode 100644 index 775e34ae22..0000000000 --- a/pyrit/backend/routes/memory.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Memory API routes. - -Provides endpoints for querying stored data with pagination. -""" - -from datetime import datetime -from typing import List, Optional - -from fastapi import APIRouter, Query - -from pyrit.backend.models.common import PaginatedResponse -from pyrit.backend.models.memory import ( - AttackResultQueryResponse, - MessageQueryResponse, - ScenarioResultQueryResponse, - ScoreQueryResponse, - SeedQueryResponse, -) -from pyrit.backend.services import get_memory_service - -router = APIRouter(prefix="/memory", tags=["memory"]) - - -@router.get( - "/messages", - response_model=PaginatedResponse[MessageQueryResponse], -) -async def query_messages( - conversation_id: Optional[str] = Query(None, description="Filter by conversation ID"), - role: Optional[str] = Query(None, description="Filter by role (user/assistant/system)"), - data_type: Optional[str] = Query(None, description="Filter by data type (text/image_path/audio_path)"), - harm_category: Optional[List[str]] = Query(None, description="Filter by harm categories"), - response_error: Optional[str] = Query(None, description="Filter by response error type"), - start_time: Optional[datetime] = Query(None, description="Messages after this time"), - end_time: Optional[datetime] = Query(None, description="Messages before this time"), - limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor"), -) -> PaginatedResponse[MessageQueryResponse]: - """ - Query message pieces with pagination. - - Returns messages matching the specified filters, ordered by timestamp descending. - Use cursor for pagination through large result sets. - - Returns: - PaginatedResponse[MessageQueryResponse]: Paginated list of messages. - """ - service = get_memory_service() - - return await service.get_messages( - conversation_id=conversation_id, - role=role, - harm_categories=harm_category, - data_type=data_type, - response_error=response_error, - start_time=start_time, - end_time=end_time, - limit=limit, - cursor=cursor, - ) - - -@router.get( - "/scores", - response_model=PaginatedResponse[ScoreQueryResponse], -) -async def query_scores( - message_id: Optional[str] = Query(None, description="Filter by message piece ID"), - conversation_id: Optional[str] = Query(None, description="Filter by conversation ID"), - score_type: Optional[str] = Query(None, description="Filter by score type"), - scorer_type: Optional[str] = Query(None, description="Filter by scorer class name"), - start_time: Optional[datetime] = Query(None, description="Scores after this time"), - end_time: Optional[datetime] = Query(None, description="Scores before this time"), - limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor"), -) -> PaginatedResponse[ScoreQueryResponse]: - """ - Query scores with pagination. - - Returns scores matching the specified filters, ordered by timestamp descending. - - Returns: - PaginatedResponse[ScoreQueryResponse]: Paginated list of scores. - """ - service = get_memory_service() - - return await service.get_scores( - message_id=message_id, - conversation_id=conversation_id, - score_type=score_type, - scorer_type=scorer_type, - start_time=start_time, - end_time=end_time, - limit=limit, - cursor=cursor, - ) - - -@router.get( - "/attack-results", - response_model=PaginatedResponse[AttackResultQueryResponse], -) -async def query_attack_results( - conversation_id: Optional[str] = Query(None, description="Filter by conversation ID"), - outcome: Optional[str] = Query(None, description="Filter by outcome"), - attack_type: Optional[str] = Query(None, description="Filter by attack class name"), - objective: Optional[str] = Query(None, description="Search by objective text"), - min_turns: Optional[int] = Query(None, ge=1, description="Minimum executed turns"), - max_turns: Optional[int] = Query(None, ge=1, description="Maximum executed turns"), - start_time: Optional[datetime] = Query(None, description="Results after this time"), - end_time: Optional[datetime] = Query(None, description="Results before this time"), - limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor"), -) -> PaginatedResponse[AttackResultQueryResponse]: - """ - Query attack results with pagination. - - Returns attack results matching the specified filters, ordered by timestamp descending. - - Returns: - PaginatedResponse[AttackResultQueryResponse]: Paginated list of attack results. - """ - service = get_memory_service() - - return await service.get_attack_results( - conversation_id=conversation_id, - outcome=outcome, - attack_type=attack_type, - objective=objective, - min_turns=min_turns, - max_turns=max_turns, - start_time=start_time, - end_time=end_time, - limit=limit, - cursor=cursor, - ) - - -@router.get( - "/scenario-results", - response_model=PaginatedResponse[ScenarioResultQueryResponse], -) -async def query_scenario_results( - scenario_name: Optional[str] = Query(None, description="Filter by scenario name"), - run_state: Optional[str] = Query(None, description="Filter by run state"), - start_time: Optional[datetime] = Query(None, description="Results after this time"), - end_time: Optional[datetime] = Query(None, description="Results before this time"), - limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor"), -) -> PaginatedResponse[ScenarioResultQueryResponse]: - """ - Query scenario results with pagination. - - Returns scenario results matching the specified filters, ordered by timestamp descending. - - Returns: - PaginatedResponse[ScenarioResultQueryResponse]: Paginated list of scenario results. - """ - service = get_memory_service() - - return await service.get_scenario_results( - scenario_name=scenario_name, - run_state=run_state, - start_time=start_time, - end_time=end_time, - limit=limit, - cursor=cursor, - ) - - -@router.get( - "/seeds", - response_model=PaginatedResponse[SeedQueryResponse], -) -async def query_seeds( - dataset_name: Optional[str] = Query(None, description="Filter by dataset name"), - seed_type: Optional[str] = Query(None, description="Filter by seed type"), - harm_category: Optional[List[str]] = Query(None, description="Filter by harm categories"), - data_type: Optional[str] = Query(None, description="Filter by data type"), - search: Optional[str] = Query(None, description="Search in seed value text"), - limit: int = Query(50, ge=1, le=200, description="Maximum results per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor"), -) -> PaginatedResponse[SeedQueryResponse]: - """ - Query seeds with pagination. - - Returns seeds matching the specified filters, ordered by date_added descending. - - Returns: - PaginatedResponse[SeedQueryResponse]: Paginated list of seeds. - """ - service = get_memory_service() - - return await service.get_seeds( - dataset_name=dataset_name, - seed_type=seed_type, - harm_categories=harm_category, - data_type=data_type, - search=search, - limit=limit, - cursor=cursor, - ) diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py new file mode 100644 index 0000000000..43a3465901 --- /dev/null +++ b/pyrit/backend/routes/targets.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target instance API routes. + +Targets have two concepts: +- Types: Available via /api/registry/targets (static metadata) +- Instances: Runtime objects created via this API + +This module handles target instances (runtime objects). +""" + +from typing import Literal, Optional + +from fastapi import APIRouter, HTTPException, Query, status + +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.models.targets import ( + CreateTargetRequest, + CreateTargetResponse, + TargetInstance, + TargetListResponse, +) +from pyrit.backend.services.target_service import get_target_service + +router = APIRouter(prefix="/targets", tags=["targets"]) + + +@router.get( + "", + response_model=TargetListResponse, +) +async def list_targets( + source: Optional[Literal["initializer", "user"]] = Query( + None, description="Filter by source (initializer or user)" + ), +) -> TargetListResponse: + """ + List target instances. + + Returns all registered target instances. Use source filter to distinguish + between initializer-created (startup) and user-created (API) targets. + + Returns: + TargetListResponse: List of target instances. + """ + service = get_target_service() + return await service.list_targets(source=source) + + +@router.post( + "", + response_model=CreateTargetResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid target type or parameters"}, + 422: {"model": ProblemDetail, "description": "Validation error"}, + }, +) +async def create_target(request: CreateTargetRequest) -> CreateTargetResponse: + """ + Create a new target instance. + + Instantiates a target with the given type and parameters. + The target becomes available for use in attacks. + + Note: Sensitive parameters (API keys, tokens) are filtered from the response. + + Returns: + CreateTargetResponse: The created target instance details. + """ + service = get_target_service() + + try: + return await service.create_target(request) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create target: {str(e)}", + ) + + +@router.get( + "/{target_id}", + response_model=TargetInstance, + responses={ + 404: {"model": ProblemDetail, "description": "Target not found"}, + }, +) +async def get_target(target_id: str) -> TargetInstance: + """ + Get a target instance by ID. + + Returns: + TargetInstance: The target instance details. + """ + service = get_target_service() + + target = await service.get_target(target_id) + if not target: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Target '{target_id}' not found", + ) + + return target + + +@router.delete( + "/{target_id}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 404: {"model": ProblemDetail, "description": "Target not found"}, + }, +) +async def delete_target(target_id: str) -> None: + """ + Delete a target instance. + + Note: Targets in use by active attacks cannot be deleted. + """ + service = get_target_service() + + deleted = await service.delete_target(target_id) + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Target '{target_id}' not found", + ) diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index 8536741133..c748793055 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -7,26 +7,32 @@ Provides business logic layer for API routes. """ -from pyrit.backend.services.conversation_service import ( - ConversationService, - ConversationState, - get_conversation_service, +from pyrit.backend.services.attack_service import ( + AttackService, + AttackState, + get_attack_service, ) -from pyrit.backend.services.memory_service import ( - MemoryService, - get_memory_service, +from pyrit.backend.services.converter_service import ( + ConverterService, + get_converter_service, ) from pyrit.backend.services.registry_service import ( RegistryService, get_registry_service, ) +from pyrit.backend.services.target_service import ( + TargetService, + get_target_service, +) __all__ = [ - "ConversationService", - "ConversationState", - "get_conversation_service", - "MemoryService", - "get_memory_service", + "AttackService", + "AttackState", + "get_attack_service", + "ConverterService", + "get_converter_service", "RegistryService", "get_registry_service", + "TargetService", + "get_target_service", ] diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py new file mode 100644 index 0000000000..1c8ece7976 --- /dev/null +++ b/pyrit/backend/services/attack_service.py @@ -0,0 +1,458 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Attack service for managing attacks. + +All interactions are modeled as "attacks" - this is the attack-centric API design. +Handles attack lifecycle, message sending, prepended conversations, and scoring. +""" + +import uuid +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel + +from pyrit.backend.models.attacks import ( + AttackDetail, + AttackListResponse, + AttackSummary, + CreateAttackRequest, + CreateAttackResponse, + Message, + MessagePiece, + MessagePieceRequest, + PrependedMessageRequest, + Score, + SendMessageRequest, + SendMessageResponse, + UpdateAttackRequest, +) +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.services.converter_service import get_converter_service +from pyrit.backend.services.target_service import get_target_service +from pyrit.memory import CentralMemory + + +class AttackState(BaseModel): + """Internal state for an active attack.""" + + attack_id: str + name: Optional[str] = None + target_id: str + target_type: str + outcome: Optional[Literal["pending", "success", "failure"]] = None + prepended_conversation: List[Message] = [] + converter_ids: List[str] = [] + message_count: int = 0 + created_at: datetime + updated_at: datetime + + +class AttackService: + """Service for managing attacks.""" + + def __init__(self) -> None: + """Initialize the attack service.""" + self._memory = CentralMemory.get_memory_instance() + # Active attack states + self._attacks: Dict[str, AttackState] = {} + # Messages by attack ID (in-memory for now) + self._messages: Dict[str, List[Message]] = defaultdict(list) + + async def list_attacks( + self, + target_id: Optional[str] = None, + outcome: Optional[Literal["pending", "success", "failure"]] = None, + limit: int = 20, + cursor: Optional[str] = None, + ) -> AttackListResponse: + """ + List attacks with optional filtering and pagination. + + Args: + target_id: Filter by target instance ID + outcome: Filter by outcome + limit: Maximum items per page + cursor: Pagination cursor + + Returns: + AttackListResponse: Paginated list of attack summaries + """ + attacks = list(self._attacks.values()) + + # Apply filters + if target_id: + attacks = [a for a in attacks if a.target_id == target_id] + if outcome: + attacks = [a for a in attacks if a.outcome == outcome] + + # Sort by updated_at descending + attacks.sort(key=lambda a: a.updated_at, reverse=True) + + # Simple cursor-based pagination (cursor is the attack_id) + start_idx = 0 + if cursor: + for i, attack in enumerate(attacks): + if attack.attack_id == cursor: + start_idx = i + 1 + break + + page = attacks[start_idx : start_idx + limit] + has_more = len(attacks) > start_idx + limit + + summaries = [] + for attack in page: + messages = self._messages.get(attack.attack_id, []) + last_message_preview = None + if messages: + last_msg = messages[-1] + if last_msg.pieces: + preview_text = last_msg.pieces[0].converted_value + last_message_preview = preview_text[:100] + "..." if len(preview_text) > 100 else preview_text + + summaries.append( + AttackSummary( + attack_id=attack.attack_id, + name=attack.name, + target_id=attack.target_id, + target_type=attack.target_type, + outcome=attack.outcome, + last_message_preview=last_message_preview, + message_count=len(messages), + created_at=attack.created_at, + updated_at=attack.updated_at, + ) + ) + + next_cursor = page[-1].attack_id if has_more and page else None + + return AttackListResponse( + items=summaries, + pagination=PaginationInfo( + limit=limit, + has_more=has_more, + next_cursor=next_cursor, + prev_cursor=cursor, + ), + ) + + async def get_attack(self, attack_id: str) -> Optional[AttackDetail]: + """ + Get attack details including all messages. + + Args: + attack_id: Attack ID + + Returns: + AttackDetail or None if not found + """ + state = self._attacks.get(attack_id) + if not state: + return None + + messages = self._messages.get(attack_id, []) + + return AttackDetail( + attack_id=state.attack_id, + name=state.name, + target_id=state.target_id, + target_type=state.target_type, + outcome=state.outcome, + prepended_conversation=state.prepended_conversation, + messages=messages, + converter_ids=state.converter_ids, + created_at=state.created_at, + updated_at=state.updated_at, + ) + + async def create_attack( + self, + request: CreateAttackRequest, + ) -> CreateAttackResponse: + """ + Create a new attack. + + Args: + request: Attack creation request + + Returns: + CreateAttackResponse: Created attack details + """ + target_service = get_target_service() + + # Validate target exists + target_instance = await target_service.get_target(request.target_id) + if not target_instance: + raise ValueError(f"Target instance '{request.target_id}' not found") + + attack_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + # Convert prepended messages to Message format + prepended_messages: List[Message] = [] + if request.prepended_conversation: + for i, prep_msg in enumerate(request.prepended_conversation): + msg = Message( + message_id=str(uuid.uuid4()), + turn_number=0, # Prepended messages are turn 0 + role=prep_msg.role, + pieces=[ + MessagePiece( + piece_id=str(uuid.uuid4()), + data_type="text", + original_value=prep_msg.content, + converted_value=prep_msg.content, + scores=[], + ) + ], + created_at=now, + ) + prepended_messages.append(msg) + + # Validate converter IDs if provided + if request.converter_ids: + converter_service = get_converter_service() + for conv_id in request.converter_ids: + if await converter_service.get_converter(conv_id) is None: + raise ValueError(f"Converter instance '{conv_id}' not found") + + state = AttackState( + attack_id=attack_id, + name=request.name, + target_id=request.target_id, + target_type=target_instance.type, + outcome=None, + prepended_conversation=prepended_messages, + converter_ids=request.converter_ids or [], + message_count=0, + created_at=now, + updated_at=now, + ) + self._attacks[attack_id] = state + + return CreateAttackResponse( + attack_id=attack_id, + name=request.name, + target_id=request.target_id, + target_type=target_instance.type, + outcome=None, + prepended_conversation=prepended_messages, + created_at=now, + updated_at=now, + ) + + async def update_attack( + self, + attack_id: str, + request: UpdateAttackRequest, + ) -> Optional[AttackDetail]: + """ + Update an attack's outcome. + + Args: + attack_id: Attack ID + request: Update request with outcome + + Returns: + Updated AttackDetail or None if not found + """ + state = self._attacks.get(attack_id) + if not state: + return None + + state.outcome = request.outcome + state.updated_at = datetime.now(timezone.utc) + + return await self.get_attack(attack_id) + + async def send_message( + self, + attack_id: str, + request: SendMessageRequest, + ) -> SendMessageResponse: + """ + Send a message in an attack and get response. + + Args: + attack_id: Attack ID + request: Message send request + + Returns: + SendMessageResponse: User and assistant messages + """ + state = self._attacks.get(attack_id) + if not state: + raise ValueError(f"Attack '{attack_id}' not found") + + target_service = get_target_service() + converter_service = get_converter_service() + + target_obj = target_service.get_target_object(state.target_id) + if not target_obj: + raise ValueError(f"Target object for '{state.target_id}' not found") + + now = datetime.now(timezone.utc) + state.message_count += 1 + user_turn = state.message_count + + # Determine which converters to use + converters = [] + if request.converter_ids: + converters = converter_service.get_converter_objects_for_ids(request.converter_ids) + elif request.converters: + converters = converter_service.instantiate_inline_converters(request.converters) + elif state.converter_ids: + converters = converter_service.get_converter_objects_for_ids(state.converter_ids) + + # Build user message pieces + user_pieces: List[MessagePiece] = [] + for piece_req in request.pieces: + original_value = piece_req.content + converted_value = original_value + + # Apply converters + for converter in converters: + result = await converter.convert_async(prompt=converted_value) + converted_value = result.output_text + + user_pieces.append( + MessagePiece( + piece_id=str(uuid.uuid4()), + data_type=piece_req.data_type, + original_value=original_value, + original_value_mime_type=piece_req.mime_type, + converted_value=converted_value, + converted_value_mime_type=piece_req.mime_type, + scores=[], + ) + ) + + user_message = Message( + message_id=str(uuid.uuid4()), + turn_number=user_turn, + role="user", + pieces=user_pieces, + created_at=now, + ) + + # Store user message + self._messages[attack_id].append(user_message) + + # Build conversation for target (prepended + all messages) + from pyrit.models import Message as PyritMessage, MessagePiece as PyritMessagePiece + + # Create prompt pieces for target + user_prompt_pieces = [] + for piece in user_pieces: + pyrit_piece = PyritMessagePiece( + role="user", + original_value=piece.original_value or "", + original_value_data_type=piece.data_type, + converted_value=piece.converted_value, + converted_value_data_type=piece.data_type, + conversation_id=attack_id, + sequence=user_turn, + ) + user_prompt_pieces.append(pyrit_piece) + + user_pyrit_message = PyritMessage(user_prompt_pieces) + + # Send to target + response_messages = await target_obj.send_prompt_async(message=user_pyrit_message) + + # Build assistant response + state.message_count += 1 + assistant_turn = state.message_count + + assistant_pieces: List[MessagePiece] = [] + if response_messages: + for resp_msg in response_messages: + for resp_piece in resp_msg.message_pieces: + assistant_pieces.append( + MessagePiece( + piece_id=str(uuid.uuid4()), + data_type=resp_piece.converted_value_data_type or "text", + original_value=resp_piece.original_value, + converted_value=resp_piece.converted_value or "", + scores=[], + ) + ) + + assistant_message = Message( + message_id=str(uuid.uuid4()), + turn_number=assistant_turn, + role="assistant", + pieces=assistant_pieces if assistant_pieces else [ + MessagePiece( + piece_id=str(uuid.uuid4()), + data_type="text", + converted_value="", + scores=[], + ) + ], + created_at=datetime.now(timezone.utc), + ) + + # Store assistant message + self._messages[attack_id].append(assistant_message) + + # Update attack timestamp + state.updated_at = datetime.now(timezone.utc) + + # Build summary + messages = self._messages[attack_id] + last_message_preview = None + if messages: + last_msg = messages[-1] + if last_msg.pieces: + preview_text = last_msg.pieces[0].converted_value + last_message_preview = preview_text[:100] + "..." if len(preview_text) > 100 else preview_text + + attack_summary = AttackSummary( + attack_id=state.attack_id, + name=state.name, + target_id=state.target_id, + target_type=state.target_type, + outcome=state.outcome, + last_message_preview=last_message_preview, + message_count=len(messages), + created_at=state.created_at, + updated_at=state.updated_at, + ) + + return SendMessageResponse( + user_message=user_message, + assistant_message=assistant_message, + attack_summary=attack_summary, + ) + + async def delete_attack(self, attack_id: str) -> bool: + """ + Delete an attack. + + Args: + attack_id: Attack ID + + Returns: + True if deleted, False if not found + """ + if attack_id in self._attacks: + del self._attacks[attack_id] + self._messages.pop(attack_id, None) + return True + return False + + +# Global service instance +_attack_service: Optional[AttackService] = None + + +def get_attack_service() -> AttackService: + """Get the global attack service instance.""" + global _attack_service + if _attack_service is None: + _attack_service = AttackService() + return _attack_service diff --git a/pyrit/backend/services/conversation_service.py b/pyrit/backend/services/conversation_service.py deleted file mode 100644 index 24438de473..0000000000 --- a/pyrit/backend/services/conversation_service.py +++ /dev/null @@ -1,508 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Conversation service for managing interactive sessions. - -Handles conversation lifecycle, message sending, branching, and converter management. -""" - -import importlib -import uuid -from collections import defaultdict -from datetime import datetime -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel - -from pyrit.backend.models.common import filter_sensitive_fields -from pyrit.backend.models.conversations import ( - BranchConversationRequest, - BranchConversationResponse, - ConverterConfig, - CreateConversationRequest, - CreateConversationResponse, - MessagePieceResponse, - MessageResponse, - SendMessageRequest, - SendMessageResponse, -) -from pyrit.memory import CentralMemory -from pyrit.models import Message, MessagePiece - - -class ConversationState(BaseModel): - """In-memory state for an active conversation.""" - - conversation_id: str - target_class: str - target_identifier: Dict[str, Any] - system_prompt: Optional[str] = None - labels: Optional[Dict[str, str]] = None - converters: List[ConverterConfig] = [] - created_at: datetime - message_count: int = 0 - - -class ConversationService: - """Service for managing conversation sessions.""" - - def __init__(self) -> None: - """Initialize the conversation service.""" - self._memory = CentralMemory.get_memory_instance() - # In-memory conversation state (for active sessions) - self._active_conversations: Dict[str, ConversationState] = {} - # Instantiated converters by conversation - self._converter_instances: Dict[str, List[Any]] = {} - # Instantiated targets by conversation - self._target_instances: Dict[str, Any] = {} - - def _instantiate_target_from_class(self, target_class: str, target_params: Optional[Dict[str, Any]]) -> Any: - """ - Instantiate a target from its class name. - - Args: - target_class: Target class name (e.g., 'TextTarget'). - target_params: Constructor parameters. - - Returns: - Instantiated target object. - """ - # Import the target class dynamically - module = importlib.import_module("pyrit.prompt_target") - cls = getattr(module, target_class, None) - - if cls is None: - raise ValueError(f"Target class '{target_class}' not found in pyrit.prompt_target") - - params = target_params or {} - return cls(**params) - - def _instantiate_converters(self, converter_configs: List[ConverterConfig]) -> List[Any]: - """ - Instantiate converters from their configurations. - - Args: - converter_configs: List of converter configurations. - - Returns: - List of instantiated converter objects. - """ - converters = [] - for config in converter_configs: - module = importlib.import_module(config.module) - converter_class = getattr(module, config.class_name) - params = config.params or {} - converter = converter_class(**params) - converters.append(converter) - - return converters - - async def create_conversation(self, request: CreateConversationRequest) -> CreateConversationResponse: - """ - Create a new conversation session. - - Args: - request: Conversation creation request. - - Returns: - Created conversation response with ID. - """ - conversation_id = str(uuid.uuid4()) - now = datetime.utcnow() - - # Instantiate the target - target = self._instantiate_target_from_class(request.target_class, request.target_params) - self._target_instances[conversation_id] = target - - # Get the target's identifier - target_identifier = target.get_identifier() if hasattr(target, "get_identifier") else {} - - # Store conversation state - state = ConversationState( - conversation_id=conversation_id, - target_class=request.target_class, - target_identifier=filter_sensitive_fields(target_identifier), - labels=request.labels, - converters=[], - created_at=now, - ) - self._active_conversations[conversation_id] = state - - return CreateConversationResponse( - conversation_id=conversation_id, - target_identifier=state.target_identifier, - labels=state.labels, - created_at=now, - ) - - async def get_conversation(self, conversation_id: str) -> Optional[ConversationState]: - """ - Get conversation state by ID. - - Returns: - Optional[ConversationState]: The conversation state or None if not found. - """ - return self._active_conversations.get(conversation_id) - - async def get_conversation_messages(self, conversation_id: str) -> List[MessageResponse]: - """ - Get all messages in a conversation. - - Args: - conversation_id: The conversation ID. - - Returns: - List of messages (grouped pieces) in order. - """ - pieces = self._memory.get_message_pieces(conversation_id=conversation_id) - - # Sort by sequence - pieces = sorted(pieces, key=lambda p: p.sequence) - - # Group pieces by sequence - by_sequence: Dict[int, List[Any]] = defaultdict(list) - for p in pieces: - by_sequence[p.sequence].append(p) - - messages = [] - for seq in sorted(by_sequence.keys()): - seq_pieces = by_sequence[seq] - if not seq_pieces: - continue - - first_piece = seq_pieces[0] - message_pieces = [ - MessagePieceResponse( - id=str(p.id) if hasattr(p, "id") and p.id else str(uuid.uuid4()), - original_value=p.original_value or "", - original_value_data_type=p.original_value_data_type, - converted_value=p.converted_value or "", - converted_value_data_type=p.converted_value_data_type, - converter_identifiers=p.converter_identifiers or [], - response_error=p.response_error if hasattr(p, "response_error") else None, - timestamp=p.timestamp, - ) - for p in seq_pieces - ] - - messages.append( - MessageResponse( - sequence=seq, - role=first_piece.role, - pieces=message_pieces, - timestamp=first_piece.timestamp, - ) - ) - - return messages - - async def send_message( - self, - conversation_id: str, - request: SendMessageRequest, - ) -> SendMessageResponse: - """ - Send a message to the target and get a response. - - This is a simplified stub - real implementation would involve - creating MessagePiece objects, applying converters, and calling target. - - Args: - conversation_id: The conversation ID. - request: Message send request. - - Returns: - Response containing sent and received messages. - """ - state = self._active_conversations.get(conversation_id) - if not state: - raise ValueError(f"Conversation {conversation_id} not found") - - target = self._target_instances.get(conversation_id) - if not target: - raise ValueError(f"Target for conversation {conversation_id} not found") - - now = datetime.utcnow() - state.message_count += 1 - user_seq = state.message_count - - # Get converters if any - converters = self._converter_instances.get(conversation_id, []) - - # Build user message pieces - user_pieces_response = [] - user_piece_objs = [] - - for piece_input in request.pieces: - original_value = piece_input.original_value or "" - original_type = piece_input.original_value_data_type - converted_value = piece_input.converted_value or original_value - converted_type = piece_input.converted_value_data_type or original_type - converter_ids = piece_input.converter_identifiers or [] - - # Apply converters if not pre-converted - if not request.pre_converted and converters: - for converter in converters: - result = await converter.convert_async(prompt=converted_value) - converted_value = result.output_text - converted_type = result.output_type - converter_ids.append(converter.get_identifier()) - - piece_id = str(uuid.uuid4()) - user_pieces_response.append( - MessagePieceResponse( - id=piece_id, - original_value=original_value, - original_value_data_type=original_type, - converted_value=converted_value, - converted_value_data_type=converted_type, - converter_identifiers=converter_ids, - response_error=None, - timestamp=now, - ) - ) - - # Create actual MessagePiece for target - user_piece_objs.append( - MessagePiece( - role="user", - original_value=original_value, - original_value_data_type=original_type, - converted_value=converted_value, - converted_value_data_type=converted_type, - converter_identifiers=converter_ids if converter_ids else None, - prompt_target_identifier=target.get_identifier(), - conversation_id=conversation_id, - sequence=user_seq, - ) - ) - - user_message_response = MessageResponse( - sequence=user_seq, - role="user", - pieces=user_pieces_response, - timestamp=now, - ) - - # Send to target - user_message_obj = Message(user_piece_objs) - response_messages = await target.send_prompt_async(message=user_message_obj) - - # Build assistant response - assistant_message_response = None - if response_messages: - state.message_count += 1 - assistant_seq = state.message_count - - assistant_pieces = [] - for resp_message in response_messages: - for resp_piece in resp_message.message_pieces: - assistant_pieces.append( - MessagePieceResponse( - id=str(resp_piece.id) if hasattr(resp_piece, "id") else str(uuid.uuid4()), - original_value=resp_piece.original_value or "", - original_value_data_type=resp_piece.original_value_data_type, - converted_value=resp_piece.converted_value or "", - converted_value_data_type=resp_piece.converted_value_data_type, - converter_identifiers=resp_piece.converter_identifiers or [], - response_error=getattr(resp_piece, "response_error", None), - timestamp=resp_piece.timestamp, - ) - ) - - if assistant_pieces: - assistant_message_response = MessageResponse( - sequence=assistant_seq, - role="assistant", - pieces=assistant_pieces, - timestamp=now, - ) - - return SendMessageResponse( - user_message=user_message_response, - assistant_message=assistant_message_response, - ) - - async def update_system_prompt(self, conversation_id: str, system_prompt: str) -> None: - """ - Update the system prompt for a conversation. - - Args: - conversation_id: The conversation ID. - system_prompt: New system prompt. - """ - state = self._active_conversations.get(conversation_id) - if not state: - raise ValueError(f"Conversation {conversation_id} not found") - - target = self._target_instances.get(conversation_id) - if not target: - raise ValueError(f"Target for conversation {conversation_id} not found") - - # Update target system prompt - target.set_system_prompt( - system_prompt=system_prompt, - conversation_id=conversation_id, - ) - - # Update state - state.system_prompt = system_prompt - - async def update_converters(self, conversation_id: str, converters: List[ConverterConfig]) -> None: - """ - Update the converters for a conversation. - - Args: - conversation_id: The conversation ID. - converters: New converter configurations. - """ - state = self._active_conversations.get(conversation_id) - if not state: - raise ValueError(f"Conversation {conversation_id} not found") - - # Instantiate new converters - converter_instances = self._instantiate_converters(converters) - self._converter_instances[conversation_id] = converter_instances - - # Update state - state.converters = converters - - async def branch_conversation( - self, - conversation_id: str, - request: BranchConversationRequest, - ) -> BranchConversationResponse: - """ - Branch a conversation from a specific point. - - Args: - conversation_id: The source conversation ID. - request: Branch request with last_included_sequence. - - Returns: - New conversation with copied messages. - """ - state = self._active_conversations.get(conversation_id) - if not state: - raise ValueError(f"Conversation {conversation_id} not found") - - # Get messages up to branch point - all_messages = await self.get_conversation_messages(conversation_id) - messages_to_copy = [m for m in all_messages if m.sequence <= request.last_included_sequence] - - if not messages_to_copy: - raise ValueError(f"No messages found at or before sequence {request.last_included_sequence}") - - # Create new conversation with same target and converters - new_conversation_id = str(uuid.uuid4()) - now = datetime.utcnow() - - # Copy target instance - original_target = self._target_instances.get(conversation_id) - if original_target: - # Create new target instance with same config - new_target = self._instantiate_target_from_class(state.target_class, None) - self._target_instances[new_conversation_id] = new_target - - # Copy converters - if state.converters: - self._converter_instances[new_conversation_id] = self._instantiate_converters(state.converters) - - # Create new state - new_state = ConversationState( - conversation_id=new_conversation_id, - target_class=state.target_class, - target_identifier=state.target_identifier, - labels=state.labels, - converters=state.converters, - created_at=now, - message_count=len(messages_to_copy), - ) - self._active_conversations[new_conversation_id] = new_state - - # Copy messages to memory with new conversation ID - for msg in messages_to_copy: - for piece in msg.pieces: - new_piece = MessagePiece( - role=msg.role, - original_value=piece.original_value, - original_value_data_type=piece.original_value_data_type, - converted_value=piece.converted_value, - converted_value_data_type=piece.converted_value_data_type, - converter_identifiers=piece.converter_identifiers if piece.converter_identifiers else None, - conversation_id=new_conversation_id, - sequence=msg.sequence, - ) - self._memory.add_message_pieces_to_memory(message_pieces=[new_piece]) - - return BranchConversationResponse( - conversation_id=new_conversation_id, - branched_from={ - "conversation_id": conversation_id, - "last_included_sequence": request.last_included_sequence, - }, - message_count=len(messages_to_copy), - created_at=now, - ) - - async def preview_converters( - self, - text: str, - converters: List[ConverterConfig], - ) -> List[Dict[str, Any]]: - """ - Preview text through a converter pipeline. - - Args: - text: Input text to convert. - converters: Converter configurations to apply. - - Returns: - List of conversion steps showing intermediate results. - """ - converter_instances = self._instantiate_converters(converters) - - steps = [] - current_text = text - - for i, converter in enumerate(converter_instances): - config = converters[i] - result = await converter.convert_async(prompt=current_text) - - steps.append( - { - "step": i + 1, - "converter_class": config.class_name, - "input": current_text, - "output": result.output_text, - "output_type": result.output_type, - } - ) - - current_text = result.output_text - - return steps - - def cleanup_conversation(self, conversation_id: str) -> None: - """Clean up resources for a conversation.""" - self._active_conversations.pop(conversation_id, None) - self._converter_instances.pop(conversation_id, None) - self._target_instances.pop(conversation_id, None) - - -# Singleton instance -_conversation_service: Optional[ConversationService] = None - - -def get_conversation_service() -> ConversationService: - """ - Get the conversation service singleton. - - Returns: - ConversationService: The conversation service instance. - """ - global _conversation_service - if _conversation_service is None: - _conversation_service = ConversationService() - return _conversation_service diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py new file mode 100644 index 0000000000..af3e96133b --- /dev/null +++ b/pyrit/backend/services/converter_service.py @@ -0,0 +1,344 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Converter service for managing converter instances. + +Handles creation, retrieval, and nested converter support. +""" + +import importlib +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Literal, Optional, Tuple + +from pyrit.backend.models.converters import ( + ConverterInstance, + ConverterInstanceListResponse, + ConverterPreviewRequest, + ConverterPreviewResponse, + CreateConverterRequest, + CreateConverterResponse, + InlineConverterConfig, + PreviewStep, +) +from pyrit.models import PromptDataType + + +class ConverterService: + """Service for managing converter instances.""" + + def __init__(self) -> None: + """Initialize the converter service.""" + # In-memory storage for converter instances + self._instances: Dict[str, ConverterInstance] = {} + # Actual instantiated converter objects + self._converter_objects: Dict[str, Any] = {} + + def _get_converter_class(self, converter_type: str) -> type: + """ + Get the converter class for a given type. + + Args: + converter_type: Converter type string (e.g., 'base64', 'Base64Converter') + + Returns: + The converter class + """ + module = importlib.import_module("pyrit.prompt_converter") + + # Try direct attribute lookup first + cls = getattr(module, converter_type, None) + if cls is not None: + return cls + + # Try common class name patterns + class_name_patterns = [ + converter_type, + f"{converter_type}Converter", + "".join(word.capitalize() for word in converter_type.split("_")), + "".join(word.capitalize() for word in converter_type.split("_")) + "Converter", + ] + + for pattern in class_name_patterns: + cls = getattr(module, pattern, None) + if cls is not None: + return cls + + raise ValueError(f"Converter type '{converter_type}' not found in pyrit.prompt_converter") + + def _create_converter_recursive( + self, + config: Dict[str, Any], + source: Literal["initializer", "user"], + ) -> Tuple[str, Any, List[ConverterInstance]]: + """ + Recursively create converters, handling nested converter params. + + Args: + config: Converter configuration with 'type' and 'params' + source: Source of creation + + Returns: + Tuple of (converter_id, converter_object, list of all created instances) + """ + converter_type = config["type"] + params = dict(config.get("params", {})) + created_instances: List[ConverterInstance] = [] + + # Check for nested converter in params + if "converter" in params and isinstance(params["converter"], dict): + nested_config = params["converter"] + if "type" in nested_config: + # Recursively create nested converter + nested_id, nested_obj, nested_instances = self._create_converter_recursive( + nested_config, source + ) + created_instances.extend(nested_instances) + # Replace inline config with the actual converter object + params["converter"] = nested_obj + + # Create this converter + converter_class = self._get_converter_class(converter_type) + converter_obj = converter_class(**params) + + converter_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + # Store the converter object + self._converter_objects[converter_id] = converter_obj + + # Build resolved params (with nested converter IDs instead of objects) + resolved_params = dict(config.get("params", {})) + if "converter" in resolved_params and isinstance(resolved_params["converter"], dict): + # Replace with the nested converter ID + nested_id = created_instances[-1].converter_id if created_instances else None + resolved_params["converter"] = {"converter_id": nested_id} + + instance = ConverterInstance( + converter_id=converter_id, + type=converter_type, + display_name=None, + params=resolved_params, + created_at=now, + source=source, + ) + self._instances[converter_id] = instance + created_instances.append(instance) + + return converter_id, converter_obj, created_instances + + async def list_converters( + self, + source: Optional[Literal["initializer", "user"]] = None, + ) -> ConverterInstanceListResponse: + """ + List all converter instances. + + Args: + source: Optional filter by source + + Returns: + ConverterInstanceListResponse: List of converter instances + """ + items = list(self._instances.values()) + + if source is not None: + items = [c for c in items if c.source == source] + + return ConverterInstanceListResponse(items=items) + + async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: + """ + Get a converter instance by ID. + + Args: + converter_id: Converter instance ID + + Returns: + ConverterInstance or None if not found + """ + return self._instances.get(converter_id) + + def get_converter_object(self, converter_id: str) -> Optional[Any]: + """ + Get the actual converter object. + + Args: + converter_id: Converter instance ID + + Returns: + The instantiated converter object or None + """ + return self._converter_objects.get(converter_id) + + async def create_converter( + self, + request: CreateConverterRequest, + ) -> CreateConverterResponse: + """ + Create a new converter instance. + + Supports nested converters - if params contains a 'converter' key with + a type/params dict, the nested converter will be created first. + + Args: + request: Converter creation request + + Returns: + CreateConverterResponse: Created converter details + """ + config = { + "type": request.type, + "params": request.params, + } + + converter_id, converter_obj, created_instances = self._create_converter_recursive( + config, "user" + ) + + # Update display name for the outermost converter + if request.display_name and converter_id in self._instances: + self._instances[converter_id].display_name = request.display_name + + outer_instance = self._instances[converter_id] + + return CreateConverterResponse( + converter_id=converter_id, + type=request.type, + display_name=request.display_name, + params=outer_instance.params, + created_converters=created_instances if len(created_instances) > 1 else None, + created_at=outer_instance.created_at, + source="user", + ) + + async def delete_converter(self, converter_id: str) -> bool: + """ + Delete a converter instance. + + Args: + converter_id: Converter instance ID + + Returns: + True if deleted, False if not found + """ + if converter_id in self._instances: + del self._instances[converter_id] + self._converter_objects.pop(converter_id, None) + return True + return False + + async def preview_conversion( + self, + request: ConverterPreviewRequest, + ) -> ConverterPreviewResponse: + """ + Preview conversion through a converter pipeline. + + Args: + request: Preview request with content and converters + + Returns: + ConverterPreviewResponse: Conversion results with steps + """ + current_value = request.original_value + current_type: PromptDataType = request.original_value_data_type + steps: List[PreviewStep] = [] + + # Get converters to apply + converters_to_apply: List[Tuple[Optional[str], str, Any]] = [] + + if request.converter_ids: + for conv_id in request.converter_ids: + conv_obj = self.get_converter_object(conv_id) + if conv_obj is None: + raise ValueError(f"Converter instance '{conv_id}' not found") + instance = self._instances[conv_id] + converters_to_apply.append((conv_id, instance.type, conv_obj)) + + if request.converters: + for inline_config in request.converters: + converter_class = self._get_converter_class(inline_config.type) + conv_obj = converter_class(**inline_config.params) + converters_to_apply.append((None, inline_config.type, conv_obj)) + + # Apply converters in sequence + for conv_id, conv_type, conv_obj in converters_to_apply: + input_value = current_value + input_type = current_type + + result = await conv_obj.convert_async(prompt=current_value) + current_value = result.output_text + current_type = result.output_type + + steps.append( + PreviewStep( + converter_id=conv_id, + converter_type=conv_type, + input_value=input_value, + input_data_type=input_type, + output_value=current_value, + output_data_type=current_type, + ) + ) + + return ConverterPreviewResponse( + original_value=request.original_value, + original_value_data_type=request.original_value_data_type, + converted_value=current_value, + converted_value_data_type=current_type, + steps=steps, + ) + + def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: + """ + Get converter objects for a list of IDs. + + Args: + converter_ids: List of converter instance IDs + + Returns: + List of converter objects + + Raises: + ValueError: If any converter ID is not found + """ + converters = [] + for conv_id in converter_ids: + conv_obj = self.get_converter_object(conv_id) + if conv_obj is None: + raise ValueError(f"Converter instance '{conv_id}' not found") + converters.append(conv_obj) + return converters + + def instantiate_inline_converters( + self, configs: List[InlineConverterConfig] + ) -> List[Any]: + """ + Instantiate converters from inline configurations. + + Args: + configs: List of inline converter configs + + Returns: + List of converter objects + """ + converters = [] + for config in configs: + converter_class = self._get_converter_class(config.type) + conv_obj = converter_class(**config.params) + converters.append(conv_obj) + return converters + + +# Global service instance +_converter_service: Optional[ConverterService] = None + + +def get_converter_service() -> ConverterService: + """Get the global converter service instance.""" + global _converter_service + if _converter_service is None: + _converter_service = ConverterService() + return _converter_service diff --git a/pyrit/backend/services/memory_service.py b/pyrit/backend/services/memory_service.py deleted file mode 100644 index 8d9cbeab02..0000000000 --- a/pyrit/backend/services/memory_service.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Memory service for API access to stored data. - -Wraps CentralMemory with pagination and filtering for API endpoints. -""" - -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple - -from pyrit.backend.models.common import PaginatedResponse, PaginationInfo, filter_sensitive_fields -from pyrit.backend.models.memory import ( - AttackResultQueryResponse, - MessageQueryResponse, - ScenarioResultQueryResponse, - ScoreQueryResponse, - SeedQueryResponse, -) -from pyrit.memory import CentralMemory -from pyrit.models.seeds import SeedObjective, SeedSimulatedConversation - - -def _parse_cursor(cursor: Optional[str]) -> Tuple[Optional[datetime], Optional[str]]: - """ - Parse a cursor string into timestamp and ID components. - - Cursor format: {ISO8601_timestamp}_{record_id} - - Returns: - Tuple[Optional[datetime], Optional[str]]: Parsed timestamp and record ID. - """ - if not cursor: - return None, None - - try: - parts = cursor.rsplit("_", 1) - if len(parts) != 2: - return None, None - timestamp_str, record_id = parts - timestamp = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) - return timestamp, record_id - except (ValueError, AttributeError): - return None, None - - -def _build_cursor(timestamp: datetime, record_id: str) -> str: - """ - Build a cursor string from timestamp and ID. - - Returns: - str: Cursor string for pagination. - """ - return f"{timestamp.isoformat()}_{record_id}" - - -class MemoryService: - """Service for querying memory with pagination support.""" - - def __init__(self) -> None: - """Initialize the memory service.""" - self._memory = CentralMemory.get_memory_instance() - - async def get_messages( - self, - *, - conversation_id: Optional[str] = None, - role: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - harm_categories: Optional[List[str]] = None, - data_type: Optional[str] = None, - response_error: Optional[str] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - limit: int = 50, - cursor: Optional[str] = None, - ) -> PaginatedResponse[MessageQueryResponse]: - """ - Query message pieces with pagination. - - Args: - conversation_id: Filter by conversation. - role: Filter by message role. - labels: Filter by labels. - harm_categories: Filter by harm categories (not supported in current API). - data_type: Filter by data type. - response_error: Filter by response error type (not supported in current API). - start_time: Messages after this time. - end_time: Messages before this time. - limit: Maximum results per page. - cursor: Pagination cursor. - - Returns: - Paginated list of messages. - """ - # Parse cursor for pagination - cursor_time, cursor_id = _parse_cursor(cursor) - - # Query memory - use supported parameters only - pieces = self._memory.get_message_pieces( - conversation_id=conversation_id, - role=role, - labels=labels, - data_type=data_type, - sent_after=cursor_time or start_time, - sent_before=end_time, - ) - - # Apply start_time filter if provided and no cursor - if start_time and not cursor_time: - pieces = [p for p in pieces if p.timestamp and p.timestamp >= start_time] - - # Sort by timestamp descending - pieces = sorted(pieces, key=lambda p: p.timestamp or datetime.min, reverse=True) - - # Apply limit + 1 to check for more - has_more = len(pieces) > limit - pieces = pieces[:limit] - - # Build response items - items = [] - for piece in pieces: - items.append( - MessageQueryResponse( - id=str(piece.id), - conversation_id=piece.conversation_id, - sequence=piece.sequence, - role=piece.role, - original_value=piece.original_value, - original_value_data_type=piece.original_value_data_type, - converted_value=piece.converted_value, - converted_value_data_type=piece.converted_value_data_type, - converter_identifiers=piece.converter_identifiers or [], - target_identifier=filter_sensitive_fields(piece.prompt_target_identifier or {}), - labels=piece.labels, - response_error=piece.response_error, - timestamp=piece.timestamp, - ) - ) - - # Build pagination info - next_cursor = None - if has_more and pieces: - last_piece = pieces[-1] - next_cursor = _build_cursor(last_piece.timestamp, str(last_piece.id)) - - prev_cursor = None - if cursor and pieces: - first_piece = pieces[0] - prev_cursor = _build_cursor(first_piece.timestamp, str(first_piece.id)) - - return PaginatedResponse( - items=items, - pagination=PaginationInfo( - limit=limit, - has_more=has_more, - next_cursor=next_cursor, - prev_cursor=prev_cursor, - ), - ) - - async def get_scores( - self, - *, - message_id: Optional[str] = None, - conversation_id: Optional[str] = None, - score_type: Optional[str] = None, - scorer_type: Optional[str] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - limit: int = 50, - cursor: Optional[str] = None, - ) -> PaginatedResponse[ScoreQueryResponse]: - """ - Query scores with pagination. - - Returns: - PaginatedResponse[ScoreQueryResponse]: Paginated list of scores. - """ - scores = self._memory.get_scores( - score_type=score_type, - ) - - # Apply additional filters - if message_id: - scores = [s for s in scores if str(s.message_piece_id) == message_id] - - if scorer_type: - scores = [ - s - for s in scores - if s.scorer_class_identifier and s.scorer_class_identifier.get("__type__") == scorer_type - ] - - if start_time: - scores = [s for s in scores if s.timestamp and s.timestamp >= start_time] - if end_time: - scores = [s for s in scores if s.timestamp and s.timestamp <= end_time] - - # Sort and paginate - scores = sorted(scores, key=lambda s: s.timestamp or datetime.min, reverse=True) - has_more = len(scores) > limit - scores = scores[:limit] - - items = [] - for score in scores: - items.append( - ScoreQueryResponse( - id=str(score.id), - message_piece_id=str(score.message_piece_id), - score_value=score.score_value, - score_value_description=score.score_value_description or "", - score_type=score.score_type, - score_category=score.score_category, - score_rationale=score.score_rationale or "", - scorer_identifier=filter_sensitive_fields(score.scorer_class_identifier or {}), - objective=score.objective, - timestamp=score.timestamp, - ) - ) - - next_cursor = None - if has_more and scores: - last = scores[-1] - next_cursor = _build_cursor(last.timestamp, str(last.id)) - - return PaginatedResponse( - items=items, - pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=None), - ) - - async def get_attack_results( - self, - *, - conversation_id: Optional[str] = None, - outcome: Optional[str] = None, - attack_type: Optional[str] = None, - objective: Optional[str] = None, - min_turns: Optional[int] = None, - max_turns: Optional[int] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - limit: int = 50, - cursor: Optional[str] = None, - ) -> PaginatedResponse[AttackResultQueryResponse]: - """ - Query attack results with pagination. - - Returns: - PaginatedResponse[AttackResultQueryResponse]: Paginated list of attack results. - """ - results = self._memory.get_attack_results( - conversation_id=conversation_id, - outcome=outcome, - objective=objective, - ) - - # Apply additional filters - if attack_type: - results = [r for r in results if r.attack_identifier and r.attack_identifier.get("__type__") == attack_type] - - if min_turns: - results = [r for r in results if r.executed_turns >= min_turns] - if max_turns: - results = [r for r in results if r.executed_turns <= max_turns] - - # Note: AttackResult doesn't have timestamp field - skip time filtering - # Sort by executed_turns as a proxy for recency - results_list = list(results) - has_more = len(results_list) > limit - results_list = results_list[:limit] - - items = [] - for result in results_list: - items.append( - AttackResultQueryResponse( - id=result.conversation_id, # Use conversation_id as identifier - conversation_id=result.conversation_id, - objective=result.objective, - attack_identifier=filter_sensitive_fields(result.attack_identifier or {}), - outcome=str(result.outcome.value) if result.outcome else None, - outcome_reason=result.outcome_reason, - executed_turns=result.executed_turns, - execution_time_ms=result.execution_time_ms, - timestamp=None, # AttackResult doesn't have timestamp - ) - ) - - # No cursor-based pagination available without timestamps - return PaginatedResponse( - items=items, - pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=None, prev_cursor=None), - ) - - async def get_scenario_results( - self, - *, - scenario_name: Optional[str] = None, - run_state: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - limit: int = 50, - cursor: Optional[str] = None, - ) -> PaginatedResponse[ScenarioResultQueryResponse]: - """ - Query scenario results with pagination. - - Returns: - PaginatedResponse[ScenarioResultQueryResponse]: Paginated list of scenario results. - """ - results = self._memory.get_scenario_results( - scenario_name=scenario_name, - labels=labels, - added_after=start_time, - added_before=end_time, - ) - - # Apply run_state filter if provided (not directly supported in API) - if run_state: - results = [r for r in results if r.scenario_run_state == run_state] - - # Sort by completion_time descending - results_list = list(results) - results_list = sorted(results_list, key=lambda r: r.completion_time or datetime.min, reverse=True) - has_more = len(results_list) > limit - results_list = results_list[:limit] - - items = [] - for result in results_list: - items.append( - ScenarioResultQueryResponse( - id=str(result.id), - scenario_name=result.scenario_identifier.name if result.scenario_identifier else "", - scenario_description=result.scenario_identifier.description if result.scenario_identifier else "", - scenario_version=result.scenario_identifier.version if result.scenario_identifier else 0, - pyrit_version=result.scenario_identifier.pyrit_version if result.scenario_identifier else "", - run_state=result.scenario_run_state, - objective_target_identifier=filter_sensitive_fields(result.objective_target_identifier or {}), - labels=result.labels, - number_tries=result.number_tries, - completion_time=result.completion_time, - timestamp=result.completion_time, # Use completion_time as timestamp - ) - ) - - next_cursor = None - if has_more and results_list: - last = results_list[-1] - next_cursor = _build_cursor(last.completion_time or datetime.min, str(last.id)) - - return PaginatedResponse( - items=items, - pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=None), - ) - - async def get_seeds( - self, - *, - dataset_name: Optional[str] = None, - seed_type: Optional[str] = None, - harm_categories: Optional[List[str]] = None, - data_type: Optional[str] = None, - search: Optional[str] = None, - limit: int = 50, - cursor: Optional[str] = None, - ) -> PaginatedResponse[SeedQueryResponse]: - """ - Query seeds with pagination. - - Returns: - PaginatedResponse[SeedQueryResponse]: Paginated list of seeds. - """ - # Build query params - seed_type needs conversion to SeedType - query_params: Dict[str, Any] = { - "dataset_name": dataset_name, - "harm_categories": harm_categories, - } - if seed_type: - query_params["seed_type"] = seed_type - if data_type: - query_params["data_types"] = [data_type] - if search: - query_params["value"] = search - - seeds = self._memory.get_seeds(**query_params) - - # Sort by date_added descending - seeds_list = sorted(list(seeds), key=lambda s: s.date_added or datetime.min, reverse=True) - has_more = len(seeds_list) > limit - seeds_list = seeds_list[:limit] - - items = [] - for seed in seeds_list: - # Determine seed_type based on class - if isinstance(seed, SeedObjective): - determined_seed_type = "objective" - elif isinstance(seed, SeedSimulatedConversation): - determined_seed_type = "simulated_conversation" - else: - determined_seed_type = "prompt" - - items.append( - SeedQueryResponse( - id=str(seed.id), - value=seed.value, - data_type=seed.data_type, - name=seed.name, - dataset_name=seed.dataset_name, - seed_type=determined_seed_type, # type: ignore - harm_categories=list(seed.harm_categories) if seed.harm_categories else None, - description=seed.description, - source=seed.source, - date_added=seed.date_added, - ) - ) - - next_cursor = None - if has_more and seeds: - last = seeds[-1] - next_cursor = _build_cursor(last.date_added or datetime.min, str(last.id)) - - return PaginatedResponse( - items=items, - pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=None), - ) - - -# Singleton instance -_memory_service: Optional[MemoryService] = None - - -def get_memory_service() -> MemoryService: - """ - Get the memory service singleton. - - Returns: - MemoryService: The memory service instance. - """ - global _memory_service - if _memory_service is None: - _memory_service = MemoryService() - return _memory_service diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py new file mode 100644 index 0000000000..b58ef93985 --- /dev/null +++ b/pyrit/backend/services/target_service.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target service for managing target instances. + +Handles creation, retrieval, and lifecycle of runtime target instances. +""" + +import importlib +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Literal, Optional + +from pyrit.backend.models.common import filter_sensitive_fields +from pyrit.backend.models.targets import ( + CreateTargetRequest, + CreateTargetResponse, + TargetInstance, + TargetListResponse, +) + + +class TargetService: + """Service for managing target instances.""" + + def __init__(self) -> None: + """Initialize the target service.""" + # In-memory storage for target instances + self._instances: Dict[str, TargetInstance] = {} + # Actual instantiated target objects (not serializable) + self._target_objects: Dict[str, Any] = {} + + def _get_target_class(self, target_type: str) -> type: + """ + Get the target class for a given type. + + Args: + target_type: Target type string (e.g., 'azure_openai', 'TextTarget') + + Returns: + The target class + """ + # Try to import from pyrit.prompt_target + module = importlib.import_module("pyrit.prompt_target") + + # Handle both snake_case and PascalCase + # First try direct attribute lookup + cls = getattr(module, target_type, None) + if cls is not None: + return cls + + # Try common class name patterns + class_name_patterns = [ + target_type, + f"{target_type}Target", + "".join(word.capitalize() for word in target_type.split("_")), # snake_case to PascalCase + "".join(word.capitalize() for word in target_type.split("_")) + "Target", + ] + + for pattern in class_name_patterns: + cls = getattr(module, pattern, None) + if cls is not None: + return cls + + raise ValueError(f"Target type '{target_type}' not found in pyrit.prompt_target") + + async def list_targets( + self, + source: Optional[Literal["initializer", "user"]] = None, + ) -> TargetListResponse: + """ + List all target instances. + + Args: + source: Optional filter by source ("initializer" or "user") + + Returns: + TargetListResponse: List of target instances + """ + items = list(self._instances.values()) + + if source is not None: + items = [t for t in items if t.source == source] + + return TargetListResponse(items=items) + + async def get_target(self, target_id: str) -> Optional[TargetInstance]: + """ + Get a target instance by ID. + + Args: + target_id: Target instance ID + + Returns: + TargetInstance or None if not found + """ + return self._instances.get(target_id) + + def get_target_object(self, target_id: str) -> Optional[Any]: + """ + Get the actual target object for use in attacks. + + Args: + target_id: Target instance ID + + Returns: + The instantiated target object or None if not found + """ + return self._target_objects.get(target_id) + + async def create_target( + self, + request: CreateTargetRequest, + ) -> CreateTargetResponse: + """ + Create a new target instance. + + Args: + request: Target creation request + + Returns: + CreateTargetResponse: Created target details + """ + target_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + # Get the target class and instantiate + target_class = self._get_target_class(request.type) + target_obj = target_class(**request.params) + self._target_objects[target_id] = target_obj + + # Get filtered params from target identifier + target_identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} + filtered_params = filter_sensitive_fields(target_identifier) + + # Store the target instance metadata + instance = TargetInstance( + target_id=target_id, + type=request.type, + display_name=request.display_name, + params=filtered_params, + created_at=now, + source="user", + ) + self._instances[target_id] = instance + + return CreateTargetResponse( + target_id=target_id, + type=request.type, + display_name=request.display_name, + params=filtered_params, + created_at=now, + source="user", + ) + + async def delete_target(self, target_id: str) -> bool: + """ + Delete a target instance. + + Args: + target_id: Target instance ID + + Returns: + True if deleted, False if not found + """ + if target_id in self._instances: + del self._instances[target_id] + self._target_objects.pop(target_id, None) + return True + return False + + async def register_initializer_target( + self, + target_type: str, + target_obj: Any, + display_name: Optional[str] = None, + ) -> TargetInstance: + """ + Register a target from an initializer (not user-created). + + Args: + target_type: Target type string + target_obj: Already-instantiated target object + display_name: Optional display name + + Returns: + TargetInstance: The registered target + """ + target_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + # Store the target object + self._target_objects[target_id] = target_obj + + # Get filtered params from target identifier + target_identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} + filtered_params = filter_sensitive_fields(target_identifier) + + instance = TargetInstance( + target_id=target_id, + type=target_type, + display_name=display_name, + params=filtered_params, + created_at=now, + source="initializer", + ) + self._instances[target_id] = instance + + return instance + + +# Global service instance +_target_service: Optional[TargetService] = None + + +def get_target_service() -> TargetService: + """Get the global target service instance.""" + global _target_service + if _target_service is None: + _target_service = TargetService() + return _target_service diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py new file mode 100644 index 0000000000..12534cc891 --- /dev/null +++ b/tests/unit/backend/test_api_routes.py @@ -0,0 +1,602 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend API routes. +""" + +from datetime import datetime, timezone +from typing import List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from pyrit.backend.main import app +from pyrit.backend.models.attacks import ( + AttackDetail, + AttackListResponse, + AttackSummary, + CreateAttackResponse, + Message, + MessagePiece, + SendMessageResponse, +) +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.models.converters import ( + ConverterInstance, + ConverterInstanceListResponse, + ConverterPreviewResponse, + CreateConverterResponse, + PreviewStep, +) +from pyrit.backend.models.targets import ( + CreateTargetResponse, + TargetInstance, + TargetListResponse, +) + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +# ============================================================================ +# Attack Routes Tests +# ============================================================================ + + +class TestAttackRoutes: + """Tests for attack API routes.""" + + def test_list_attacks_returns_empty_list(self, client: TestClient) -> None: + """Test that list attacks returns empty list initially.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["items"] == [] + + def test_list_attacks_with_filters(self, client: TestClient) -> None: + """Test that list attacks accepts filter parameters.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=10, has_more=False), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get( + "/api/attacks", + params={"target_id": "t1", "outcome": "success", "limit": 10}, + ) + + assert response.status_code == status.HTTP_200_OK + mock_service.list_attacks.assert_called_once_with( + target_id="t1", + outcome="success", + limit=10, + cursor=None, + ) + + def test_create_attack_success(self, client: TestClient) -> None: + """Test successful attack creation.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_attack = AsyncMock( + return_value=CreateAttackResponse( + attack_id="attack-1", + name="Test Attack", + target_id="target-1", + target_type="TextTarget", + outcome=None, + prepended_conversation=[], + created_at=now, + updated_at=now, + ) + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks", + json={"target_id": "target-1", "name": "Test Attack"}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["attack_id"] == "attack-1" + + def test_create_attack_target_not_found(self, client: TestClient) -> None: + """Test attack creation with non-existent target.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_attack = AsyncMock( + side_effect=ValueError("Target not found") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks", + json={"target_id": "nonexistent"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_attack_success(self, client: TestClient) -> None: + """Test getting an attack by ID.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_attack = AsyncMock( + return_value=AttackDetail( + attack_id="attack-1", + name="Test", + target_id="target-1", + target_type="TextTarget", + outcome=None, + prepended_conversation=[], + messages=[], + converter_ids=[], + created_at=now, + updated_at=now, + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks/attack-1") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["attack_id"] == "attack-1" + + def test_get_attack_not_found(self, client: TestClient) -> None: + """Test getting a non-existent attack.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_attack = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_update_attack_success(self, client: TestClient) -> None: + """Test updating an attack's outcome.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.update_attack = AsyncMock( + return_value=AttackDetail( + attack_id="attack-1", + target_id="target-1", + target_type="TextTarget", + outcome="success", + prepended_conversation=[], + messages=[], + converter_ids=[], + created_at=now, + updated_at=now, + ) + ) + mock_get_service.return_value = mock_service + + response = client.patch( + "/api/attacks/attack-1", + json={"outcome": "success"}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["outcome"] == "success" + + def test_delete_attack_success(self, client: TestClient) -> None: + """Test deleting an attack.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete_attack = AsyncMock(return_value=True) + mock_get_service.return_value = mock_service + + response = client.delete("/api/attacks/attack-1") + + assert response.status_code == status.HTTP_204_NO_CONTENT + + def test_delete_attack_not_found(self, client: TestClient) -> None: + """Test deleting a non-existent attack.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete_attack = AsyncMock(return_value=False) + mock_get_service.return_value = mock_service + + response = client.delete("/api/attacks/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_send_message_success(self, client: TestClient) -> None: + """Test sending a message in an attack.""" + now = datetime.now(timezone.utc) + + user_msg = Message( + message_id="msg-1", + turn_number=1, + role="user", + pieces=[ + MessagePiece( + piece_id="piece-1", + data_type="text", + converted_value="Hello", + scores=[], + ) + ], + created_at=now, + ) + assistant_msg = Message( + message_id="msg-2", + turn_number=2, + role="assistant", + pieces=[ + MessagePiece( + piece_id="piece-2", + data_type="text", + converted_value="Hi there!", + scores=[], + ) + ], + created_at=now, + ) + summary = AttackSummary( + attack_id="attack-1", + target_id="target-1", + target_type="TextTarget", + message_count=2, + created_at=now, + updated_at=now, + ) + + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.send_message = AsyncMock( + return_value=SendMessageResponse( + user_message=user_msg, + assistant_message=assistant_msg, + attack_summary=summary, + ) + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks/attack-1/messages", + json={"pieces": [{"content": "Hello"}]}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["user_message"]["pieces"][0]["converted_value"] == "Hello" + + +# ============================================================================ +# Target Routes Tests +# ============================================================================ + + +class TestTargetRoutes: + """Tests for target API routes.""" + + def test_list_targets_returns_empty_list(self, client: TestClient) -> None: + """Test that list targets returns empty list initially.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_targets = AsyncMock( + return_value=TargetListResponse(items=[]) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/targets") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["items"] == [] + + def test_list_targets_with_source_filter(self, client: TestClient) -> None: + """Test that list targets accepts source filter.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_targets = AsyncMock( + return_value=TargetListResponse(items=[]) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/targets", params={"source": "user"}) + + assert response.status_code == status.HTTP_200_OK + mock_service.list_targets.assert_called_once_with(source="user") + + def test_create_target_success(self, client: TestClient) -> None: + """Test successful target creation.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_target = AsyncMock( + return_value=CreateTargetResponse( + target_id="target-1", + type="TextTarget", + display_name="My Target", + params={}, + created_at=now, + source="user", + ) + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/targets", + json={"type": "TextTarget", "display_name": "My Target", "params": {}}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["target_id"] == "target-1" + + def test_create_target_invalid_type(self, client: TestClient) -> None: + """Test target creation with invalid type.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_target = AsyncMock( + side_effect=ValueError("Target type not found") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/targets", + json={"type": "InvalidTarget", "params": {}}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_get_target_success(self, client: TestClient) -> None: + """Test getting a target by ID.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_target = AsyncMock( + return_value=TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=now, + source="user", + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/targets/target-1") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["target_id"] == "target-1" + + def test_get_target_not_found(self, client: TestClient) -> None: + """Test getting a non-existent target.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_target = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.get("/api/targets/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_target_success(self, client: TestClient) -> None: + """Test deleting a target.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete_target = AsyncMock(return_value=True) + mock_get_service.return_value = mock_service + + response = client.delete("/api/targets/target-1") + + assert response.status_code == status.HTTP_204_NO_CONTENT + + def test_delete_target_not_found(self, client: TestClient) -> None: + """Test deleting a non-existent target.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete_target = AsyncMock(return_value=False) + mock_get_service.return_value = mock_service + + response = client.delete("/api/targets/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +# ============================================================================ +# Converter Routes Tests +# ============================================================================ + + +class TestConverterRoutes: + """Tests for converter API routes.""" + + def test_list_converter_types(self, client: TestClient) -> None: + """Test listing converter types from registry.""" + with patch("pyrit.backend.routes.converters.get_registry_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_converters.return_value = [] + mock_get_service.return_value = mock_service + + response = client.get("/api/converters/types") + + assert response.status_code == status.HTTP_200_OK + + def test_list_converter_instances(self, client: TestClient) -> None: + """Test listing converter instances.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_converters = AsyncMock( + return_value=ConverterInstanceListResponse(items=[]) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/converters/instances") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["items"] == [] + + def test_create_converter_instance_success(self, client: TestClient) -> None: + """Test successful converter instance creation.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_converter = AsyncMock( + return_value=CreateConverterResponse( + converter_id="conv-1", + type="Base64Converter", + display_name="My Base64", + params={}, + created_at=now, + source="user", + ) + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters/instances", + json={"type": "Base64Converter", "display_name": "My Base64", "params": {}}, + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["converter_id"] == "conv-1" + + def test_create_converter_instance_invalid_type(self, client: TestClient) -> None: + """Test converter creation with invalid type.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_converter = AsyncMock( + side_effect=ValueError("Converter type not found") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters/instances", + json={"type": "InvalidConverter", "params": {}}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_get_converter_instance_success(self, client: TestClient) -> None: + """Test getting a converter instance by ID.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_converter = AsyncMock( + return_value=ConverterInstance( + converter_id="conv-1", + type="Base64Converter", + params={}, + created_at=now, + source="user", + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/converters/instances/conv-1") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["converter_id"] == "conv-1" + + def test_get_converter_instance_not_found(self, client: TestClient) -> None: + """Test getting a non-existent converter instance.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_converter = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.get("/api/converters/instances/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_converter_instance_success(self, client: TestClient) -> None: + """Test deleting a converter instance.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete_converter = AsyncMock(return_value=True) + mock_get_service.return_value = mock_service + + response = client.delete("/api/converters/instances/conv-1") + + assert response.status_code == status.HTTP_204_NO_CONTENT + + def test_delete_converter_instance_not_found(self, client: TestClient) -> None: + """Test deleting a non-existent converter instance.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete_converter = AsyncMock(return_value=False) + mock_get_service.return_value = mock_service + + response = client.delete("/api/converters/instances/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_preview_conversion_success(self, client: TestClient) -> None: + """Test previewing a conversion.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.preview_conversion = AsyncMock( + return_value=ConverterPreviewResponse( + original_value="test", + original_value_data_type="text", + converted_value="dGVzdA==", + converted_value_data_type="text", + steps=[ + PreviewStep( + converter_id=None, + converter_type="Base64Converter", + input_value="test", + input_data_type="text", + output_value="dGVzdA==", + output_data_type="text", + ) + ], + ) + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters/preview", + json={ + "original_value": "test", + "original_value_data_type": "text", + "converters": [{"type": "Base64Converter", "params": {}}], + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["converted_value"] == "dGVzdA==" + assert len(data["steps"]) == 1 diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py new file mode 100644 index 0000000000..55c782e3ba --- /dev/null +++ b/tests/unit/backend/test_attack_service.py @@ -0,0 +1,497 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend attack service. +""" + +from datetime import datetime, timezone +from typing import List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.backend.models.attacks import ( + CreateAttackRequest, + MessagePieceRequest, + PrependedMessageRequest, + SendMessageRequest, + UpdateAttackRequest, +) +from pyrit.backend.models.targets import TargetInstance +from pyrit.backend.services.attack_service import AttackService, AttackState + + +@pytest.mark.usefixtures("patch_central_database") +class TestAttackServiceInit: + """Tests for AttackService initialization.""" + + def test_init_creates_empty_attacks_dict(self) -> None: + """Test that service initializes with empty attacks dictionary.""" + service = AttackService() + assert service._attacks == {} + + def test_init_creates_empty_messages_dict(self) -> None: + """Test that service initializes with empty messages dictionary.""" + service = AttackService() + assert len(service._messages) == 0 + + +@pytest.mark.usefixtures("patch_central_database") +class TestListAttacks: + """Tests for AttackService.list_attacks method.""" + + @pytest.mark.asyncio + async def test_list_attacks_returns_empty_when_no_attacks(self) -> None: + """Test that list_attacks returns empty list when no attacks exist.""" + service = AttackService() + + result = await service.list_attacks() + + assert result.items == [] + assert result.pagination.has_more is False + + @pytest.mark.asyncio + async def test_list_attacks_returns_attacks(self) -> None: + """Test that list_attacks returns existing attacks.""" + service = AttackService() + now = datetime.now(timezone.utc) + + # Add a test attack + service._attacks["test-id"] = AttackState( + attack_id="test-id", + name="Test Attack", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + result = await service.list_attacks() + + assert len(result.items) == 1 + assert result.items[0].attack_id == "test-id" + assert result.items[0].name == "Test Attack" + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_target_id(self) -> None: + """Test that list_attacks filters by target_id.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["attack-1"] = AttackState( + attack_id="attack-1", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + service._attacks["attack-2"] = AttackState( + attack_id="attack-2", + target_id="target-2", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + result = await service.list_attacks(target_id="target-1") + + assert len(result.items) == 1 + assert result.items[0].target_id == "target-1" + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_outcome(self) -> None: + """Test that list_attacks filters by outcome.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["attack-1"] = AttackState( + attack_id="attack-1", + target_id="target-1", + target_type="TextTarget", + outcome="success", + created_at=now, + updated_at=now, + ) + service._attacks["attack-2"] = AttackState( + attack_id="attack-2", + target_id="target-1", + target_type="TextTarget", + outcome="failure", + created_at=now, + updated_at=now, + ) + + result = await service.list_attacks(outcome="success") + + assert len(result.items) == 1 + assert result.items[0].outcome == "success" + + @pytest.mark.asyncio + async def test_list_attacks_respects_limit(self) -> None: + """Test that list_attacks respects the limit parameter.""" + service = AttackService() + now = datetime.now(timezone.utc) + + for i in range(5): + service._attacks[f"attack-{i}"] = AttackState( + attack_id=f"attack-{i}", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + result = await service.list_attacks(limit=2) + + assert len(result.items) == 2 + assert result.pagination.has_more is True + + @pytest.mark.asyncio + async def test_list_attacks_cursor_pagination(self) -> None: + """Test that list_attacks handles cursor-based pagination.""" + service = AttackService() + now = datetime.now(timezone.utc) + + # Create attacks with different updated_at times + for i in range(3): + service._attacks[f"attack-{i}"] = AttackState( + attack_id=f"attack-{i}", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + # Get first page + first_page = await service.list_attacks(limit=2) + assert len(first_page.items) == 2 + + # Get second page using cursor + if first_page.pagination.next_cursor: + second_page = await service.list_attacks( + limit=2, cursor=first_page.pagination.next_cursor + ) + assert len(second_page.items) == 1 + + +@pytest.mark.usefixtures("patch_central_database") +class TestGetAttack: + """Tests for AttackService.get_attack method.""" + + @pytest.mark.asyncio + async def test_get_attack_returns_none_for_nonexistent(self) -> None: + """Test that get_attack returns None for non-existent attack.""" + service = AttackService() + + result = await service.get_attack("nonexistent-id") + + assert result is None + + @pytest.mark.asyncio + async def test_get_attack_returns_attack_details(self) -> None: + """Test that get_attack returns full attack details.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + name="Test Attack", + target_id="target-1", + target_type="TextTarget", + outcome="pending", + created_at=now, + updated_at=now, + ) + + result = await service.get_attack("test-id") + + assert result is not None + assert result.attack_id == "test-id" + assert result.name == "Test Attack" + assert result.target_type == "TextTarget" + + +@pytest.mark.usefixtures("patch_central_database") +class TestCreateAttack: + """Tests for AttackService.create_attack method.""" + + @pytest.mark.asyncio + async def test_create_attack_validates_target_exists(self) -> None: + """Test that create_attack validates the target exists.""" + service = AttackService() + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target = AsyncMock(return_value=None) + mock_get_target_service.return_value = mock_target_service + + request = CreateAttackRequest(target_id="nonexistent") + + with pytest.raises(ValueError, match="not found"): + await service.create_attack(request) + + @pytest.mark.asyncio + async def test_create_attack_success(self) -> None: + """Test successful attack creation.""" + service = AttackService() + + mock_target = TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=datetime.now(timezone.utc), + source="user", + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target = AsyncMock(return_value=mock_target) + mock_get_target_service.return_value = mock_target_service + + request = CreateAttackRequest(target_id="target-1", name="My Attack") + + result = await service.create_attack(request) + + assert result.attack_id is not None + assert result.name == "My Attack" + assert result.target_id == "target-1" + assert result.target_type == "TextTarget" + + @pytest.mark.asyncio + async def test_create_attack_with_prepended_conversation(self) -> None: + """Test attack creation with prepended conversation.""" + service = AttackService() + + mock_target = TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=datetime.now(timezone.utc), + source="user", + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target = AsyncMock(return_value=mock_target) + mock_get_target_service.return_value = mock_target_service + + request = CreateAttackRequest( + target_id="target-1", + prepended_conversation=[ + PrependedMessageRequest(role="system", content="You are a helpful assistant."), + ], + ) + + result = await service.create_attack(request) + + assert len(result.prepended_conversation) == 1 + assert result.prepended_conversation[0].role == "system" + + @pytest.mark.asyncio + async def test_create_attack_validates_converter_ids(self) -> None: + """Test that create_attack validates converter IDs exist.""" + service = AttackService() + + mock_target = TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=datetime.now(timezone.utc), + source="user", + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target = AsyncMock(return_value=mock_target) + mock_get_target_service.return_value = mock_target_service + + with patch( + "pyrit.backend.services.attack_service.get_converter_service" + ) as mock_get_converter_service: + mock_converter_service = MagicMock() + mock_converter_service.get_converter = AsyncMock(return_value=None) + mock_get_converter_service.return_value = mock_converter_service + + request = CreateAttackRequest( + target_id="target-1", + converter_ids=["nonexistent-converter"], + ) + + with pytest.raises(ValueError, match="Converter instance"): + await service.create_attack(request) + + +@pytest.mark.usefixtures("patch_central_database") +class TestUpdateAttack: + """Tests for AttackService.update_attack method.""" + + @pytest.mark.asyncio + async def test_update_attack_returns_none_for_nonexistent(self) -> None: + """Test that update_attack returns None for non-existent attack.""" + service = AttackService() + + request = UpdateAttackRequest(outcome="success") + result = await service.update_attack("nonexistent", request) + + assert result is None + + @pytest.mark.asyncio + async def test_update_attack_updates_outcome(self) -> None: + """Test that update_attack updates the outcome.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + outcome=None, + created_at=now, + updated_at=now, + ) + + request = UpdateAttackRequest(outcome="success") + result = await service.update_attack("test-id", request) + + assert result is not None + assert result.outcome == "success" + + +@pytest.mark.usefixtures("patch_central_database") +class TestDeleteAttack: + """Tests for AttackService.delete_attack method.""" + + @pytest.mark.asyncio + async def test_delete_attack_returns_false_for_nonexistent(self) -> None: + """Test that delete_attack returns False for non-existent attack.""" + service = AttackService() + + result = await service.delete_attack("nonexistent") + + assert result is False + + @pytest.mark.asyncio + async def test_delete_attack_deletes_attack(self) -> None: + """Test that delete_attack removes the attack.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + result = await service.delete_attack("test-id") + + assert result is True + assert "test-id" not in service._attacks + + @pytest.mark.asyncio + async def test_delete_attack_removes_messages(self) -> None: + """Test that delete_attack also removes associated messages.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + service._messages["test-id"] = [] + + await service.delete_attack("test-id") + + assert "test-id" not in service._messages + + +@pytest.mark.usefixtures("patch_central_database") +class TestSendMessage: + """Tests for AttackService.send_message method.""" + + @pytest.mark.asyncio + async def test_send_message_raises_for_nonexistent_attack(self) -> None: + """Test that send_message raises ValueError for non-existent attack.""" + service = AttackService() + + request = SendMessageRequest( + pieces=[MessagePieceRequest(content="Hello")], + ) + + with pytest.raises(ValueError, match="Attack"): + await service.send_message("nonexistent", request) + + @pytest.mark.asyncio + async def test_send_message_raises_for_missing_target_object(self) -> None: + """Test that send_message raises when target object is not found.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target_object.return_value = None + mock_get_target_service.return_value = mock_target_service + + with patch( + "pyrit.backend.services.attack_service.get_converter_service" + ) as mock_get_converter_service: + mock_converter_service = MagicMock() + mock_get_converter_service.return_value = mock_converter_service + + request = SendMessageRequest( + pieces=[MessagePieceRequest(content="Hello")], + ) + + with pytest.raises(ValueError, match="Target object"): + await service.send_message("test-id", request) + + +@pytest.mark.usefixtures("patch_central_database") +class TestAttackServiceSingleton: + """Tests for get_attack_service singleton function.""" + + def test_get_attack_service_returns_attack_service(self) -> None: + """Test that get_attack_service returns an AttackService instance.""" + from pyrit.backend.services.attack_service import get_attack_service + + # Reset singleton for clean test + import pyrit.backend.services.attack_service as module + module._attack_service = None + + service = get_attack_service() + assert isinstance(service, AttackService) + + def test_get_attack_service_returns_same_instance(self) -> None: + """Test that get_attack_service returns the same instance.""" + from pyrit.backend.services.attack_service import get_attack_service + + # Reset singleton for clean test + import pyrit.backend.services.attack_service as module + module._attack_service = None + + service1 = get_attack_service() + service2 = get_attack_service() + assert service1 is service2 diff --git a/tests/unit/backend/test_common_models.py b/tests/unit/backend/test_common_models.py index 06739e5c9e..803e58db76 100644 --- a/tests/unit/backend/test_common_models.py +++ b/tests/unit/backend/test_common_models.py @@ -5,8 +5,6 @@ Tests for backend common models. """ - - from pyrit.backend.models.common import ( FieldError, IdentifierDict, @@ -402,4 +400,3 @@ def test_problem_detail_serialization(self) -> None: assert "instance" not in data # None should be excluded assert data["type"] == "/errors/test" - diff --git a/tests/unit/backend/test_conversation_service.py b/tests/unit/backend/test_conversation_service.py deleted file mode 100644 index da700c2d59..0000000000 --- a/tests/unit/backend/test_conversation_service.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Tests for backend conversation service. -""" - -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from pyrit.backend.models.conversations import ( - CreateConversationRequest, - ConverterConfig, -) -from pyrit.backend.services.conversation_service import ( - ConversationService, - ConversationState, - get_conversation_service, -) - - -class TestConversationState: - """Tests for ConversationState model.""" - - def test_conversation_state_creation(self) -> None: - """Test creating a conversation state.""" - state = ConversationState( - conversation_id="test-id", - target_class="OpenAIChatTarget", - target_identifier={"endpoint": "test"}, - created_at=datetime.utcnow(), - ) - - assert state.conversation_id == "test-id" - assert state.target_class == "OpenAIChatTarget" - assert state.converters == [] - - def test_conversation_state_with_system_prompt(self) -> None: - """Test conversation state with system prompt.""" - state = ConversationState( - conversation_id="test-id", - target_class="OpenAIChatTarget", - target_identifier={}, - system_prompt="Test prompt", - created_at=datetime.utcnow(), - ) - - assert state.system_prompt == "Test prompt" - - def test_conversation_state_defaults(self) -> None: - """Test conversation state default values.""" - state = ConversationState( - conversation_id="test-id", - target_class="OpenAIChatTarget", - target_identifier={}, - created_at=datetime.utcnow(), - ) - - assert state.system_prompt is None - assert state.converters == [] - assert state.message_count == 0 - assert state.labels is None - - -class TestConversationService: - """Tests for ConversationService.""" - - @pytest.fixture - def service(self, patch_central_database: MagicMock) -> ConversationService: - """Create a conversation service instance. - - Args: - patch_central_database: The patched central database fixture. - - Returns: - ConversationService: The service instance. - """ - return ConversationService() - - @pytest.mark.asyncio - async def test_create_conversation_success( - self, service: ConversationService - ) -> None: - """Test creating a conversation successfully.""" - mock_target = MagicMock() - mock_target.get_identifier.return_value = {"type": "TextTarget"} - - with patch.object( - service, "_instantiate_target_from_class", return_value=mock_target - ): - request = CreateConversationRequest(target_class="TextTarget") - result = await service.create_conversation(request) - - assert result is not None - assert result.conversation_id is not None - - @pytest.mark.asyncio - async def test_create_conversation_with_labels( - self, service: ConversationService - ) -> None: - """Test creating a conversation with labels.""" - mock_target = MagicMock() - mock_target.get_identifier.return_value = {"type": "TextTarget"} - - with patch.object( - service, "_instantiate_target_from_class", return_value=mock_target - ): - request = CreateConversationRequest( - target_class="TextTarget", - labels={"test": "label"}, - ) - result = await service.create_conversation(request) - - assert result.labels == {"test": "label"} - - @pytest.mark.asyncio - async def test_create_conversation_invalid_target_class( - self, service: ConversationService - ) -> None: - """Test creating a conversation with invalid target class.""" - with patch.object( - service, - "_instantiate_target_from_class", - side_effect=ValueError("Target class 'InvalidTarget' not found"), - ): - request = CreateConversationRequest(target_class="InvalidTarget") - - with pytest.raises(ValueError, match="not found"): - await service.create_conversation(request) - - @pytest.mark.asyncio - async def test_get_conversation_existing( - self, service: ConversationService - ) -> None: - """Test getting an existing conversation.""" - mock_target = MagicMock() - mock_target.get_identifier.return_value = {"type": "TextTarget"} - - with patch.object( - service, "_instantiate_target_from_class", return_value=mock_target - ): - request = CreateConversationRequest(target_class="TextTarget") - created = await service.create_conversation(request) - - result = await service.get_conversation(created.conversation_id) - - assert result is not None - assert result.conversation_id == created.conversation_id - - @pytest.mark.asyncio - async def test_get_conversation_nonexistent( - self, service: ConversationService - ) -> None: - """Test getting a nonexistent conversation.""" - result = await service.get_conversation("nonexistent-id") - - assert result is None - - @pytest.mark.asyncio - async def test_get_conversation_messages_returns_list( - self, service: ConversationService - ) -> None: - """Test getting messages from a conversation.""" - messages = await service.get_conversation_messages("any-id") - - assert isinstance(messages, list) - - @pytest.mark.asyncio - async def test_cleanup_conversation_existing( - self, service: ConversationService - ) -> None: - """Test cleaning up an existing conversation.""" - mock_target = MagicMock() - mock_target.get_identifier.return_value = {"type": "TextTarget"} - - with patch.object( - service, "_instantiate_target_from_class", return_value=mock_target - ): - request = CreateConversationRequest(target_class="TextTarget") - created = await service.create_conversation(request) - - service.cleanup_conversation(created.conversation_id) - - result = await service.get_conversation(created.conversation_id) - assert result is None - - @pytest.mark.asyncio - async def test_cleanup_conversation_removes_target_instance( - self, service: ConversationService - ) -> None: - """Test that cleanup removes target instance.""" - mock_target = MagicMock() - mock_target.get_identifier.return_value = {"type": "TextTarget"} - - with patch.object( - service, "_instantiate_target_from_class", return_value=mock_target - ): - request = CreateConversationRequest(target_class="TextTarget") - created = await service.create_conversation(request) - - assert created.conversation_id in service._target_instances - - service.cleanup_conversation(created.conversation_id) - - assert created.conversation_id not in service._target_instances - - def test_cleanup_conversation_nonexistent_no_error( - self, service: ConversationService - ) -> None: - """Test cleaning up nonexistent conversation doesn't raise error.""" - # Should not raise any exception - service.cleanup_conversation("nonexistent-id") - - -class TestGetConversationServiceSingleton: - """Tests for get_conversation_service singleton function.""" - - def test_returns_conversation_service_instance( - self, patch_central_database: MagicMock - ) -> None: - """Test that get_conversation_service returns a ConversationService.""" - import pyrit.backend.services.conversation_service as module - - module._conversation_service = None - - service = get_conversation_service() - - assert isinstance(service, ConversationService) - - def test_returns_same_instance( - self, patch_central_database: MagicMock - ) -> None: - """Test that get_conversation_service returns the same instance.""" - import pyrit.backend.services.conversation_service as module - - module._conversation_service = None - - service1 = get_conversation_service() - service2 = get_conversation_service() - - assert service1 is service2 diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py new file mode 100644 index 0000000000..5a8c45a276 --- /dev/null +++ b/tests/unit/backend/test_converter_service.py @@ -0,0 +1,538 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend converter service. +""" + +from datetime import datetime, timezone +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.backend.models.converters import ( + ConverterInstance, + ConverterPreviewRequest, + CreateConverterRequest, + InlineConverterConfig, +) +from pyrit.backend.services.converter_service import ConverterService +from pyrit.models import PromptDataType + + +class TestConverterServiceInit: + """Tests for ConverterService initialization.""" + + def test_init_creates_empty_instances_dict(self) -> None: + """Test that service initializes with empty instances dictionary.""" + service = ConverterService() + assert service._instances == {} + + def test_init_creates_empty_converter_objects_dict(self) -> None: + """Test that service initializes with empty converter objects dictionary.""" + service = ConverterService() + assert service._converter_objects == {} + + +class TestListConverters: + """Tests for ConverterService.list_converters method.""" + + @pytest.mark.asyncio + async def test_list_converters_returns_empty_when_no_converters(self) -> None: + """Test that list_converters returns empty list when no converters exist.""" + service = ConverterService() + + result = await service.list_converters() + + assert result.items == [] + + @pytest.mark.asyncio + async def test_list_converters_returns_converters(self) -> None: + """Test that list_converters returns existing converters.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="Base64Converter", + display_name="My Converter", + params={}, + created_at=now, + source="user", + ) + + result = await service.list_converters() + + assert len(result.items) == 1 + assert result.items[0].converter_id == "conv-1" + assert result.items[0].display_name == "My Converter" + + @pytest.mark.asyncio + async def test_list_converters_filters_by_source_user(self) -> None: + """Test that list_converters filters by source='user'.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="Base64Converter", + params={}, + created_at=now, + source="user", + ) + service._instances["conv-2"] = ConverterInstance( + converter_id="conv-2", + type="Base64Converter", + params={}, + created_at=now, + source="initializer", + ) + + result = await service.list_converters(source="user") + + assert len(result.items) == 1 + assert result.items[0].source == "user" + + @pytest.mark.asyncio + async def test_list_converters_filters_by_source_initializer(self) -> None: + """Test that list_converters filters by source='initializer'.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="Base64Converter", + params={}, + created_at=now, + source="user", + ) + service._instances["conv-2"] = ConverterInstance( + converter_id="conv-2", + type="Base64Converter", + params={}, + created_at=now, + source="initializer", + ) + + result = await service.list_converters(source="initializer") + + assert len(result.items) == 1 + assert result.items[0].source == "initializer" + + +class TestGetConverter: + """Tests for ConverterService.get_converter method.""" + + @pytest.mark.asyncio + async def test_get_converter_returns_none_for_nonexistent(self) -> None: + """Test that get_converter returns None for non-existent converter.""" + service = ConverterService() + + result = await service.get_converter("nonexistent-id") + + assert result is None + + @pytest.mark.asyncio + async def test_get_converter_returns_converter(self) -> None: + """Test that get_converter returns the converter instance.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="Base64Converter", + display_name="Test Converter", + params={"key": "value"}, + created_at=now, + source="user", + ) + + result = await service.get_converter("conv-1") + + assert result is not None + assert result.converter_id == "conv-1" + assert result.display_name == "Test Converter" + + +class TestGetConverterObject: + """Tests for ConverterService.get_converter_object method.""" + + def test_get_converter_object_returns_none_for_nonexistent(self) -> None: + """Test that get_converter_object returns None for non-existent converter.""" + service = ConverterService() + + result = service.get_converter_object("nonexistent-id") + + assert result is None + + def test_get_converter_object_returns_object(self) -> None: + """Test that get_converter_object returns the actual converter object.""" + service = ConverterService() + mock_converter = MagicMock() + service._converter_objects["conv-1"] = mock_converter + + result = service.get_converter_object("conv-1") + + assert result is mock_converter + + +class TestGetConverterClass: + """Tests for ConverterService._get_converter_class method.""" + + def test_get_converter_class_raises_for_invalid_type(self) -> None: + """Test that _get_converter_class raises ValueError for invalid type.""" + service = ConverterService() + + with pytest.raises(ValueError, match="not found"): + service._get_converter_class("NonExistentConverter") + + def test_get_converter_class_finds_base64_converter(self) -> None: + """Test that _get_converter_class finds Base64Converter.""" + service = ConverterService() + + result = service._get_converter_class("Base64Converter") + + assert result is not None + assert "Base64" in result.__name__ + + def test_get_converter_class_handles_snake_case(self) -> None: + """Test that _get_converter_class handles snake_case names.""" + service = ConverterService() + + # base64 should resolve to Base64Converter + result = service._get_converter_class("base64") + + assert result is not None + + +class TestCreateConverter: + """Tests for ConverterService.create_converter method.""" + + @pytest.mark.asyncio + async def test_create_converter_raises_for_invalid_type(self) -> None: + """Test that create_converter raises for invalid converter type.""" + service = ConverterService() + + request = CreateConverterRequest( + type="NonExistentConverter", + params={}, + ) + + with pytest.raises(ValueError, match="not found"): + await service.create_converter(request) + + @pytest.mark.asyncio + async def test_create_converter_success(self) -> None: + """Test successful converter creation.""" + service = ConverterService() + + request = CreateConverterRequest( + type="Base64Converter", + display_name="My Base64", + params={}, + ) + + result = await service.create_converter(request) + + assert result.converter_id is not None + assert result.type == "Base64Converter" + assert result.display_name == "My Base64" + assert result.source == "user" + + @pytest.mark.asyncio + async def test_create_converter_stores_instance(self) -> None: + """Test that create_converter stores the instance.""" + service = ConverterService() + + request = CreateConverterRequest( + type="Base64Converter", + params={}, + ) + + result = await service.create_converter(request) + + assert result.converter_id in service._instances + assert result.converter_id in service._converter_objects + + +class TestDeleteConverter: + """Tests for ConverterService.delete_converter method.""" + + @pytest.mark.asyncio + async def test_delete_converter_returns_false_for_nonexistent(self) -> None: + """Test that delete_converter returns False for non-existent converter.""" + service = ConverterService() + + result = await service.delete_converter("nonexistent") + + assert result is False + + @pytest.mark.asyncio + async def test_delete_converter_deletes_converter(self) -> None: + """Test that delete_converter removes the converter.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="Base64Converter", + params={}, + created_at=now, + source="user", + ) + service._converter_objects["conv-1"] = MagicMock() + + result = await service.delete_converter("conv-1") + + assert result is True + assert "conv-1" not in service._instances + assert "conv-1" not in service._converter_objects + + +class TestPreviewConversion: + """Tests for ConverterService.preview_conversion method.""" + + @pytest.mark.asyncio + async def test_preview_conversion_raises_for_nonexistent_converter(self) -> None: + """Test that preview raises ValueError for non-existent converter ID.""" + service = ConverterService() + + request = ConverterPreviewRequest( + original_value="test", + original_value_data_type="text", + converter_ids=["nonexistent"], + ) + + with pytest.raises(ValueError, match="not found"): + await service.preview_conversion(request) + + @pytest.mark.asyncio + async def test_preview_conversion_with_converter_ids(self) -> None: + """Test preview with converter IDs.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + # Create a mock converter + mock_converter = MagicMock() + mock_result = MagicMock() + mock_result.output_text = "encoded_value" + mock_result.output_type = "text" + mock_converter.convert_async = AsyncMock(return_value=mock_result) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="MockConverter", + params={}, + created_at=now, + source="user", + ) + service._converter_objects["conv-1"] = mock_converter + + request = ConverterPreviewRequest( + original_value="test", + original_value_data_type="text", + converter_ids=["conv-1"], + ) + + result = await service.preview_conversion(request) + + assert result.original_value == "test" + assert result.converted_value == "encoded_value" + assert len(result.steps) == 1 + assert result.steps[0].converter_id == "conv-1" + + @pytest.mark.asyncio + async def test_preview_conversion_with_inline_converters(self) -> None: + """Test preview with inline converter configs.""" + service = ConverterService() + + request = ConverterPreviewRequest( + original_value="test", + original_value_data_type="text", + converters=[ + InlineConverterConfig(type="Base64Converter", params={}), + ], + ) + + result = await service.preview_conversion(request) + + assert result.original_value == "test" + assert result.converted_value is not None + assert len(result.steps) == 1 + # Base64 of "test" should be different from "test" + assert result.converted_value != "test" + + @pytest.mark.asyncio + async def test_preview_conversion_chains_multiple_converters(self) -> None: + """Test that preview chains multiple converters.""" + service = ConverterService() + now = datetime.now(timezone.utc) + + # Create two mock converters + mock_converter1 = MagicMock() + mock_result1 = MagicMock() + mock_result1.output_text = "step1_output" + mock_result1.output_type = "text" + mock_converter1.convert_async = AsyncMock(return_value=mock_result1) + + mock_converter2 = MagicMock() + mock_result2 = MagicMock() + mock_result2.output_text = "step2_output" + mock_result2.output_type = "text" + mock_converter2.convert_async = AsyncMock(return_value=mock_result2) + + service._instances["conv-1"] = ConverterInstance( + converter_id="conv-1", + type="MockConverter1", + params={}, + created_at=now, + source="user", + ) + service._converter_objects["conv-1"] = mock_converter1 + + service._instances["conv-2"] = ConverterInstance( + converter_id="conv-2", + type="MockConverter2", + params={}, + created_at=now, + source="user", + ) + service._converter_objects["conv-2"] = mock_converter2 + + request = ConverterPreviewRequest( + original_value="input", + original_value_data_type="text", + converter_ids=["conv-1", "conv-2"], + ) + + result = await service.preview_conversion(request) + + assert result.converted_value == "step2_output" + assert len(result.steps) == 2 + # Second converter should receive output from first + mock_converter2.convert_async.assert_called_with(prompt="step1_output") + + +class TestGetConverterObjectsForIds: + """Tests for ConverterService.get_converter_objects_for_ids method.""" + + def test_get_converter_objects_for_ids_raises_for_nonexistent(self) -> None: + """Test that method raises ValueError for non-existent ID.""" + service = ConverterService() + + with pytest.raises(ValueError, match="not found"): + service.get_converter_objects_for_ids(["nonexistent"]) + + def test_get_converter_objects_for_ids_returns_objects(self) -> None: + """Test that method returns converter objects in order.""" + service = ConverterService() + + mock1 = MagicMock() + mock2 = MagicMock() + service._converter_objects["conv-1"] = mock1 + service._converter_objects["conv-2"] = mock2 + + result = service.get_converter_objects_for_ids(["conv-1", "conv-2"]) + + assert result == [mock1, mock2] + + +class TestInstantiateInlineConverters: + """Tests for ConverterService.instantiate_inline_converters method.""" + + def test_instantiate_inline_converters_creates_objects(self) -> None: + """Test that inline converters are instantiated.""" + service = ConverterService() + + configs = [ + InlineConverterConfig(type="Base64Converter", params={}), + ] + + result = service.instantiate_inline_converters(configs) + + assert len(result) == 1 + # Verify it's a real converter object + assert hasattr(result[0], "convert_async") + + def test_instantiate_inline_converters_raises_for_invalid_type(self) -> None: + """Test that invalid type raises ValueError.""" + service = ConverterService() + + configs = [ + InlineConverterConfig(type="NonExistentConverter", params={}), + ] + + with pytest.raises(ValueError, match="not found"): + service.instantiate_inline_converters(configs) + + +class TestNestedConverterCreation: + """Tests for nested converter creation.""" + + @pytest.mark.asyncio + async def test_create_converter_with_nested_converter(self) -> None: + """Test creating a converter with a nested converter config.""" + service = ConverterService() + + # Mock the parent converter class that accepts a 'converter' param + mock_parent_class = MagicMock() + mock_parent_instance = MagicMock() + mock_parent_class.return_value = mock_parent_instance + + mock_child_class = MagicMock() + mock_child_instance = MagicMock() + mock_child_class.return_value = mock_child_instance + + def mock_get_class(converter_type: str) -> type: + if converter_type == "ParentConverter": + return mock_parent_class + elif converter_type == "ChildConverter": + return mock_child_class + raise ValueError(f"Unknown type: {converter_type}") + + with patch.object(service, "_get_converter_class", side_effect=mock_get_class): + request = CreateConverterRequest( + type="ParentConverter", + params={ + "converter": { + "type": "ChildConverter", + "params": {}, + }, + }, + ) + + result = await service.create_converter(request) + + # Parent should be created with child converter object + mock_parent_class.assert_called() + # The call should have received the child instance, not the dict + call_kwargs = mock_parent_class.call_args[1] + assert call_kwargs.get("converter") is mock_child_instance + + +class TestConverterServiceSingleton: + """Tests for get_converter_service singleton function.""" + + def test_get_converter_service_returns_converter_service(self) -> None: + """Test that get_converter_service returns a ConverterService instance.""" + from pyrit.backend.services.converter_service import get_converter_service + + # Reset singleton for clean test + import pyrit.backend.services.converter_service as module + module._converter_service = None + + service = get_converter_service() + assert isinstance(service, ConverterService) + + def test_get_converter_service_returns_same_instance(self) -> None: + """Test that get_converter_service returns the same instance.""" + from pyrit.backend.services.converter_service import get_converter_service + + # Reset singleton for clean test + import pyrit.backend.services.converter_service as module + module._converter_service = None + + service1 = get_converter_service() + service2 = get_converter_service() + assert service1 is service2 diff --git a/tests/unit/backend/test_error_handlers.py b/tests/unit/backend/test_error_handlers.py index b1afe028c7..93131360e8 100644 --- a/tests/unit/backend/test_error_handlers.py +++ b/tests/unit/backend/test_error_handlers.py @@ -5,10 +5,8 @@ Tests for backend error handler middleware. """ -from unittest.mock import MagicMock - import pytest -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI from fastapi.testclient import TestClient from pyrit.backend.middleware.error_handlers import register_error_handlers @@ -61,9 +59,7 @@ async def test_endpoint(data: TestInput) -> dict: assert data["status"] == 422 assert "errors" in data - def test_validation_error_includes_field_details( - self, app: FastAPI, client: TestClient - ) -> None: + def test_validation_error_includes_field_details(self, app: FastAPI, client: TestClient) -> None: """Test that validation errors include field-level details.""" from pydantic import BaseModel @@ -100,9 +96,7 @@ async def test_endpoint() -> dict: assert data["status"] == 400 assert "Invalid input value" in data["detail"] - def test_file_not_found_error_returns_404( - self, app: FastAPI, client: TestClient - ) -> None: + def test_file_not_found_error_returns_404(self, app: FastAPI, client: TestClient) -> None: """Test that FileNotFoundError returns 404 with RFC 7807 format.""" @app.get("/test") @@ -117,9 +111,7 @@ async def test_endpoint() -> dict: assert data["title"] == "Not Found" assert data["status"] == 404 - def test_permission_error_returns_403( - self, app: FastAPI, client: TestClient - ) -> None: + def test_permission_error_returns_403(self, app: FastAPI, client: TestClient) -> None: """Test that PermissionError returns 403 with RFC 7807 format.""" @app.get("/test") @@ -134,9 +126,7 @@ async def test_endpoint() -> dict: assert data["title"] == "Forbidden" assert data["status"] == 403 - def test_not_implemented_error_returns_501( - self, app: FastAPI, client: TestClient - ) -> None: + def test_not_implemented_error_returns_501(self, app: FastAPI, client: TestClient) -> None: """Test that NotImplementedError returns 501 with RFC 7807 format.""" @app.get("/test") @@ -151,9 +141,7 @@ async def test_endpoint() -> dict: assert data["title"] == "Not Implemented" assert data["status"] == 501 - def test_generic_exception_returns_500( - self, app: FastAPI, client: TestClient - ) -> None: + def test_generic_exception_returns_500(self, app: FastAPI, client: TestClient) -> None: """Test that unexpected exceptions return 500 with RFC 7807 format.""" @app.get("/test") @@ -170,9 +158,7 @@ async def test_endpoint() -> dict: # Should not leak internal error details assert "An unexpected error occurred" in data["detail"] - def test_error_response_includes_instance( - self, app: FastAPI, client: TestClient - ) -> None: + def test_error_response_includes_instance(self, app: FastAPI, client: TestClient) -> None: """Test that error responses include the request path as instance.""" @app.get("/api/v1/test/resource") @@ -184,9 +170,7 @@ async def test_endpoint() -> dict: data = response.json() assert data["instance"] == "/api/v1/test/resource" - def test_error_excludes_none_fields( - self, app: FastAPI, client: TestClient - ) -> None: + def test_error_excludes_none_fields(self, app: FastAPI, client: TestClient) -> None: """Test that None fields are excluded from error response.""" @app.get("/test") diff --git a/tests/unit/backend/test_memory_service.py b/tests/unit/backend/test_memory_service.py deleted file mode 100644 index 95b08f63e3..0000000000 --- a/tests/unit/backend/test_memory_service.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Tests for backend memory service. -""" - -from datetime import datetime, timezone -from unittest.mock import MagicMock - -import pytest - -from pyrit.backend.models.common import PaginatedResponse -from pyrit.backend.services.memory_service import ( - MemoryService, - get_memory_service, - _parse_cursor, - _build_cursor, -) - - -class TestCursorFunctions: - """Tests for cursor parsing and building functions.""" - - def test_parse_cursor_with_valid_cursor(self) -> None: - """Test parsing a valid cursor string.""" - timestamp = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc) - cursor = f"{timestamp.isoformat()}_abc123" - - parsed_time, parsed_id = _parse_cursor(cursor) - - assert parsed_id == "abc123" - assert parsed_time is not None - assert parsed_time.year == 2024 - - def test_parse_cursor_with_none(self) -> None: - """Test parsing None cursor.""" - parsed_time, parsed_id = _parse_cursor(None) - - assert parsed_time is None - assert parsed_id is None - - def test_parse_cursor_with_empty_string(self) -> None: - """Test parsing empty cursor string.""" - parsed_time, parsed_id = _parse_cursor("") - - assert parsed_time is None - assert parsed_id is None - - def test_parse_cursor_with_invalid_format(self) -> None: - """Test parsing cursor with invalid format.""" - parsed_time, parsed_id = _parse_cursor("invalid_cursor_without_timestamp") - - assert parsed_time is None - assert parsed_id is None - - def test_parse_cursor_with_malformed_timestamp(self) -> None: - """Test parsing cursor with malformed timestamp.""" - parsed_time, parsed_id = _parse_cursor("not-a-timestamp_abc123") - - assert parsed_time is None - assert parsed_id is None - - def test_build_cursor_creates_valid_string(self) -> None: - """Test building a cursor string.""" - timestamp = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc) - record_id = "test-id-123" - - cursor = _build_cursor(timestamp, record_id) - - assert record_id in cursor - assert timestamp.isoformat() in cursor - - def test_cursor_roundtrip(self) -> None: - """Test that a cursor can be built and parsed back.""" - original_time = datetime(2024, 6, 15, 14, 30, 0, tzinfo=timezone.utc) - original_id = "message-uuid-123" - - cursor = _build_cursor(original_time, original_id) - parsed_time, parsed_id = _parse_cursor(cursor) - - assert parsed_id == original_id - assert parsed_time is not None - - -class TestMemoryService: - """Tests for MemoryService.""" - - @pytest.fixture - def service(self, patch_central_database: MagicMock) -> MemoryService: - """Create a memory service with patched database. - - Args: - patch_central_database: The patched central database fixture. - - Returns: - MemoryService: The service instance. - """ - return MemoryService() - - @pytest.mark.asyncio - async def test_get_messages_returns_paginated_result( - self, service: MemoryService - ) -> None: - """Test that get_messages returns paginated results.""" - result = await service.get_messages() - - assert isinstance(result, PaginatedResponse) - assert isinstance(result.items, list) - assert result.pagination is not None - - @pytest.mark.asyncio - async def test_get_messages_with_conversation_id( - self, service: MemoryService - ) -> None: - """Test filtering messages by conversation ID.""" - result = await service.get_messages(conversation_id="test-conv-id") - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_messages_respects_limit( - self, service: MemoryService - ) -> None: - """Test that limit parameter is respected.""" - result = await service.get_messages(limit=10) - - assert len(result.items) <= 10 - - @pytest.mark.asyncio - async def test_get_messages_with_role_filter( - self, service: MemoryService - ) -> None: - """Test filtering messages by role.""" - result = await service.get_messages(role="user") - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_messages_with_time_filters( - self, service: MemoryService - ) -> None: - """Test filtering messages by time range.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 12, 31, tzinfo=timezone.utc) - - result = await service.get_messages(start_time=start, end_time=end) - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_messages_pagination_has_more( - self, service: MemoryService - ) -> None: - """Test that pagination correctly reports has_more.""" - result = await service.get_messages(limit=1) - - assert isinstance(result.pagination.has_more, bool) - - @pytest.mark.asyncio - async def test_get_scores_returns_paginated_result( - self, service: MemoryService - ) -> None: - """Test that get_scores returns paginated results.""" - result = await service.get_scores() - - assert isinstance(result, PaginatedResponse) - assert isinstance(result.items, list) - assert result.pagination is not None - - @pytest.mark.asyncio - async def test_get_scores_with_message_id( - self, service: MemoryService - ) -> None: - """Test filtering scores by message ID.""" - result = await service.get_scores(message_id="test-message-id") - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_scores_with_score_type( - self, service: MemoryService - ) -> None: - """Test filtering scores by score type.""" - result = await service.get_scores(score_type="true_false") - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_attack_results_returns_paginated_result( - self, service: MemoryService - ) -> None: - """Test that get_attack_results returns paginated results.""" - result = await service.get_attack_results() - - assert isinstance(result, PaginatedResponse) - assert isinstance(result.items, list) - assert result.pagination is not None - - @pytest.mark.asyncio - async def test_get_attack_results_with_outcome_filter( - self, service: MemoryService - ) -> None: - """Test filtering attack results by outcome.""" - result = await service.get_attack_results(outcome="success") - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_attack_results_with_turn_filters( - self, service: MemoryService - ) -> None: - """Test filtering attack results by turn count.""" - result = await service.get_attack_results(min_turns=1, max_turns=10) - - assert isinstance(result, PaginatedResponse) - - @pytest.mark.asyncio - async def test_get_seeds_returns_paginated_result( - self, service: MemoryService - ) -> None: - """Test that get_seeds returns paginated results.""" - result = await service.get_seeds() - - assert isinstance(result, PaginatedResponse) - assert isinstance(result.items, list) - assert result.pagination is not None - - @pytest.mark.asyncio - async def test_get_scenario_results_returns_paginated_result( - self, service: MemoryService - ) -> None: - """Test that get_scenario_results returns paginated results.""" - result = await service.get_scenario_results() - - assert isinstance(result, PaginatedResponse) - assert isinstance(result.items, list) - assert result.pagination is not None - - -class TestGetMemoryServiceSingleton: - """Tests for get_memory_service singleton function.""" - - def test_returns_memory_service_instance( - self, patch_central_database: MagicMock - ) -> None: - """Test that get_memory_service returns a MemoryService.""" - import pyrit.backend.services.memory_service as module - - module._memory_service = None - - service = get_memory_service() - - assert isinstance(service, MemoryService) - - def test_returns_same_instance( - self, patch_central_database: MagicMock - ) -> None: - """Test that get_memory_service returns the same instance.""" - import pyrit.backend.services.memory_service as module - - module._memory_service = None - - service1 = get_memory_service() - service2 = get_memory_service() - - assert service1 is service2 diff --git a/tests/unit/backend/test_registry_service.py b/tests/unit/backend/test_registry_service.py index 95e964c280..0f8ac2b65b 100644 --- a/tests/unit/backend/test_registry_service.py +++ b/tests/unit/backend/test_registry_service.py @@ -5,8 +5,6 @@ Tests for backend registry service. """ - - from pyrit.backend.services.registry_service import ( RegistryService, _extract_params_schema, diff --git a/tests/unit/backend/test_routes.py b/tests/unit/backend/test_routes.py deleted file mode 100644 index 5a907bdedc..0000000000 --- a/tests/unit/backend/test_routes.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Tests for backend API routes. -""" - -from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from pyrit.backend.models.common import PaginatedResponse, PaginationInfo -from pyrit.backend.models.conversations import CreateConversationResponse -from pyrit.backend.models.memory import MessageQueryResponse -from pyrit.backend.routes import health, version - - -class TestHealthRoute: - """Tests for health check endpoint.""" - - @pytest.fixture - def client(self) -> TestClient: - """Create a test client for health routes. - - Returns: - TestClient: The test client. - """ - app = FastAPI() - app.include_router(health.router) - return TestClient(app) - - def test_health_returns_200(self, client: TestClient) -> None: - """Test that health endpoint returns 200.""" - response = client.get("/health") - - assert response.status_code == 200 - - def test_health_returns_healthy_status(self, client: TestClient) -> None: - """Test that health endpoint returns healthy status.""" - response = client.get("/health") - data = response.json() - - assert data["status"] == "healthy" - - def test_health_returns_service_name(self, client: TestClient) -> None: - """Test that health endpoint returns service name.""" - response = client.get("/health") - data = response.json() - - assert data["service"] == "pyrit-backend" - - def test_health_returns_timestamp(self, client: TestClient) -> None: - """Test that health endpoint returns timestamp.""" - response = client.get("/health") - data = response.json() - - assert "timestamp" in data - # Verify it's a valid ISO format timestamp - datetime.fromisoformat(data["timestamp"].replace("Z", "+00:00")) - - -class TestVersionRoute: - """Tests for version endpoint.""" - - @pytest.fixture - def client(self) -> TestClient: - """Create a test client for version routes. - - Returns: - TestClient: The test client. - """ - app = FastAPI() - app.include_router(version.router) - return TestClient(app) - - def test_version_returns_200(self, client: TestClient) -> None: - """Test that version endpoint returns 200.""" - response = client.get("/api/version") - - assert response.status_code == 200 - - def test_version_returns_version_string(self, client: TestClient) -> None: - """Test that version endpoint returns version string.""" - response = client.get("/api/version") - data = response.json() - - assert "version" in data - assert isinstance(data["version"], str) - - def test_version_returns_display_string(self, client: TestClient) -> None: - """Test that version endpoint returns display string.""" - response = client.get("/api/version") - data = response.json() - - assert "display" in data - assert isinstance(data["display"], str) - - -class TestRegistryRoutes: - """Tests for registry endpoints.""" - - @pytest.fixture - def mock_service(self) -> MagicMock: - """Create a mock registry service. - - Returns: - MagicMock: The mock service. - """ - return MagicMock() - - @pytest.fixture - def client(self) -> TestClient: - """Create a test client for registry routes. - - Returns: - TestClient: The test client. - """ - from pyrit.backend.routes import registry - - app = FastAPI() - app.include_router(registry.router) - return TestClient(app) - - def test_list_targets_returns_200(self, client: TestClient) -> None: - """Test that list targets returns 200.""" - mock_service = MagicMock() - mock_service.get_targets.return_value = [] - - with patch( - "pyrit.backend.routes.registry.get_registry_service", - return_value=mock_service, - ): - response = client.get("/registry/targets") - - assert response.status_code == 200 - - def test_list_targets_returns_list(self, client: TestClient) -> None: - """Test that list targets returns a list.""" - mock_service = MagicMock() - mock_service.get_targets.return_value = [] - - with patch( - "pyrit.backend.routes.registry.get_registry_service", - return_value=mock_service, - ): - response = client.get("/registry/targets") - data = response.json() - - assert isinstance(data, list) - - def test_list_converters_returns_200(self, client: TestClient) -> None: - """Test that list converters returns 200.""" - mock_service = MagicMock() - mock_service.get_converters.return_value = [] - - with patch( - "pyrit.backend.routes.registry.get_registry_service", - return_value=mock_service, - ): - response = client.get("/registry/converters") - - assert response.status_code == 200 - - def test_list_scorers_returns_200(self, client: TestClient) -> None: - """Test that list scorers returns 200.""" - mock_service = MagicMock() - mock_service.get_scorers.return_value = [] - - with patch( - "pyrit.backend.routes.registry.get_registry_service", - return_value=mock_service, - ): - response = client.get("/registry/scorers") - - assert response.status_code == 200 - - def test_list_scenarios_returns_200(self, client: TestClient) -> None: - """Test that list scenarios returns 200.""" - mock_service = MagicMock() - mock_service.get_scenarios.return_value = [] - - with patch( - "pyrit.backend.routes.registry.get_registry_service", - return_value=mock_service, - ): - response = client.get("/registry/scenarios") - - assert response.status_code == 200 - - -class TestConversationRoutes: - """Tests for conversation endpoints.""" - - @pytest.fixture - def client(self, patch_central_database: MagicMock) -> TestClient: - """Create a test client for conversation routes. - - Args: - patch_central_database: The patched central database fixture. - - Returns: - TestClient: The test client. - """ - from pyrit.backend.routes import conversations - - app = FastAPI() - app.include_router(conversations.router) - return TestClient(app) - - def test_create_conversation_returns_201(self, client: TestClient, patch_central_database: MagicMock) -> None: - """Test that create conversation returns 201.""" - mock_service = MagicMock() - mock_response = CreateConversationResponse( - conversation_id="test-id", - target_identifier={"__type__": "TextTarget"}, - labels=None, - created_at=datetime.now(), - ) - mock_service.create_conversation = AsyncMock(return_value=mock_response) - - with patch( - "pyrit.backend.routes.conversations.get_conversation_service", - return_value=mock_service, - ): - response = client.post( - "/conversations", - json={ - "target_class": "TextTarget", - "target_params": None, - }, - ) - - assert response.status_code == 201 - - def test_get_conversation_returns_404_for_missing( - self, client: TestClient, patch_central_database: MagicMock - ) -> None: - """Test that get conversation returns 404 for missing.""" - mock_service = MagicMock() - mock_service.get_conversation = AsyncMock(return_value=None) - - with patch( - "pyrit.backend.routes.conversations.get_conversation_service", - return_value=mock_service, - ): - response = client.get("/conversations/nonexistent") - - assert response.status_code == 404 - - def test_delete_conversation_returns_204(self, client: TestClient, patch_central_database: MagicMock) -> None: - """Test that delete conversation returns 204.""" - mock_service = MagicMock() - # Must return a conversation state for delete to work - mock_service.get_conversation = AsyncMock(return_value=MagicMock(conversation_id="conv-1")) - mock_service.cleanup_conversation = MagicMock() - - with patch( - "pyrit.backend.routes.conversations.get_conversation_service", - return_value=mock_service, - ): - response = client.delete("/conversations/conv-1") - - assert response.status_code == 204 - - -class TestMemoryRoutes: - """Tests for memory endpoints.""" - - @pytest.fixture - def client(self) -> TestClient: - """Create a test client for memory routes. - - Returns: - TestClient: The test client. - """ - from pyrit.backend.routes import memory - - app = FastAPI() - app.include_router(memory.router) - return TestClient(app) - - def test_query_messages_returns_200(self, client: TestClient) -> None: - """Test that query messages returns 200.""" - mock_service = MagicMock() - mock_response = PaginatedResponse[MessageQueryResponse]( - items=[], - pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), - ) - mock_service.get_messages = AsyncMock(return_value=mock_response) - - with patch( - "pyrit.backend.routes.memory.get_memory_service", - return_value=mock_service, - ): - response = client.get("/memory/messages") - - assert response.status_code == 200 - - def test_query_scores_returns_200(self, client: TestClient) -> None: - """Test that query scores returns 200.""" - mock_service = MagicMock() - mock_response = PaginatedResponse( - items=[], - pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), - ) - mock_service.get_scores = AsyncMock(return_value=mock_response) - - with patch( - "pyrit.backend.routes.memory.get_memory_service", - return_value=mock_service, - ): - response = client.get("/memory/scores") - - assert response.status_code == 200 - - def test_query_attack_results_returns_200(self, client: TestClient) -> None: - """Test that query attack results returns 200.""" - mock_service = MagicMock() - mock_response = PaginatedResponse( - items=[], - pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), - ) - mock_service.get_attack_results = AsyncMock(return_value=mock_response) - - with patch( - "pyrit.backend.routes.memory.get_memory_service", - return_value=mock_service, - ): - response = client.get("/memory/attack-results") - - assert response.status_code == 200 - - def test_query_seeds_returns_200(self, client: TestClient) -> None: - """Test that query seeds returns 200.""" - mock_service = MagicMock() - mock_response = PaginatedResponse( - items=[], - pagination=PaginationInfo(offset=0, limit=50, total=0, has_more=False), - ) - mock_service.get_seeds = AsyncMock(return_value=mock_response) - - with patch( - "pyrit.backend.routes.memory.get_memory_service", - return_value=mock_service, - ): - response = client.get("/memory/seeds") - - assert response.status_code == 200 diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py new file mode 100644 index 0000000000..226672b5fa --- /dev/null +++ b/tests/unit/backend/test_target_service.py @@ -0,0 +1,368 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend target service. +""" + +from datetime import datetime, timezone +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.backend.models.targets import CreateTargetRequest, TargetInstance +from pyrit.backend.services.target_service import TargetService + + +class TestTargetServiceInit: + """Tests for TargetService initialization.""" + + def test_init_creates_empty_instances_dict(self) -> None: + """Test that service initializes with empty instances dictionary.""" + service = TargetService() + assert service._instances == {} + + def test_init_creates_empty_target_objects_dict(self) -> None: + """Test that service initializes with empty target objects dictionary.""" + service = TargetService() + assert service._target_objects == {} + + +class TestListTargets: + """Tests for TargetService.list_targets method.""" + + @pytest.mark.asyncio + async def test_list_targets_returns_empty_when_no_targets(self) -> None: + """Test that list_targets returns empty list when no targets exist.""" + service = TargetService() + + result = await service.list_targets() + + assert result.items == [] + + @pytest.mark.asyncio + async def test_list_targets_returns_targets(self) -> None: + """Test that list_targets returns existing targets.""" + service = TargetService() + now = datetime.now(timezone.utc) + + service._instances["target-1"] = TargetInstance( + target_id="target-1", + type="TextTarget", + display_name="My Target", + params={}, + created_at=now, + source="user", + ) + + result = await service.list_targets() + + assert len(result.items) == 1 + assert result.items[0].target_id == "target-1" + assert result.items[0].display_name == "My Target" + + @pytest.mark.asyncio + async def test_list_targets_filters_by_source_user(self) -> None: + """Test that list_targets filters by source='user'.""" + service = TargetService() + now = datetime.now(timezone.utc) + + service._instances["target-1"] = TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=now, + source="user", + ) + service._instances["target-2"] = TargetInstance( + target_id="target-2", + type="TextTarget", + params={}, + created_at=now, + source="initializer", + ) + + result = await service.list_targets(source="user") + + assert len(result.items) == 1 + assert result.items[0].source == "user" + + @pytest.mark.asyncio + async def test_list_targets_filters_by_source_initializer(self) -> None: + """Test that list_targets filters by source='initializer'.""" + service = TargetService() + now = datetime.now(timezone.utc) + + service._instances["target-1"] = TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=now, + source="user", + ) + service._instances["target-2"] = TargetInstance( + target_id="target-2", + type="TextTarget", + params={}, + created_at=now, + source="initializer", + ) + + result = await service.list_targets(source="initializer") + + assert len(result.items) == 1 + assert result.items[0].source == "initializer" + + +class TestGetTarget: + """Tests for TargetService.get_target method.""" + + @pytest.mark.asyncio + async def test_get_target_returns_none_for_nonexistent(self) -> None: + """Test that get_target returns None for non-existent target.""" + service = TargetService() + + result = await service.get_target("nonexistent-id") + + assert result is None + + @pytest.mark.asyncio + async def test_get_target_returns_target(self) -> None: + """Test that get_target returns the target instance.""" + service = TargetService() + now = datetime.now(timezone.utc) + + service._instances["target-1"] = TargetInstance( + target_id="target-1", + type="TextTarget", + display_name="Test Target", + params={"key": "value"}, + created_at=now, + source="user", + ) + + result = await service.get_target("target-1") + + assert result is not None + assert result.target_id == "target-1" + assert result.display_name == "Test Target" + + +class TestGetTargetObject: + """Tests for TargetService.get_target_object method.""" + + def test_get_target_object_returns_none_for_nonexistent(self) -> None: + """Test that get_target_object returns None for non-existent target.""" + service = TargetService() + + result = service.get_target_object("nonexistent-id") + + assert result is None + + def test_get_target_object_returns_object(self) -> None: + """Test that get_target_object returns the actual target object.""" + service = TargetService() + mock_target = MagicMock() + service._target_objects["target-1"] = mock_target + + result = service.get_target_object("target-1") + + assert result is mock_target + + +class TestGetTargetClass: + """Tests for TargetService._get_target_class method.""" + + def test_get_target_class_raises_for_invalid_type(self) -> None: + """Test that _get_target_class raises ValueError for invalid type.""" + service = TargetService() + + with pytest.raises(ValueError, match="not found"): + service._get_target_class("NonExistentTarget") + + def test_get_target_class_finds_text_target(self) -> None: + """Test that _get_target_class finds TextTarget.""" + service = TargetService() + + # TextTarget should exist in pyrit.prompt_target + result = service._get_target_class("TextTarget") + + assert result is not None + assert "TextTarget" in result.__name__ + + +class TestCreateTarget: + """Tests for TargetService.create_target method.""" + + @pytest.mark.asyncio + async def test_create_target_raises_for_invalid_type(self) -> None: + """Test that create_target raises for invalid target type.""" + service = TargetService() + + request = CreateTargetRequest( + type="NonExistentTarget", + params={}, + ) + + with pytest.raises(ValueError, match="not found"): + await service.create_target(request) + + @pytest.mark.asyncio + async def test_create_target_success(self) -> None: + """Test successful target creation.""" + service = TargetService() + + # Use a target that doesn't require external dependencies + request = CreateTargetRequest( + type="TextTarget", + display_name="My Text Target", + params={}, + ) + + result = await service.create_target(request) + + assert result.target_id is not None + assert result.type == "TextTarget" + assert result.display_name == "My Text Target" + assert result.source == "user" + + @pytest.mark.asyncio + async def test_create_target_stores_instance(self) -> None: + """Test that create_target stores the instance.""" + service = TargetService() + + request = CreateTargetRequest( + type="TextTarget", + params={}, + ) + + result = await service.create_target(request) + + assert result.target_id in service._instances + assert result.target_id in service._target_objects + + @pytest.mark.asyncio + async def test_create_target_filters_sensitive_params(self) -> None: + """Test that create_target filters sensitive parameters.""" + service = TargetService() + + # Create a mock target class that has sensitive identifier fields + mock_target_class = MagicMock() + mock_target_instance = MagicMock() + mock_target_instance.get_identifier.return_value = { + "type": "MockTarget", + "api_key": "secret-key", + "endpoint": "https://api.example.com", + } + mock_target_class.return_value = mock_target_instance + + with patch.object(service, "_get_target_class", return_value=mock_target_class): + request = CreateTargetRequest( + type="MockTarget", + params={}, + ) + + result = await service.create_target(request) + + # api_key should be filtered out + assert "api_key" not in result.params + # endpoint should remain + assert result.params.get("endpoint") == "https://api.example.com" + + +class TestDeleteTarget: + """Tests for TargetService.delete_target method.""" + + @pytest.mark.asyncio + async def test_delete_target_returns_false_for_nonexistent(self) -> None: + """Test that delete_target returns False for non-existent target.""" + service = TargetService() + + result = await service.delete_target("nonexistent") + + assert result is False + + @pytest.mark.asyncio + async def test_delete_target_deletes_target(self) -> None: + """Test that delete_target removes the target.""" + service = TargetService() + now = datetime.now(timezone.utc) + + service._instances["target-1"] = TargetInstance( + target_id="target-1", + type="TextTarget", + params={}, + created_at=now, + source="user", + ) + service._target_objects["target-1"] = MagicMock() + + result = await service.delete_target("target-1") + + assert result is True + assert "target-1" not in service._instances + assert "target-1" not in service._target_objects + + +class TestRegisterInitializerTarget: + """Tests for TargetService.register_initializer_target method.""" + + @pytest.mark.asyncio + async def test_register_initializer_target_creates_instance(self) -> None: + """Test that register_initializer_target creates an instance.""" + service = TargetService() + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"type": "MockTarget"} + + result = await service.register_initializer_target( + target_type="MockTarget", + target_obj=mock_target, + display_name="Initializer Target", + ) + + assert result.target_id is not None + assert result.type == "MockTarget" + assert result.display_name == "Initializer Target" + assert result.source == "initializer" + + @pytest.mark.asyncio + async def test_register_initializer_target_stores_object(self) -> None: + """Test that register_initializer_target stores the target object.""" + service = TargetService() + mock_target = MagicMock() + mock_target.get_identifier.return_value = {} + + result = await service.register_initializer_target( + target_type="MockTarget", + target_obj=mock_target, + ) + + assert service._target_objects[result.target_id] is mock_target + + +class TestTargetServiceSingleton: + """Tests for get_target_service singleton function.""" + + def test_get_target_service_returns_target_service(self) -> None: + """Test that get_target_service returns a TargetService instance.""" + from pyrit.backend.services.target_service import get_target_service + + # Reset singleton for clean test + import pyrit.backend.services.target_service as module + module._target_service = None + + service = get_target_service() + assert isinstance(service, TargetService) + + def test_get_target_service_returns_same_instance(self) -> None: + """Test that get_target_service returns the same instance.""" + from pyrit.backend.services.target_service import get_target_service + + # Reset singleton for clean test + import pyrit.backend.services.target_service as module + module._target_service = None + + service1 = get_target_service() + service2 = get_target_service() + assert service1 is service2 From cef47dba1635c72c6ba72fb89209b516f5b96a31 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 28 Jan 2026 10:39:55 -0800 Subject: [PATCH 03/35] more refinements to simplify and remove obsolete parts --- frontend/playwright-report/index.html | 85 +++++++ frontend/test-results/.last-run.json | 4 + pyrit/backend/main.py | 2 - pyrit/backend/models/__init__.py | 8 +- pyrit/backend/models/attacks.py | 82 +++--- pyrit/backend/routes/attacks.py | 25 +- pyrit/backend/services/attack_service.py | 230 ++++++++--------- pyrit/backend/services/converter_service.py | 25 +- pyrit/backend/services/target_service.py | 13 +- tests/unit/backend/test_api_routes.py | 81 +++--- tests/unit/backend/test_attack_service.py | 248 +++++++++++++++---- tests/unit/backend/test_converter_service.py | 10 +- tests/unit/backend/test_target_service.py | 9 +- 13 files changed, 515 insertions(+), 307 deletions(-) create mode 100644 frontend/playwright-report/index.html create mode 100644 frontend/test-results/.last-run.json diff --git a/frontend/playwright-report/index.html b/frontend/playwright-report/index.html new file mode 100644 index 0000000000..f6ad474f7f --- /dev/null +++ b/frontend/playwright-report/index.html @@ -0,0 +1,85 @@ + + + + + + + + + Playwright Test Report + + + + +
+ + + \ No newline at end of file diff --git a/frontend/test-results/.last-run.json b/frontend/test-results/.last-run.json new file mode 100644 index 0000000000..cbcc1fbac1 --- /dev/null +++ b/frontend/test-results/.last-run.json @@ -0,0 +1,4 @@ +{ + "status": "passed", + "failedTests": [] +} \ No newline at end of file diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index 49bc824397..b6bb6988dd 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -3,8 +3,6 @@ """ FastAPI application entry point for PyRIT backend. - -This is the attack-centric API - all interactions are modeled as "attacks". """ import os diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 191f6c85cb..3ec408399a 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -8,6 +8,8 @@ """ from pyrit.backend.models.attacks import ( + AddMessageRequest, + AddMessageResponse, AttackDetail, AttackListResponse, AttackSummary, @@ -18,8 +20,6 @@ MessagePieceRequest, PrependedMessageRequest, Score, - SendMessageRequest, - SendMessageResponse, UpdateAttackRequest, ) from pyrit.backend.models.common import ( @@ -65,6 +65,8 @@ __all__ = [ # Attacks + "AddMessageRequest", + "AddMessageResponse", "AttackDetail", "AttackListResponse", "AttackSummary", @@ -75,8 +77,6 @@ "MessagePieceRequest", "PrependedMessageRequest", "Score", - "SendMessageRequest", - "SendMessageResponse", "UpdateAttackRequest", # Common "ALLOWED_IDENTIFIER_FIELDS", diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index fd0d851d84..5d282ad28b 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -9,12 +9,12 @@ """ from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.converters import InlineConverterConfig +from pyrit.models import PromptResponseError class Score(BaseModel): @@ -36,12 +36,18 @@ class MessagePiece(BaseModel): """ piece_id: str = Field(..., description="Unique piece identifier") - data_type: str = Field(..., description="Data type: 'text', 'image', 'audio', 'video', etc.") - original_value: Optional[str] = Field(None, description="Original value before conversion") - original_value_mime_type: Optional[str] = Field(None, description="MIME type of original value") + data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', 'video', etc.") + original_value: Optional[str] = Field(default=None, description="Original value before conversion") + original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") converted_value: str = Field(..., description="Converted value (text or base64 for media)") - converted_value_mime_type: Optional[str] = Field(None, description="MIME type of converted value") + converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of converted value") scores: List[Score] = Field(default_factory=list, description="Scores embedded in this piece") + response_error: PromptResponseError = Field( + default="none", description="Error status: none, processing, blocked, empty, unknown" + ) + response_error_description: Optional[str] = Field( + default=None, description="Description of the error if response_error is not 'none'" + ) class Message(BaseModel): @@ -73,6 +79,7 @@ class AttackSummary(BaseModel): None, description="Preview of the last message (truncated to ~100 chars)" ) message_count: int = Field(0, description="Total number of messages in the attack") + labels: Dict[str, str] = Field(default_factory=dict, description="User-defined labels for filtering") created_at: datetime = Field(..., description="Attack creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") @@ -89,14 +96,12 @@ class AttackDetail(BaseModel): name: Optional[str] = Field(None, description="Attack name/label") target_id: str = Field(..., description="Target instance ID") target_type: str = Field(..., description="Target type (e.g., 'azure_openai')") - outcome: Optional[Literal["pending", "success", "failure"]] = Field( - None, description="Attack outcome" - ) + outcome: Optional[Literal["pending", "success", "failure"]] = Field(None, description="Attack outcome") prepended_conversation: List[Message] = Field( default_factory=list, description="Prepended messages (system prompts, branching context)" ) messages: List[Message] = Field(default_factory=list, description="Attack messages in order") - converter_ids: List[str] = Field(default_factory=list, description="Converter instance IDs applied") + labels: Dict[str, str] = Field(default_factory=dict, description="User-defined labels for filtering") created_at: datetime = Field(..., description="Attack creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") @@ -133,24 +138,14 @@ class CreateAttackRequest(BaseModel): prepended_conversation: Optional[List[PrependedMessageRequest]] = Field( None, description="Messages to prepend (system prompts, branching context)" ) - converter_ids: Optional[List[str]] = Field( - None, description="Converter instance IDs to apply to user messages" - ) + labels: Optional[Dict[str, str]] = Field(None, description="User-defined labels for filtering") class CreateAttackResponse(BaseModel): """Response after creating an attack.""" attack_id: str = Field(..., description="Unique attack identifier") - name: Optional[str] = Field(None, description="Attack name/label") - target_id: str = Field(..., description="Target instance ID") - target_type: str = Field(..., description="Target type") - outcome: Optional[str] = Field(None, description="Attack outcome (initially null)") - prepended_conversation: List[Message] = Field( - default_factory=list, description="Prepended messages (converted to Message format)" - ) created_at: datetime = Field(..., description="Attack creation timestamp") - updated_at: datetime = Field(..., description="Last update timestamp") # ============================================================================ @@ -161,39 +156,50 @@ class CreateAttackResponse(BaseModel): class UpdateAttackRequest(BaseModel): """Request to update an attack's outcome.""" - outcome: Literal["pending", "success", "failure"] = Field( - ..., description="Updated attack outcome" - ) + outcome: Literal["pending", "success", "failure"] = Field(..., description="Updated attack outcome") # ============================================================================ -# Send Message +# Add Message # ============================================================================ class MessagePieceRequest(BaseModel): - """A piece of content to send in a message.""" + """A piece of content for a message.""" data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") - content: str = Field(..., description="Content to send (text or base64 for media)") + content: str = Field(..., description="Content (text or base64 for media)") mime_type: Optional[str] = Field(None, description="MIME type for media content") -class SendMessageRequest(BaseModel): - """Request to send a message within an attack.""" +class AddMessageRequest(BaseModel): + """ + Request to add a message to an attack. + + If send=True (default for user role), the message is sent to the target + and we wait for a response. If send=False, the message is just stored + in memory without sending (useful for system messages, context injection). + """ - pieces: List[MessagePieceRequest] = Field(..., description="Message pieces to send") + role: Literal["user", "assistant", "system"] = Field(default="user", description="Message role") + pieces: List[MessagePieceRequest] = Field(..., description="Message pieces") + send: bool = Field( + default=True, + description="If True, send to target and wait for response. If False, just store in memory.", + ) converter_ids: Optional[List[str]] = Field( None, description="Converter instance IDs to apply (overrides attack-level)" ) - converters: Optional[List[InlineConverterConfig]] = Field( - None, description="Inline converter definitions (for one-off use)" - ) -class SendMessageResponse(BaseModel): - """Response after sending a message.""" +class AddMessageResponse(BaseModel): + """ + Response after adding a message. + + Returns the updated attack detail. If send=True was used, the new + assistant response will be in the messages list. Check response_error + on the assistant's message pieces if the target returned an error. + """ - user_message: Message = Field(..., description="The user message that was sent") - assistant_message: Message = Field(..., description="The assistant's response") - attack_summary: AttackSummary = Field(..., description="Updated attack summary") + attack: AttackDetail = Field(..., description="Updated attack with new message(s)") + error: Optional[str] = Field(None, description="Transport-level error if send=True and request failed entirely") diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 23d6a316da..e6478d0cbf 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -13,12 +13,12 @@ from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.attacks import ( + AddMessageRequest, + AddMessageResponse, AttackDetail, AttackListResponse, CreateAttackRequest, CreateAttackResponse, - SendMessageRequest, - SendMessageResponse, UpdateAttackRequest, ) from pyrit.backend.models.common import ProblemDetail @@ -147,21 +147,22 @@ async def update_attack( @router.post( "/{attack_id}/messages", - response_model=SendMessageResponse, + response_model=AddMessageResponse, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, 400: {"model": ProblemDetail, "description": "Message send failed"}, }, ) -async def send_message( +async def add_message( attack_id: str, - request: SendMessageRequest, -) -> SendMessageResponse: + request: AddMessageRequest, +) -> AddMessageResponse: """ - Send a message in an attack. + Add a message to an attack. - Sends the user message to the target, applies converters, and returns - both the user message and assistant response. + If send=True (default), sends the message to the target and waits for a response. + If send=False, just stores the message in memory without sending (useful for + system messages, context injection, or replaying assistant responses). Converters can be specified at three levels (in priority order): 1. request.converter_ids - per-message converter instances @@ -169,12 +170,12 @@ async def send_message( 3. attack.converter_ids - attack-level defaults Returns: - SendMessageResponse: The sent message and assistant response. + AddMessageResponse: Updated attack with new message(s). """ service = get_attack_service() try: - return await service.send_message(attack_id, request) + return await service.add_message(attack_id, request) except ValueError as e: error_msg = str(e) if "not found" in error_msg.lower(): @@ -189,7 +190,7 @@ async def send_message( except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to send message: {str(e)}", + detail=f"Failed to add message: {str(e)}", ) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 1c8ece7976..11eaed10bb 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -11,11 +11,13 @@ import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional, cast from pydantic import BaseModel from pyrit.backend.models.attacks import ( + AddMessageRequest, + AddMessageResponse, AttackDetail, AttackListResponse, AttackSummary, @@ -23,11 +25,6 @@ CreateAttackResponse, Message, MessagePiece, - MessagePieceRequest, - PrependedMessageRequest, - Score, - SendMessageRequest, - SendMessageResponse, UpdateAttackRequest, ) from pyrit.backend.models.common import PaginationInfo @@ -45,7 +42,6 @@ class AttackState(BaseModel): target_type: str outcome: Optional[Literal["pending", "success", "failure"]] = None prepended_conversation: List[Message] = [] - converter_ids: List[str] = [] message_count: int = 0 created_at: datetime updated_at: datetime @@ -163,7 +159,6 @@ async def get_attack(self, attack_id: str) -> Optional[AttackDetail]: outcome=state.outcome, prepended_conversation=state.prepended_conversation, messages=messages, - converter_ids=state.converter_ids, created_at=state.created_at, updated_at=state.updated_at, ) @@ -212,13 +207,6 @@ async def create_attack( ) prepended_messages.append(msg) - # Validate converter IDs if provided - if request.converter_ids: - converter_service = get_converter_service() - for conv_id in request.converter_ids: - if await converter_service.get_converter(conv_id) is None: - raise ValueError(f"Converter instance '{conv_id}' not found") - state = AttackState( attack_id=attack_id, name=request.name, @@ -226,7 +214,6 @@ async def create_attack( target_type=target_instance.type, outcome=None, prepended_conversation=prepended_messages, - converter_ids=request.converter_ids or [], message_count=0, created_at=now, updated_at=now, @@ -235,13 +222,7 @@ async def create_attack( return CreateAttackResponse( attack_id=attack_id, - name=request.name, - target_id=request.target_id, - target_type=target_instance.type, - outcome=None, - prepended_conversation=prepended_messages, created_at=now, - updated_at=now, ) async def update_attack( @@ -268,20 +249,23 @@ async def update_attack( return await self.get_attack(attack_id) - async def send_message( + async def add_message( self, attack_id: str, - request: SendMessageRequest, - ) -> SendMessageResponse: + request: AddMessageRequest, + ) -> AddMessageResponse: """ - Send a message in an attack and get response. + Add a message to an attack. + + If send=True, sends to target and waits for response. + If send=False, just stores the message in memory. Args: attack_id: Attack ID - request: Message send request + request: Add message request Returns: - SendMessageResponse: User and assistant messages + AddMessageResponse: Updated attack detail """ state = self._attacks.get(attack_id) if not state: @@ -290,25 +274,17 @@ async def send_message( target_service = get_target_service() converter_service = get_converter_service() - target_obj = target_service.get_target_object(state.target_id) - if not target_obj: - raise ValueError(f"Target object for '{state.target_id}' not found") - now = datetime.now(timezone.utc) state.message_count += 1 - user_turn = state.message_count + msg_turn = state.message_count - # Determine which converters to use + # Determine which converters to use (only for user messages being sent) converters = [] - if request.converter_ids: + if request.send and request.role == "user" and request.converter_ids: converters = converter_service.get_converter_objects_for_ids(request.converter_ids) - elif request.converters: - converters = converter_service.instantiate_inline_converters(request.converters) - elif state.converter_ids: - converters = converter_service.get_converter_objects_for_ids(state.converter_ids) - # Build user message pieces - user_pieces: List[MessagePiece] = [] + # Build message pieces + msg_pieces: List[MessagePiece] = [] for piece_req in request.pieces: original_value = piece_req.content converted_value = original_value @@ -318,7 +294,7 @@ async def send_message( result = await converter.convert_async(prompt=converted_value) converted_value = result.output_text - user_pieces.append( + msg_pieces.append( MessagePiece( piece_id=str(uuid.uuid4()), data_type=piece_req.data_type, @@ -330,103 +306,102 @@ async def send_message( ) ) - user_message = Message( + message = Message( message_id=str(uuid.uuid4()), - turn_number=user_turn, - role="user", - pieces=user_pieces, + turn_number=msg_turn, + role=request.role, + pieces=msg_pieces, created_at=now, ) - # Store user message - self._messages[attack_id].append(user_message) - - # Build conversation for target (prepended + all messages) - from pyrit.models import Message as PyritMessage, MessagePiece as PyritMessagePiece - - # Create prompt pieces for target - user_prompt_pieces = [] - for piece in user_pieces: - pyrit_piece = PyritMessagePiece( - role="user", - original_value=piece.original_value or "", - original_value_data_type=piece.data_type, - converted_value=piece.converted_value, - converted_value_data_type=piece.data_type, - conversation_id=attack_id, - sequence=user_turn, - ) - user_prompt_pieces.append(pyrit_piece) - - user_pyrit_message = PyritMessage(user_prompt_pieces) - - # Send to target - response_messages = await target_obj.send_prompt_async(message=user_pyrit_message) - - # Build assistant response - state.message_count += 1 - assistant_turn = state.message_count - - assistant_pieces: List[MessagePiece] = [] - if response_messages: - for resp_msg in response_messages: - for resp_piece in resp_msg.message_pieces: - assistant_pieces.append( + # Store the message + self._messages[attack_id].append(message) + + # If send=True, send to target and get response + transport_error: Optional[str] = None + if request.send: + target_obj = target_service.get_target_object(state.target_id) + if not target_obj: + raise ValueError(f"Target object for '{state.target_id}' not found") + + try: + # Build conversation for target + from pyrit.models import Message as PyritMessage + from pyrit.models import MessagePiece as PyritMessagePiece + from pyrit.models import PromptDataType + + # Create prompt pieces for target + prompt_pieces = [] + for piece in msg_pieces: + pyrit_piece = PyritMessagePiece( + role=request.role, + original_value=piece.original_value or "", + original_value_data_type=cast(PromptDataType, piece.data_type), + converted_value=piece.converted_value, + converted_value_data_type=cast(PromptDataType, piece.data_type), + conversation_id=attack_id, + sequence=msg_turn, + ) + prompt_pieces.append(pyrit_piece) + + pyrit_message = PyritMessage(prompt_pieces) + + # Send to target + response_messages = await target_obj.send_prompt_async(message=pyrit_message) + + # Build assistant response + state.message_count += 1 + assistant_turn = state.message_count + + assistant_pieces: List[MessagePiece] = [] + if response_messages: + for resp_msg in response_messages: + for resp_piece in resp_msg.message_pieces: + assistant_pieces.append( + MessagePiece( + piece_id=str(uuid.uuid4()), + data_type=resp_piece.converted_value_data_type or "text", + original_value=resp_piece.original_value, + converted_value=resp_piece.converted_value or "", + scores=[], + response_error=resp_piece.response_error, + ) + ) + + assistant_message = Message( + message_id=str(uuid.uuid4()), + turn_number=assistant_turn, + role="assistant", + pieces=assistant_pieces + if assistant_pieces + else [ MessagePiece( piece_id=str(uuid.uuid4()), - data_type=resp_piece.converted_value_data_type or "text", - original_value=resp_piece.original_value, - converted_value=resp_piece.converted_value or "", + data_type="text", + converted_value="", scores=[], ) - ) - - assistant_message = Message( - message_id=str(uuid.uuid4()), - turn_number=assistant_turn, - role="assistant", - pieces=assistant_pieces if assistant_pieces else [ - MessagePiece( - piece_id=str(uuid.uuid4()), - data_type="text", - converted_value="", - scores=[], + ], + created_at=datetime.now(timezone.utc), ) - ], - created_at=datetime.now(timezone.utc), - ) - # Store assistant message - self._messages[attack_id].append(assistant_message) + # Store assistant message + self._messages[attack_id].append(assistant_message) + + except Exception as e: + transport_error = str(e) # Update attack timestamp state.updated_at = datetime.now(timezone.utc) - # Build summary - messages = self._messages[attack_id] - last_message_preview = None - if messages: - last_msg = messages[-1] - if last_msg.pieces: - preview_text = last_msg.pieces[0].converted_value - last_message_preview = preview_text[:100] + "..." if len(preview_text) > 100 else preview_text - - attack_summary = AttackSummary( - attack_id=state.attack_id, - name=state.name, - target_id=state.target_id, - target_type=state.target_type, - outcome=state.outcome, - last_message_preview=last_message_preview, - message_count=len(messages), - created_at=state.created_at, - updated_at=state.updated_at, - ) + # Get updated attack detail + attack_detail = await self.get_attack(attack_id) + if attack_detail is None: + raise ValueError(f"Attack '{attack_id}' not found after update") - return SendMessageResponse( - user_message=user_message, - assistant_message=assistant_message, - attack_summary=attack_summary, + return AddMessageResponse( + attack=attack_detail, + error=transport_error, ) async def delete_attack(self, attack_id: str) -> bool: @@ -451,7 +426,12 @@ async def delete_attack(self, attack_id: str) -> bool: def get_attack_service() -> AttackService: - """Get the global attack service instance.""" + """ + Get the global attack service instance. + + Returns: + AttackService: The singleton attack service instance. + """ global _attack_service if _attack_service is None: _attack_service = AttackService() diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index af3e96133b..c7afb8b0da 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -10,7 +10,7 @@ import importlib import uuid from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, cast from pyrit.backend.models.converters import ( ConverterInstance, @@ -50,7 +50,7 @@ def _get_converter_class(self, converter_type: str) -> type: # Try direct attribute lookup first cls = getattr(module, converter_type, None) if cls is not None: - return cls + return cast(type, cls) # Try common class name patterns class_name_patterns = [ @@ -63,7 +63,7 @@ def _get_converter_class(self, converter_type: str) -> type: for pattern in class_name_patterns: cls = getattr(module, pattern, None) if cls is not None: - return cls + return cast(type, cls) raise ValueError(f"Converter type '{converter_type}' not found in pyrit.prompt_converter") @@ -91,9 +91,7 @@ def _create_converter_recursive( nested_config = params["converter"] if "type" in nested_config: # Recursively create nested converter - nested_id, nested_obj, nested_instances = self._create_converter_recursive( - nested_config, source - ) + nested_id, nested_obj, nested_instances = self._create_converter_recursive(nested_config, source) created_instances.extend(nested_instances) # Replace inline config with the actual converter object params["converter"] = nested_obj @@ -193,9 +191,7 @@ async def create_converter( "params": request.params, } - converter_id, converter_obj, created_instances = self._create_converter_recursive( - config, "user" - ) + converter_id, converter_obj, created_instances = self._create_converter_recursive(config, "user") # Update display name for the outermost converter if request.display_name and converter_id in self._instances: @@ -312,9 +308,7 @@ def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: converters.append(conv_obj) return converters - def instantiate_inline_converters( - self, configs: List[InlineConverterConfig] - ) -> List[Any]: + def instantiate_inline_converters(self, configs: List[InlineConverterConfig]) -> List[Any]: """ Instantiate converters from inline configurations. @@ -337,7 +331,12 @@ def instantiate_inline_converters( def get_converter_service() -> ConverterService: - """Get the global converter service instance.""" + """ + Get the global converter service instance. + + Returns: + ConverterService: The singleton converter service instance. + """ global _converter_service if _converter_service is None: _converter_service = ConverterService() diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index b58ef93985..7b847e676f 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -10,7 +10,7 @@ import importlib import uuid from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, Literal, Optional, cast from pyrit.backend.models.common import filter_sensitive_fields from pyrit.backend.models.targets import ( @@ -48,7 +48,7 @@ def _get_target_class(self, target_type: str) -> type: # First try direct attribute lookup cls = getattr(module, target_type, None) if cls is not None: - return cls + return cast(type, cls) # Try common class name patterns class_name_patterns = [ @@ -61,7 +61,7 @@ def _get_target_class(self, target_type: str) -> type: for pattern in class_name_patterns: cls = getattr(module, pattern, None) if cls is not None: - return cls + return cast(type, cls) raise ValueError(f"Target type '{target_type}' not found in pyrit.prompt_target") @@ -215,7 +215,12 @@ async def register_initializer_target( def get_target_service() -> TargetService: - """Get the global target service instance.""" + """ + Get the global target service instance. + + Returns: + The singleton TargetService instance. + """ global _target_service if _target_service is None: _target_service = TargetService() diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 12534cc891..e02880c7e5 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -15,13 +15,13 @@ from pyrit.backend.main import app from pyrit.backend.models.attacks import ( + AddMessageResponse, AttackDetail, AttackListResponse, AttackSummary, CreateAttackResponse, Message, MessagePiece, - SendMessageResponse, ) from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.converters import ( @@ -104,13 +104,7 @@ def test_create_attack_success(self, client: TestClient) -> None: mock_service.create_attack = AsyncMock( return_value=CreateAttackResponse( attack_id="attack-1", - name="Test Attack", - target_id="target-1", - target_type="TextTarget", - outcome=None, - prepended_conversation=[], created_at=now, - updated_at=now, ) ) mock_get_service.return_value = mock_service @@ -155,7 +149,6 @@ def test_get_attack_success(self, client: TestClient) -> None: outcome=None, prepended_conversation=[], messages=[], - converter_ids=[], created_at=now, updated_at=now, ) @@ -193,7 +186,6 @@ def test_update_attack_success(self, client: TestClient) -> None: outcome="success", prepended_conversation=[], messages=[], - converter_ids=[], created_at=now, updated_at=now, ) @@ -231,54 +223,49 @@ def test_delete_attack_not_found(self, client: TestClient) -> None: assert response.status_code == status.HTTP_404_NOT_FOUND - def test_send_message_success(self, client: TestClient) -> None: - """Test sending a message in an attack.""" + def test_add_message_success(self, client: TestClient) -> None: + """Test adding a message to an attack.""" now = datetime.now(timezone.utc) - user_msg = Message( - message_id="msg-1", - turn_number=1, - role="user", - pieces=[ - MessagePiece( - piece_id="piece-1", - data_type="text", - converted_value="Hello", - scores=[], - ) - ], - created_at=now, - ) - assistant_msg = Message( - message_id="msg-2", - turn_number=2, - role="assistant", - pieces=[ - MessagePiece( - piece_id="piece-2", - data_type="text", - converted_value="Hi there!", - scores=[], - ) - ], - created_at=now, - ) - summary = AttackSummary( + attack_detail = AttackDetail( attack_id="attack-1", target_id="target-1", target_type="TextTarget", - message_count=2, + messages=[ + Message( + message_id="msg-1", + turn_number=1, + role="user", + pieces=[ + MessagePiece( + piece_id="piece-1", + converted_value="Hello", + ) + ], + created_at=now, + ), + Message( + message_id="msg-2", + turn_number=2, + role="assistant", + pieces=[ + MessagePiece( + piece_id="piece-2", + converted_value="Hi there!", + ) + ], + created_at=now, + ), + ], created_at=now, updated_at=now, ) with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.send_message = AsyncMock( - return_value=SendMessageResponse( - user_message=user_msg, - assistant_message=assistant_msg, - attack_summary=summary, + mock_service.add_message = AsyncMock( + return_value=AddMessageResponse( + attack=attack_detail, ) ) mock_get_service.return_value = mock_service @@ -290,7 +277,7 @@ def test_send_message_success(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["user_message"]["pieces"][0]["converted_value"] == "Hello" + assert len(data["attack"]["messages"]) == 2 # ============================================================================ diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 55c782e3ba..b9379a0303 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -12,10 +12,10 @@ import pytest from pyrit.backend.models.attacks import ( + AddMessageRequest, CreateAttackRequest, MessagePieceRequest, PrependedMessageRequest, - SendMessageRequest, UpdateAttackRequest, ) from pyrit.backend.models.targets import TargetInstance @@ -258,9 +258,14 @@ async def test_create_attack_success(self) -> None: result = await service.create_attack(request) assert result.attack_id is not None - assert result.name == "My Attack" - assert result.target_id == "target-1" - assert result.target_type == "TextTarget" + assert result.created_at is not None + + # Verify the attack was stored correctly by fetching it + attack = await service.get_attack(result.attack_id) + assert attack is not None + assert attack.name == "My Attack" + assert attack.target_id == "target-1" + assert attack.target_type == "TextTarget" @pytest.mark.asyncio async def test_create_attack_with_prepended_conversation(self) -> None: @@ -291,43 +296,11 @@ async def test_create_attack_with_prepended_conversation(self) -> None: result = await service.create_attack(request) - assert len(result.prepended_conversation) == 1 - assert result.prepended_conversation[0].role == "system" - - @pytest.mark.asyncio - async def test_create_attack_validates_converter_ids(self) -> None: - """Test that create_attack validates converter IDs exist.""" - service = AttackService() - - mock_target = TargetInstance( - target_id="target-1", - type="TextTarget", - params={}, - created_at=datetime.now(timezone.utc), - source="user", - ) - - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: - mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=mock_target) - mock_get_target_service.return_value = mock_target_service - - with patch( - "pyrit.backend.services.attack_service.get_converter_service" - ) as mock_get_converter_service: - mock_converter_service = MagicMock() - mock_converter_service.get_converter = AsyncMock(return_value=None) - mock_get_converter_service.return_value = mock_converter_service - - request = CreateAttackRequest( - target_id="target-1", - converter_ids=["nonexistent-converter"], - ) - - with pytest.raises(ValueError, match="Converter instance"): - await service.create_attack(request) + # Verify the attack was stored with prepended conversation + attack = await service.get_attack(result.attack_id) + assert attack is not None + assert len(attack.prepended_conversation) == 1 + assert attack.prepended_conversation[0].role == "system" @pytest.mark.usefixtures("patch_central_database") @@ -419,24 +392,24 @@ async def test_delete_attack_removes_messages(self) -> None: @pytest.mark.usefixtures("patch_central_database") -class TestSendMessage: - """Tests for AttackService.send_message method.""" +class TestAddMessage: + """Tests for AttackService.add_message method.""" @pytest.mark.asyncio - async def test_send_message_raises_for_nonexistent_attack(self) -> None: - """Test that send_message raises ValueError for non-existent attack.""" + async def test_add_message_raises_for_nonexistent_attack(self) -> None: + """Test that add_message raises ValueError for non-existent attack.""" service = AttackService() - request = SendMessageRequest( + request = AddMessageRequest( pieces=[MessagePieceRequest(content="Hello")], ) with pytest.raises(ValueError, match="Attack"): - await service.send_message("nonexistent", request) + await service.add_message("nonexistent", request) @pytest.mark.asyncio - async def test_send_message_raises_for_missing_target_object(self) -> None: - """Test that send_message raises when target object is not found.""" + async def test_add_message_with_send_raises_for_missing_target_object(self) -> None: + """Test that add_message with send=True raises when target object is not found.""" service = AttackService() now = datetime.now(timezone.utc) @@ -461,12 +434,185 @@ async def test_send_message_raises_for_missing_target_object(self) -> None: mock_converter_service = MagicMock() mock_get_converter_service.return_value = mock_converter_service - request = SendMessageRequest( + request = AddMessageRequest( pieces=[MessagePieceRequest(content="Hello")], + send=True, ) with pytest.raises(ValueError, match="Target object"): - await service.send_message("test-id", request) + await service.add_message("test-id", request) + + @pytest.mark.asyncio + async def test_add_message_without_send_stores_message(self) -> None: + """Test that add_message with send=False just stores the message.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_get_target_service.return_value = mock_target_service + + with patch( + "pyrit.backend.services.attack_service.get_converter_service" + ) as mock_get_converter_service: + mock_converter_service = MagicMock() + mock_get_converter_service.return_value = mock_converter_service + + request = AddMessageRequest( + role="system", + pieces=[MessagePieceRequest(content="You are a helpful assistant.")], + send=False, + ) + + result = await service.add_message("test-id", request) + + assert result.attack is not None + assert len(result.attack.messages) == 1 + assert result.attack.messages[0].role == "system" + assert result.error is None + + @pytest.mark.asyncio + async def test_add_message_with_converter_ids_applies_converters(self) -> None: + """Test that add_message with converter_ids applies the converters.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + # Create mock converter + mock_converter = MagicMock() + mock_converter.convert_async = AsyncMock( + return_value=MagicMock(output_text="converted text") + ) + + # Create mock target + mock_target = AsyncMock() + mock_target.send_prompt_async = AsyncMock( + return_value=MagicMock( + request_pieces=[ + MagicMock( + original_value="assistant response", + converted_value="assistant response", + original_value_data_type="text", + ) + ], + response_error_description="none", + ) + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target_object.return_value = mock_target + mock_get_target_service.return_value = mock_target_service + + with patch( + "pyrit.backend.services.attack_service.get_converter_service" + ) as mock_get_converter_service: + mock_converter_service = MagicMock() + mock_converter_service.get_converter_objects_for_ids.return_value = [ + mock_converter + ] + mock_get_converter_service.return_value = mock_converter_service + + request = AddMessageRequest( + role="user", + pieces=[MessagePieceRequest(content="Hello")], + send=True, + converter_ids=["converter-1"], + ) + + result = await service.add_message("test-id", request) + + # Verify converter was applied + mock_converter_service.get_converter_objects_for_ids.assert_called_once_with( + ["converter-1"] + ) + mock_converter.convert_async.assert_called_once_with(prompt="Hello") + + # Verify message was converted + assert result.attack is not None + # First message is the user message with conversion + user_msg = result.attack.messages[0] + assert user_msg.role == "user" + assert user_msg.pieces[0].converted_value == "converted text" + + @pytest.mark.asyncio + async def test_add_message_without_converter_ids_does_not_apply_converters(self) -> None: + """Test that add_message without converter_ids does not apply any converters.""" + service = AttackService() + now = datetime.now(timezone.utc) + + service._attacks["test-id"] = AttackState( + attack_id="test-id", + target_id="target-1", + target_type="TextTarget", + created_at=now, + updated_at=now, + ) + + # Create mock target + mock_target = AsyncMock() + mock_target.send_prompt_async = AsyncMock( + return_value=MagicMock( + request_pieces=[ + MagicMock( + original_value="response", + converted_value="response", + original_value_data_type="text", + ) + ], + response_error_description="none", + ) + ) + + with patch( + "pyrit.backend.services.attack_service.get_target_service" + ) as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target_object.return_value = mock_target + mock_get_target_service.return_value = mock_target_service + + with patch( + "pyrit.backend.services.attack_service.get_converter_service" + ) as mock_get_converter_service: + mock_converter_service = MagicMock() + mock_get_converter_service.return_value = mock_converter_service + + # No converter_ids in request - should not apply any converters + request = AddMessageRequest( + role="user", + pieces=[MessagePieceRequest(content="Hello")], + send=True, + ) + + result = await service.add_message("test-id", request) + + # Verify no converter lookup was done + mock_converter_service.get_converter_objects_for_ids.assert_not_called() + + # Verify original value equals converted value (no conversion) + assert result.attack is not None + user_msg = result.attack.messages[0] + assert user_msg.pieces[0].original_value == "Hello" + assert user_msg.pieces[0].converted_value == "Hello" @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 5a8c45a276..e8ff4a5aad 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -6,7 +6,6 @@ """ from datetime import datetime, timezone -from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -18,7 +17,6 @@ InlineConverterConfig, ) from pyrit.backend.services.converter_service import ConverterService -from pyrit.models import PromptDataType class TestConverterServiceInit: @@ -516,10 +514,10 @@ class TestConverterServiceSingleton: def test_get_converter_service_returns_converter_service(self) -> None: """Test that get_converter_service returns a ConverterService instance.""" - from pyrit.backend.services.converter_service import get_converter_service - # Reset singleton for clean test import pyrit.backend.services.converter_service as module + from pyrit.backend.services.converter_service import get_converter_service + module._converter_service = None service = get_converter_service() @@ -527,10 +525,10 @@ def test_get_converter_service_returns_converter_service(self) -> None: def test_get_converter_service_returns_same_instance(self) -> None: """Test that get_converter_service returns the same instance.""" - from pyrit.backend.services.converter_service import get_converter_service - # Reset singleton for clean test import pyrit.backend.services.converter_service as module + from pyrit.backend.services.converter_service import get_converter_service + module._converter_service = None service1 = get_converter_service() diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 226672b5fa..a95aed843d 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -6,7 +6,6 @@ """ from datetime import datetime, timezone -from typing import Any, Dict from unittest.mock import MagicMock, patch import pytest @@ -346,10 +345,10 @@ class TestTargetServiceSingleton: def test_get_target_service_returns_target_service(self) -> None: """Test that get_target_service returns a TargetService instance.""" - from pyrit.backend.services.target_service import get_target_service - # Reset singleton for clean test import pyrit.backend.services.target_service as module + from pyrit.backend.services.target_service import get_target_service + module._target_service = None service = get_target_service() @@ -357,10 +356,10 @@ def test_get_target_service_returns_target_service(self) -> None: def test_get_target_service_returns_same_instance(self) -> None: """Test that get_target_service returns the same instance.""" - from pyrit.backend.services.target_service import get_target_service - # Reset singleton for clean test import pyrit.backend.services.target_service as module + from pyrit.backend.services.target_service import get_target_service + module._target_service = None service1 = get_target_service() From a06e9b843d7d71e85631b562b6cdc54a8aef2fb2 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 28 Jan 2026 11:00:57 -0800 Subject: [PATCH 04/35] explicitly remove playwright-report --- frontend/playwright-report/index.html | 85 --------------------------- 1 file changed, 85 deletions(-) delete mode 100644 frontend/playwright-report/index.html diff --git a/frontend/playwright-report/index.html b/frontend/playwright-report/index.html deleted file mode 100644 index f6ad474f7f..0000000000 --- a/frontend/playwright-report/index.html +++ /dev/null @@ -1,85 +0,0 @@ - - - - - - - - - Playwright Test Report - - - - -
- - - \ No newline at end of file From d5679b3e282096a1ebc03897b224b2754ff92688 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 28 Jan 2026 11:01:33 -0800 Subject: [PATCH 05/35] explicitly remove last-run.json --- frontend/test-results/.last-run.json | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 frontend/test-results/.last-run.json diff --git a/frontend/test-results/.last-run.json b/frontend/test-results/.last-run.json deleted file mode 100644 index cbcc1fbac1..0000000000 --- a/frontend/test-results/.last-run.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "status": "passed", - "failedTests": [] -} \ No newline at end of file From 7be9760784c2c5f58ae563bb0b91d188c2ac1be3 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 28 Jan 2026 15:10:23 -0800 Subject: [PATCH 06/35] use normalizer properly, clean up, refactor --- pyrit/backend/main.py | 3 +- pyrit/backend/models/attacks.py | 6 +- pyrit/backend/routes/__init__.py | 3 +- pyrit/backend/routes/attacks.py | 23 - pyrit/backend/routes/converters.py | 103 +-- pyrit/backend/routes/initializers.py | 34 + pyrit/backend/routes/registry.py | 125 ---- pyrit/backend/routes/scenarios.py | 34 + pyrit/backend/routes/scorers.py | 38 + pyrit/backend/routes/targets.py | 33 +- pyrit/backend/services/__init__.py | 2 - pyrit/backend/services/attack_service.py | 647 +++++++++-------- pyrit/backend/services/converter_service.py | 391 +++++------ tests/unit/backend/test_api_routes.py | 297 +++++--- tests/unit/backend/test_attack_service.py | 740 +++++++------------- 15 files changed, 1092 insertions(+), 1387 deletions(-) create mode 100644 pyrit/backend/routes/initializers.py delete mode 100644 pyrit/backend/routes/registry.py create mode 100644 pyrit/backend/routes/scenarios.py create mode 100644 pyrit/backend/routes/scorers.py diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index b6bb6988dd..7314ca54c4 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -16,7 +16,7 @@ import pyrit from pyrit.backend.middleware import register_error_handlers -from pyrit.backend.routes import attacks, converters, health, registry, targets, version +from pyrit.backend.routes import attacks, converters, health, targets, version from pyrit.setup.initialization import initialize_pyrit_async # Check for development mode from environment variable @@ -54,7 +54,6 @@ async def startup_event_async() -> None: app.include_router(attacks.router, prefix="/api", tags=["attacks"]) app.include_router(targets.router, prefix="/api", tags=["targets"]) app.include_router(converters.router, prefix="/api", tags=["converters"]) -app.include_router(registry.router, prefix="/api", tags=["registry"]) app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(version.router, tags=["version"]) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 5d282ad28b..db85c13805 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -168,7 +168,10 @@ class MessagePieceRequest(BaseModel): """A piece of content for a message.""" data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") - content: str = Field(..., description="Content (text or base64 for media)") + original_value: str = Field(..., description="Original value (text or base64 for media)") + converted_value: Optional[str] = Field( + None, description="Converted value. If provided, bypasses converters." + ) mime_type: Optional[str] = Field(None, description="MIME type for media content") @@ -202,4 +205,3 @@ class AddMessageResponse(BaseModel): """ attack: AttackDetail = Field(..., description="Updated attack with new message(s)") - error: Optional[str] = Field(None, description="Transport-level error if send=True and request failed entirely") diff --git a/pyrit/backend/routes/__init__.py b/pyrit/backend/routes/__init__.py index 4f16ed7759..7aa362d23f 100644 --- a/pyrit/backend/routes/__init__.py +++ b/pyrit/backend/routes/__init__.py @@ -5,13 +5,12 @@ API route handlers. """ -from pyrit.backend.routes import attacks, converters, health, registry, targets, version +from pyrit.backend.routes import attacks, converters, health, targets, version __all__ = [ "attacks", "converters", "health", - "registry", "targets", "version", ] diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index e6478d0cbf..7401923bcf 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -192,26 +192,3 @@ async def add_message( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to add message: {str(e)}", ) - - -@router.delete( - "/{attack_id}", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - 404: {"model": ProblemDetail, "description": "Attack not found"}, - }, -) -async def delete_attack(attack_id: str) -> None: - """ - Delete an attack. - - Removes the attack and all associated messages. - """ - service = get_attack_service() - - deleted = await service.delete_attack(attack_id) - if not deleted: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Attack '{attack_id}' not found", - ) diff --git a/pyrit/backend/routes/converters.py b/pyrit/backend/routes/converters.py index 318f251b46..c1319104ac 100644 --- a/pyrit/backend/routes/converters.py +++ b/pyrit/backend/routes/converters.py @@ -4,13 +4,11 @@ """ Converters API routes. -Provides endpoints for: -- Listing converter types (metadata from registry) -- Managing converter instances (runtime objects) -- Previewing converter transformations +Provides endpoints for managing converter instances and previewing conversions. +Converter types are set at app startup - you cannot add new types at runtime. """ -from typing import List, Literal, Optional +from typing import Literal, Optional from fastapi import APIRouter, HTTPException, Query, status @@ -18,58 +16,21 @@ from pyrit.backend.models.converters import ( ConverterInstance, ConverterInstanceListResponse, - ConverterMetadataResponse, ConverterPreviewRequest, ConverterPreviewResponse, CreateConverterRequest, CreateConverterResponse, ) -from pyrit.backend.services import get_registry_service from pyrit.backend.services.converter_service import get_converter_service router = APIRouter(prefix="/converters", tags=["converters"]) -# ============================================================================ -# Converter Types (from registry) -# ============================================================================ - - -@router.get( - "/types", - response_model=List[ConverterMetadataResponse], -) -async def list_converter_types( - is_llm_based: Optional[bool] = Query(None, description="Filter by LLM-based converters"), - is_deterministic: Optional[bool] = Query(None, description="Filter by deterministic converters"), -) -> List[ConverterMetadataResponse]: - """ - List available converter types. - - Returns metadata about all available prompt converter types (not instances). - For instances, use GET /converters/instances. - - Returns: - List[ConverterMetadataResponse]: List of converter type metadata. - """ - service = get_registry_service() - - return service.get_converters( - is_llm_based=is_llm_based, - is_deterministic=is_deterministic, - ) - - -# ============================================================================ -# Converter Instances (runtime objects) -# ============================================================================ - - @router.get( - "/instances", + "", response_model=ConverterInstanceListResponse, ) -async def list_converter_instances( +async def list_converters( source: Optional[Literal["initializer", "user"]] = Query( None, description="Filter by source (initializer or user)" ), @@ -87,35 +48,19 @@ async def list_converter_instances( @router.post( - "/instances", + "", response_model=CreateConverterResponse, status_code=status.HTTP_201_CREATED, responses={ 400: {"model": ProblemDetail, "description": "Invalid converter type or parameters"}, - 422: {"model": ProblemDetail, "description": "Validation error"}, }, ) -async def create_converter_instance(request: CreateConverterRequest) -> CreateConverterResponse: +async def create_converter(request: CreateConverterRequest) -> CreateConverterResponse: """ Create a new converter instance. Supports nested converters - if params contains a 'converter' key with - a type/params object, the nested converter will be created first and - linked to the outer converter. - - Example for SelectiveTextConverter: - ```json - { - "type": "selective_text", - "params": { - "pattern": "\\[CONVERT\\]", - "converter": { - "type": "base64", - "params": {} - } - } - } - ``` + a type/params object, the nested converter will be created first. Returns: CreateConverterResponse: The created converter instance details. @@ -137,13 +82,13 @@ async def create_converter_instance(request: CreateConverterRequest) -> CreateCo @router.get( - "/instances/{converter_id}", + "/{converter_id}", response_model=ConverterInstance, responses={ 404: {"model": ProblemDetail, "description": "Converter not found"}, }, ) -async def get_converter_instance(converter_id: str) -> ConverterInstance: +async def get_converter(converter_id: str) -> ConverterInstance: """ Get a converter instance by ID. @@ -162,34 +107,6 @@ async def get_converter_instance(converter_id: str) -> ConverterInstance: return converter -@router.delete( - "/instances/{converter_id}", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - 404: {"model": ProblemDetail, "description": "Converter not found"}, - }, -) -async def delete_converter_instance(converter_id: str) -> None: - """ - Delete a converter instance. - - Note: Converters in use by active attacks cannot be deleted. - """ - service = get_converter_service() - - deleted = await service.delete_converter(converter_id) - if not deleted: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Converter '{converter_id}' not found", - ) - - -# ============================================================================ -# Converter Preview -# ============================================================================ - - @router.post( "/preview", response_model=ConverterPreviewResponse, diff --git a/pyrit/backend/routes/initializers.py b/pyrit/backend/routes/initializers.py new file mode 100644 index 0000000000..24f0b1d642 --- /dev/null +++ b/pyrit/backend/routes/initializers.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Initializers API routes. + +Provides endpoints for listing available initializers. +""" + +from typing import List + +from fastapi import APIRouter + +from pyrit.backend.models.registry import InitializerMetadataResponse +from pyrit.backend.services import get_registry_service + +router = APIRouter(prefix="/initializers", tags=["initializers"]) + + +@router.get( + "", + response_model=List[InitializerMetadataResponse], +) +async def list_initializers() -> List[InitializerMetadataResponse]: + """ + List available initializers. + + Returns metadata about all registered initializers. + + Returns: + List[InitializerMetadataResponse]: List of initializer metadata. + """ + service = get_registry_service() + return service.get_initializers() diff --git a/pyrit/backend/routes/registry.py b/pyrit/backend/routes/registry.py deleted file mode 100644 index dd0659bd33..0000000000 --- a/pyrit/backend/routes/registry.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Registry API routes. - -Provides endpoints for querying available components. -""" - -from typing import List, Optional - -from fastapi import APIRouter, Query - -from pyrit.backend.models.registry import ( - ConverterMetadataResponse, - InitializerMetadataResponse, - ScenarioMetadataResponse, - ScorerMetadataResponse, - TargetMetadataResponse, -) -from pyrit.backend.services import get_registry_service - -router = APIRouter(prefix="/registry", tags=["registry"]) - - -@router.get( - "/targets", - response_model=List[TargetMetadataResponse], -) -async def list_targets( - is_chat_target: Optional[bool] = Query(None, description="Filter by chat target support"), -) -> List[TargetMetadataResponse]: - """ - List available targets. - - Returns metadata about all available prompt targets, optionally - filtered by chat target support. - - Returns: - List[TargetMetadataResponse]: List of target metadata. - """ - service = get_registry_service() - - return service.get_targets(is_chat_target=is_chat_target) - - -@router.get( - "/scenarios", - response_model=List[ScenarioMetadataResponse], -) -async def list_scenarios() -> List[ScenarioMetadataResponse]: - """ - List available scenarios. - - Returns metadata about all registered scenarios. - - Returns: - List[ScenarioMetadataResponse]: List of scenario metadata. - """ - service = get_registry_service() - - return service.get_scenarios() - - -@router.get( - "/scorers", - response_model=List[ScorerMetadataResponse], -) -async def list_scorers( - scorer_type: Optional[str] = Query(None, description="Filter by scorer type (true_false or float_scale)"), -) -> List[ScorerMetadataResponse]: - """ - List registered scorers. - - Returns metadata about all registered scorer instances. - - Returns: - List[ScorerMetadataResponse]: List of scorer metadata. - """ - service = get_registry_service() - - return service.get_scorers(scorer_type=scorer_type) - - -@router.get( - "/converters", - response_model=List[ConverterMetadataResponse], -) -async def list_converters( - is_llm_based: Optional[bool] = Query(None, description="Filter by LLM-based converters"), - is_deterministic: Optional[bool] = Query(None, description="Filter by deterministic converters"), -) -> List[ConverterMetadataResponse]: - """ - List available converters. - - Returns metadata about all available prompt converters. - Note: Also available at /converters endpoint. - - Returns: - List[ConverterMetadataResponse]: List of converter metadata. - """ - service = get_registry_service() - - return service.get_converters( - is_llm_based=is_llm_based, - is_deterministic=is_deterministic, - ) - - -@router.get( - "/initializers", - response_model=List[InitializerMetadataResponse], -) -async def list_initializers() -> List[InitializerMetadataResponse]: - """ - List available initializers. - - Returns metadata about all registered initializers. - - Returns: - List[InitializerMetadataResponse]: List of initializer metadata. - """ - service = get_registry_service() - - return service.get_initializers() diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py new file mode 100644 index 0000000000..0f45b8c53d --- /dev/null +++ b/pyrit/backend/routes/scenarios.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenarios API routes. + +Provides endpoints for listing available scenarios. +""" + +from typing import List + +from fastapi import APIRouter + +from pyrit.backend.models.registry import ScenarioMetadataResponse +from pyrit.backend.services import get_registry_service + +router = APIRouter(prefix="/scenarios", tags=["scenarios"]) + + +@router.get( + "", + response_model=List[ScenarioMetadataResponse], +) +async def list_scenarios() -> List[ScenarioMetadataResponse]: + """ + List available scenarios. + + Returns metadata about all registered scenarios. + + Returns: + List[ScenarioMetadataResponse]: List of scenario metadata. + """ + service = get_registry_service() + return service.get_scenarios() diff --git a/pyrit/backend/routes/scorers.py b/pyrit/backend/routes/scorers.py new file mode 100644 index 0000000000..b516746555 --- /dev/null +++ b/pyrit/backend/routes/scorers.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scorers API routes. + +Provides endpoints for listing available scorers. +""" + +from typing import List, Optional + +from fastapi import APIRouter, Query + +from pyrit.backend.models.registry import ScorerMetadataResponse +from pyrit.backend.services import get_registry_service + +router = APIRouter(prefix="/scorers", tags=["scorers"]) + + +@router.get( + "", + response_model=List[ScorerMetadataResponse], +) +async def list_scorers( + scorer_type: Optional[str] = Query( + None, description="Filter by scorer type (true_false or float_scale)" + ), +) -> List[ScorerMetadataResponse]: + """ + List available scorers. + + Returns metadata about all registered scorer types. + + Returns: + List[ScorerMetadataResponse]: List of scorer metadata. + """ + service = get_registry_service() + return service.get_scorers(scorer_type=scorer_type) diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index 43a3465901..03aa41c8f2 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -2,13 +2,10 @@ # Licensed under the MIT license. """ -Target instance API routes. +Target API routes. -Targets have two concepts: -- Types: Available via /api/registry/targets (static metadata) -- Instances: Runtime objects created via this API - -This module handles target instances (runtime objects). +Provides endpoints for managing target instances. +Target types are set at app startup via initializers - you cannot add new types at runtime. """ from typing import Literal, Optional @@ -55,7 +52,6 @@ async def list_targets( status_code=status.HTTP_201_CREATED, responses={ 400: {"model": ProblemDetail, "description": "Invalid target type or parameters"}, - 422: {"model": ProblemDetail, "description": "Validation error"}, }, ) async def create_target(request: CreateTargetRequest) -> CreateTargetResponse: @@ -110,26 +106,3 @@ async def get_target(target_id: str) -> TargetInstance: ) return target - - -@router.delete( - "/{target_id}", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - 404: {"model": ProblemDetail, "description": "Target not found"}, - }, -) -async def delete_target(target_id: str) -> None: - """ - Delete a target instance. - - Note: Targets in use by active attacks cannot be deleted. - """ - service = get_target_service() - - deleted = await service.delete_target(target_id) - if not deleted: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Target '{target_id}' not found", - ) diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index c748793055..c91c71cffa 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -9,7 +9,6 @@ from pyrit.backend.services.attack_service import ( AttackService, - AttackState, get_attack_service, ) from pyrit.backend.services.converter_service import ( @@ -27,7 +26,6 @@ __all__ = [ "AttackService", - "AttackState", "get_attack_service", "ConverterService", "get_converter_service", diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 11eaed10bb..edceec3a8e 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -4,16 +4,20 @@ """ Attack service for managing attacks. -All interactions are modeled as "attacks" - this is the attack-centric API design. -Handles attack lifecycle, message sending, prepended conversations, and scoring. +All user interactions are modeled as "attacks" - this is the attack-centric API design. +Handles attack lifecycle, message sending, and scoring. + +ARCHITECTURE: +- Each attack is represented by an AttackResult stored in the database +- The AttackResult has a conversation_id that links to the main conversation +- Messages are stored via PyRIT memory with that conversation_id +- For GUI use, it's a 1-to-1 mapping: one AttackResult, one conversation +- Future AI-generated attacks may have multiple related conversations """ import uuid -from collections import defaultdict from datetime import datetime, timezone -from typing import Dict, List, Literal, Optional, cast - -from pydantic import BaseModel +from typing import List, Literal, Optional, cast from pyrit.backend.models.attacks import ( AddMessageRequest, @@ -25,41 +29,38 @@ CreateAttackResponse, Message, MessagePiece, + Score, UpdateAttackRequest, ) from pyrit.backend.models.common import PaginationInfo from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service from pyrit.memory import CentralMemory - - -class AttackState(BaseModel): - """Internal state for an active attack.""" - - attack_id: str - name: Optional[str] = None - target_id: str - target_type: str - outcome: Optional[Literal["pending", "success", "failure"]] = None - prepended_conversation: List[Message] = [] - message_count: int = 0 - created_at: datetime - updated_at: datetime +from pyrit.models import AttackOutcome, AttackResult +from pyrit.models import Message as PyritMessage +from pyrit.models import MessagePiece as PyritMessagePiece +from pyrit.models import PromptDataType +from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer class AttackService: - """Service for managing attacks.""" + """ + Service for managing attacks. + + Uses PyRIT memory (database) as the source of truth via AttackResult. + """ def __init__(self) -> None: """Initialize the attack service.""" self._memory = CentralMemory.get_memory_instance() - # Active attack states - self._attacks: Dict[str, AttackState] = {} - # Messages by attack ID (in-memory for now) - self._messages: Dict[str, List[Message]] = defaultdict(list) + + # ======================================================================== + # Public API Methods + # ======================================================================== async def list_attacks( self, + *, target_id: Optional[str] = None, outcome: Optional[Literal["pending", "success", "failure"]] = None, limit: int = 20, @@ -68,70 +69,42 @@ async def list_attacks( """ List attacks with optional filtering and pagination. - Args: - target_id: Filter by target instance ID - outcome: Filter by outcome - limit: Maximum items per page - cursor: Pagination cursor - - Returns: - AttackListResponse: Paginated list of attack summaries + Queries AttackResult entries from the database. """ - attacks = list(self._attacks.values()) - - # Apply filters - if target_id: - attacks = [a for a in attacks if a.target_id == target_id] + # Map outcome string to AttackOutcome enum value for filtering + outcome_filter = None if outcome: - attacks = [a for a in attacks if a.outcome == outcome] + outcome_map = { + "pending": "undetermined", + "success": "success", + "failure": "failure", + } + outcome_filter = outcome_map.get(outcome) - # Sort by updated_at descending - attacks.sort(key=lambda a: a.updated_at, reverse=True) + attack_results = self._memory.get_attack_results(outcome=outcome_filter) - # Simple cursor-based pagination (cursor is the attack_id) - start_idx = 0 - if cursor: - for i, attack in enumerate(attacks): - if attack.attack_id == cursor: - start_idx = i + 1 - break + # Convert to summaries and filter + summaries = [] + for ar in attack_results: + # Get target info from attack_identifier + ar_target_id = ar.attack_identifier.get("target_id", "") + if target_id and ar_target_id != target_id: + continue - page = attacks[start_idx : start_idx + limit] - has_more = len(attacks) > start_idx + limit + summary = self._build_summary(ar) + summaries.append(summary) - summaries = [] - for attack in page: - messages = self._messages.get(attack.attack_id, []) - last_message_preview = None - if messages: - last_msg = messages[-1] - if last_msg.pieces: - preview_text = last_msg.pieces[0].converted_value - last_message_preview = preview_text[:100] + "..." if len(preview_text) > 100 else preview_text - - summaries.append( - AttackSummary( - attack_id=attack.attack_id, - name=attack.name, - target_id=attack.target_id, - target_type=attack.target_type, - outcome=attack.outcome, - last_message_preview=last_message_preview, - message_count=len(messages), - created_at=attack.created_at, - updated_at=attack.updated_at, - ) - ) + # Sort by most recent + summaries.sort(key=lambda s: s.updated_at, reverse=True) + # Paginate + page, has_more = self._paginate(summaries, cursor, limit) next_cursor = page[-1].attack_id if has_more and page else None return AttackListResponse( - items=summaries, + items=page, pagination=PaginationInfo( - limit=limit, - has_more=has_more, - next_cursor=next_cursor, - prev_cursor=cursor, + limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor ), ) @@ -139,289 +112,349 @@ async def get_attack(self, attack_id: str) -> Optional[AttackDetail]: """ Get attack details including all messages. - Args: - attack_id: Attack ID - - Returns: - AttackDetail or None if not found + Queries the AttackResult and its conversation from the database. """ - state = self._attacks.get(attack_id) - if not state: + # Get the attack result + results = self._memory.get_attack_results(conversation_id=attack_id) + if not results: return None - messages = self._messages.get(attack_id, []) + ar = results[0] + + # Get messages for this conversation + pyrit_messages = self._memory.get_conversation(conversation_id=attack_id) + backend_messages = self._convert_pyrit_messages(pyrit_messages) return AttackDetail( - attack_id=state.attack_id, - name=state.name, - target_id=state.target_id, - target_type=state.target_type, - outcome=state.outcome, - prepended_conversation=state.prepended_conversation, - messages=messages, - created_at=state.created_at, - updated_at=state.updated_at, + attack_id=attack_id, + name=ar.attack_identifier.get("name"), + target_id=ar.attack_identifier.get("target_id", ""), + target_type=ar.attack_identifier.get("target_type", ""), + outcome=self._map_outcome(ar.outcome), + prepended_conversation=[], + messages=backend_messages, + created_at=ar.metadata.get("created_at", datetime.now(timezone.utc)), + updated_at=ar.metadata.get("updated_at", datetime.now(timezone.utc)), ) - async def create_attack( - self, - request: CreateAttackRequest, - ) -> CreateAttackResponse: + async def create_attack(self, request: CreateAttackRequest) -> CreateAttackResponse: """ Create a new attack. - Args: - request: Attack creation request - - Returns: - CreateAttackResponse: Created attack details + Creates an AttackResult with a new conversation_id. """ target_service = get_target_service() - - # Validate target exists target_instance = await target_service.get_target(request.target_id) if not target_instance: raise ValueError(f"Target instance '{request.target_id}' not found") - attack_id = str(uuid.uuid4()) + # Generate conversation_id (this is the attack_id) + conversation_id = str(uuid.uuid4()) now = datetime.now(timezone.utc) - # Convert prepended messages to Message format - prepended_messages: List[Message] = [] - if request.prepended_conversation: - for i, prep_msg in enumerate(request.prepended_conversation): - msg = Message( - message_id=str(uuid.uuid4()), - turn_number=0, # Prepended messages are turn 0 - role=prep_msg.role, - pieces=[ - MessagePiece( - piece_id=str(uuid.uuid4()), - data_type="text", - original_value=prep_msg.content, - converted_value=prep_msg.content, - scores=[], - ) - ], - created_at=now, - ) - prepended_messages.append(msg) - - state = AttackState( - attack_id=attack_id, - name=request.name, - target_id=request.target_id, - target_type=target_instance.type, - outcome=None, - prepended_conversation=prepended_messages, - message_count=0, - created_at=now, - updated_at=now, + # Create AttackResult + attack_result = AttackResult( + conversation_id=conversation_id, + objective=request.name or "Manual attack via GUI", + attack_identifier={ + "name": request.name or "", + "target_id": request.target_id, + "target_type": target_instance.type, + "source": "gui", + }, + outcome=AttackOutcome.UNDETERMINED, + metadata={ + "created_at": now.isoformat(), + "updated_at": now.isoformat(), + **(request.labels or {}), + }, ) - self._attacks[attack_id] = state - return CreateAttackResponse( - attack_id=attack_id, - created_at=now, - ) + # Store in memory + self._memory.add_attack_results_to_memory(attack_results=[attack_result]) + + # Store prepended conversation if provided + if request.prepended_conversation: + await self._store_prepended_messages( + conversation_id=conversation_id, + prepended=request.prepended_conversation, + ) + + return CreateAttackResponse(attack_id=conversation_id, created_at=now) async def update_attack( - self, - attack_id: str, - request: UpdateAttackRequest, + self, attack_id: str, request: UpdateAttackRequest ) -> Optional[AttackDetail]: """ Update an attack's outcome. - Args: - attack_id: Attack ID - request: Update request with outcome - - Returns: - Updated AttackDetail or None if not found + Updates the AttackResult in the database. """ - state = self._attacks.get(attack_id) - if not state: + results = self._memory.get_attack_results(conversation_id=attack_id) + if not results: return None - state.outcome = request.outcome - state.updated_at = datetime.now(timezone.utc) + # Map outcome + outcome_map = { + "pending": AttackOutcome.UNDETERMINED, + "success": AttackOutcome.SUCCESS, + "failure": AttackOutcome.FAILURE, + } + new_outcome = outcome_map.get(request.outcome, AttackOutcome.UNDETERMINED) + + # Update the attack result (need to update via memory interface) + # For now, we update metadata to track the change + ar = results[0] + ar.outcome = new_outcome + ar.metadata["updated_at"] = datetime.now(timezone.utc).isoformat() + + # Re-add to memory (this should update) + self._memory.add_attack_results_to_memory(attack_results=[ar]) return await self.get_attack(attack_id) async def add_message( - self, - attack_id: str, - request: AddMessageRequest, + self, attack_id: str, request: AddMessageRequest ) -> AddMessageResponse: """ - Add a message to an attack. - - If send=True, sends to target and waits for response. - If send=False, just stores the message in memory. + Add a message to an attack, optionally sending to target. - Args: - attack_id: Attack ID - request: Add message request - - Returns: - AddMessageResponse: Updated attack detail + Messages are stored in the database via PromptNormalizer. """ - state = self._attacks.get(attack_id) - if not state: + # Check if attack exists + results = self._memory.get_attack_results(conversation_id=attack_id) + if not results: raise ValueError(f"Attack '{attack_id}' not found") - target_service = get_target_service() - converter_service = get_converter_service() + ar = results[0] + target_id = ar.attack_identifier.get("target_id") + if not target_id: + raise ValueError(f"Attack '{attack_id}' has no target configured") - now = datetime.now(timezone.utc) - state.message_count += 1 - msg_turn = state.message_count - - # Determine which converters to use (only for user messages being sent) - converters = [] - if request.send and request.role == "user" and request.converter_ids: - converters = converter_service.get_converter_objects_for_ids(request.converter_ids) - - # Build message pieces - msg_pieces: List[MessagePiece] = [] - for piece_req in request.pieces: - original_value = piece_req.content - converted_value = original_value - - # Apply converters - for converter in converters: - result = await converter.convert_async(prompt=converted_value) - converted_value = result.output_text - - msg_pieces.append( - MessagePiece( - piece_id=str(uuid.uuid4()), - data_type=piece_req.data_type, - original_value=original_value, - original_value_mime_type=piece_req.mime_type, - converted_value=converted_value, - converted_value_mime_type=piece_req.mime_type, - scores=[], - ) - ) + # Get existing messages to determine sequence + existing = self._memory.get_message_pieces(conversation_id=attack_id) + sequence = max((p.sequence for p in existing), default=-1) + 1 - message = Message( - message_id=str(uuid.uuid4()), - turn_number=msg_turn, - role=request.role, - pieces=msg_pieces, - created_at=now, - ) - - # Store the message - self._messages[attack_id].append(message) - - # If send=True, send to target and get response - transport_error: Optional[str] = None if request.send: - target_obj = target_service.get_target_object(state.target_id) - if not target_obj: - raise ValueError(f"Target object for '{state.target_id}' not found") - - try: - # Build conversation for target - from pyrit.models import Message as PyritMessage - from pyrit.models import MessagePiece as PyritMessagePiece - from pyrit.models import PromptDataType - - # Create prompt pieces for target - prompt_pieces = [] - for piece in msg_pieces: - pyrit_piece = PyritMessagePiece( - role=request.role, - original_value=piece.original_value or "", - original_value_data_type=cast(PromptDataType, piece.data_type), - converted_value=piece.converted_value, - converted_value_data_type=cast(PromptDataType, piece.data_type), - conversation_id=attack_id, - sequence=msg_turn, - ) - prompt_pieces.append(pyrit_piece) - - pyrit_message = PyritMessage(prompt_pieces) - - # Send to target - response_messages = await target_obj.send_prompt_async(message=pyrit_message) - - # Build assistant response - state.message_count += 1 - assistant_turn = state.message_count - - assistant_pieces: List[MessagePiece] = [] - if response_messages: - for resp_msg in response_messages: - for resp_piece in resp_msg.message_pieces: - assistant_pieces.append( - MessagePiece( - piece_id=str(uuid.uuid4()), - data_type=resp_piece.converted_value_data_type or "text", - original_value=resp_piece.original_value, - converted_value=resp_piece.converted_value or "", - scores=[], - response_error=resp_piece.response_error, - ) - ) - - assistant_message = Message( - message_id=str(uuid.uuid4()), - turn_number=assistant_turn, - role="assistant", - pieces=assistant_pieces - if assistant_pieces - else [ - MessagePiece( - piece_id=str(uuid.uuid4()), - data_type="text", - converted_value="", - scores=[], - ) - ], - created_at=datetime.now(timezone.utc), - ) - - # Store assistant message - self._messages[attack_id].append(assistant_message) - - except Exception as e: - transport_error = str(e) + await self._send_and_store_message(attack_id, target_id, request, sequence) + else: + await self._store_message_only(attack_id, request, sequence) # Update attack timestamp - state.updated_at = datetime.now(timezone.utc) + ar.metadata["updated_at"] = datetime.now(timezone.utc).isoformat() - # Get updated attack detail attack_detail = await self.get_attack(attack_id) if attack_detail is None: raise ValueError(f"Attack '{attack_id}' not found after update") - return AddMessageResponse( - attack=attack_detail, - error=transport_error, + return AddMessageResponse(attack=attack_detail) + + # ======================================================================== + # Private Helper Methods - Summary Building + # ======================================================================== + + def _build_summary(self, ar: AttackResult) -> AttackSummary: + """Build an AttackSummary from an AttackResult.""" + # Get message count and last preview + pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) + message_count = len(set(p.sequence for p in pieces)) + last_preview = None + if pieces: + last_piece = max(pieces, key=lambda p: p.sequence) + text = last_piece.converted_value or "" + last_preview = text[:100] + "..." if len(text) > 100 else text + + created_str = ar.metadata.get("created_at") + updated_str = ar.metadata.get("updated_at") + created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) + updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at + + return AttackSummary( + attack_id=ar.conversation_id, + name=ar.attack_identifier.get("name"), + target_id=ar.attack_identifier.get("target_id", ""), + target_type=ar.attack_identifier.get("target_type", ""), + outcome=self._map_outcome(ar.outcome), + last_message_preview=last_preview, + message_count=message_count, + created_at=created_at, + updated_at=updated_at, ) - async def delete_attack(self, attack_id: str) -> bool: - """ - Delete an attack. + def _map_outcome(self, outcome: AttackOutcome) -> Optional[Literal["pending", "success", "failure"]]: + """Map AttackOutcome enum to API outcome string.""" + if outcome == AttackOutcome.SUCCESS: + return "success" + elif outcome == AttackOutcome.FAILURE: + return "failure" + else: + return "pending" + + # ======================================================================== + # Private Helper Methods - Pagination + # ======================================================================== + + def _paginate( + self, items: List[AttackSummary], cursor: Optional[str], limit: int + ) -> tuple[List[AttackSummary], bool]: + """Apply cursor-based pagination.""" + start_idx = 0 + if cursor: + for i, item in enumerate(items): + if item.attack_id == cursor: + start_idx = i + 1 + break - Args: - attack_id: Attack ID + page = items[start_idx : start_idx + limit] + has_more = len(items) > start_idx + limit + return page, has_more + + # ======================================================================== + # Private Helper Methods - Message Conversion + # ======================================================================== + + def _convert_pyrit_messages(self, pyrit_messages: list) -> List[Message]: + """Convert PyRIT messages to backend Message format.""" + messages = [] + for msg in pyrit_messages: + pieces = [ + MessagePiece( + piece_id=str(p.id), + data_type=p.converted_value_data_type or "text", + original_value=p.original_value, + converted_value=p.converted_value or "", + scores=self._convert_scores(p.scores) if hasattr(p, 'scores') and p.scores else [], + response_error=p.response_error or "none", + ) + for p in msg.message_pieces + ] + + first = msg.message_pieces[0] if msg.message_pieces else None + messages.append(Message( + message_id=str(first.id) if first else str(uuid.uuid4()), + turn_number=first.sequence if first else 0, + role=first.role if first else "user", + pieces=pieces, + created_at=first.timestamp if first else datetime.now(timezone.utc), + )) + + return messages + + def _convert_scores(self, scores: list) -> List[Score]: + """Convert PyRIT scores to backend Score format.""" + return [ + Score( + score_id=str(s.id), + scorer_type=s.scorer_class_identifier.get("__type__", "unknown"), + score_value=s.score_value, + score_rationale=s.score_rationale, + scored_at=s.timestamp, + ) + for s in scores + ] + + # ======================================================================== + # Private Helper Methods - Store Messages + # ======================================================================== + + async def _store_prepended_messages( + self, + conversation_id: str, + prepended: list, + ) -> None: + """Store prepended conversation messages in memory.""" + for seq, msg in enumerate(prepended): + piece = PyritMessagePiece( + role=msg.role, + original_value=msg.content, + original_value_data_type="text", + converted_value=msg.content, + converted_value_data_type="text", + conversation_id=conversation_id, + sequence=seq, + ) + self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + + async def _send_and_store_message( + self, + attack_id: str, + target_id: str, + request: AddMessageRequest, + sequence: int, + ) -> None: + """Send message to target via normalizer and store response.""" + target_obj = get_target_service().get_target_object(target_id) + if not target_obj: + raise ValueError(f"Target object for '{target_id}' not found") + + pyrit_message = self._build_pyrit_message(request, attack_id, sequence) + converter_configs = self._get_converter_configs(request) + + normalizer = PromptNormalizer() + await normalizer.send_prompt_async( + message=pyrit_message, + target=target_obj, + conversation_id=attack_id, + request_converter_configurations=converter_configs, + ) + # PromptNormalizer stores both request and response in memory automatically + + async def _store_message_only( + self, + attack_id: str, + request: AddMessageRequest, + sequence: int, + ) -> None: + """Store message without sending (send=False).""" + for p in request.pieces: + piece = PyritMessagePiece( + role=request.role, + original_value=p.original_value, + original_value_data_type=cast(PromptDataType, p.data_type), + converted_value=p.converted_value or p.original_value, + converted_value_data_type=cast(PromptDataType, p.data_type), + conversation_id=attack_id, + sequence=sequence, + ) + self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + + def _build_pyrit_message( + self, + request: AddMessageRequest, + conversation_id: str, + sequence: int, + ) -> PyritMessage: + """Build PyRIT Message from request.""" + pieces = [ + PyritMessagePiece( + role=request.role, + original_value=p.original_value, + original_value_data_type=cast(PromptDataType, p.data_type), + converted_value=p.converted_value or p.original_value, + converted_value_data_type=cast(PromptDataType, p.data_type), + conversation_id=conversation_id, + sequence=sequence, + ) + for p in request.pieces + ] + return PyritMessage(pieces) + + def _get_converter_configs( + self, request: AddMessageRequest + ) -> List[PromptConverterConfiguration]: + """Get converter configurations if needed.""" + has_preconverted = any(p.converted_value is not None for p in request.pieces) + if has_preconverted or not request.converter_ids: + return [] + + converters = get_converter_service().get_converter_objects_for_ids(request.converter_ids) + return PromptConverterConfiguration.from_converters(converters=converters) - Returns: - True if deleted, False if not found - """ - if attack_id in self._attacks: - del self._attacks[attack_id] - self._messages.pop(attack_id, None) - return True - return False +# ============================================================================ +# Singleton +# ============================================================================ -# Global service instance _attack_service: Optional[AttackService] = None @@ -430,7 +463,7 @@ def get_attack_service() -> AttackService: Get the global attack service instance. Returns: - AttackService: The singleton attack service instance. + The singleton AttackService instance. """ global _attack_service if _attack_service is None: diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index c7afb8b0da..9702e7ed13 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -30,220 +30,202 @@ class ConverterService: def __init__(self) -> None: """Initialize the converter service.""" - # In-memory storage for converter instances self._instances: Dict[str, ConverterInstance] = {} - # Actual instantiated converter objects self._converter_objects: Dict[str, Any] = {} - def _get_converter_class(self, converter_type: str) -> type: - """ - Get the converter class for a given type. + # ======================================================================== + # Public API Methods + # ======================================================================== + + async def list_converters( + self, source: Optional[Literal["initializer", "user"]] = None + ) -> ConverterInstanceListResponse: + """List all converter instances.""" + items = list(self._instances.values()) + if source is not None: + items = [c for c in items if c.source == source] + return ConverterInstanceListResponse(items=items) - Args: - converter_type: Converter type string (e.g., 'base64', 'Base64Converter') + async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: + """Get a converter instance by ID.""" + return self._instances.get(converter_id) - Returns: - The converter class - """ + def get_converter_object(self, converter_id: str) -> Optional[Any]: + """Get the actual converter object.""" + return self._converter_objects.get(converter_id) + + async def create_converter( + self, request: CreateConverterRequest + ) -> CreateConverterResponse: + """Create a new converter instance with optional nested converters.""" + config = {"type": request.type, "params": request.params} + converter_id, _, created_instances = self._create_converter_recursive(config, "user") + + if request.display_name and converter_id in self._instances: + self._instances[converter_id].display_name = request.display_name + + outer_instance = self._instances[converter_id] + return CreateConverterResponse( + converter_id=converter_id, + type=request.type, + display_name=request.display_name, + params=outer_instance.params, + created_converters=created_instances if len(created_instances) > 1 else None, + created_at=outer_instance.created_at, + source="user", + ) + + async def delete_converter(self, converter_id: str) -> bool: + """Delete a converter instance.""" + if converter_id in self._instances: + del self._instances[converter_id] + self._converter_objects.pop(converter_id, None) + return True + return False + + async def preview_conversion( + self, request: ConverterPreviewRequest + ) -> ConverterPreviewResponse: + """Preview conversion through a converter pipeline.""" + converters = self._gather_converters_for_preview(request) + steps, final_value, final_type = await self._apply_converters( + converters, request.original_value, request.original_value_data_type + ) + + return ConverterPreviewResponse( + original_value=request.original_value, + original_value_data_type=request.original_value_data_type, + converted_value=final_value, + converted_value_data_type=final_type, + steps=steps, + ) + + def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: + """Get converter objects for a list of IDs.""" + converters = [] + for conv_id in converter_ids: + conv_obj = self.get_converter_object(conv_id) + if conv_obj is None: + raise ValueError(f"Converter instance '{conv_id}' not found") + converters.append(conv_obj) + return converters + + def instantiate_inline_converters(self, configs: List[InlineConverterConfig]) -> List[Any]: + """Instantiate converters from inline configurations.""" + return [ + self._get_converter_class(config.type)(**config.params) + for config in configs + ] + + # ======================================================================== + # Private Helper Methods - Class Resolution + # ======================================================================== + + def _get_converter_class(self, converter_type: str) -> type: + """Get the converter class for a given type.""" module = importlib.import_module("pyrit.prompt_converter") - # Try direct attribute lookup first cls = getattr(module, converter_type, None) if cls is not None: return cast(type, cls) - # Try common class name patterns - class_name_patterns = [ - converter_type, - f"{converter_type}Converter", - "".join(word.capitalize() for word in converter_type.split("_")), - "".join(word.capitalize() for word in converter_type.split("_")) + "Converter", - ] - - for pattern in class_name_patterns: + for pattern in self._class_name_patterns(converter_type): cls = getattr(module, pattern, None) if cls is not None: return cast(type, cls) raise ValueError(f"Converter type '{converter_type}' not found in pyrit.prompt_converter") + def _class_name_patterns(self, type_name: str) -> List[str]: + """Generate class name patterns to try.""" + pascal = "".join(word.capitalize() for word in type_name.split("_")) + return [type_name, f"{type_name}Converter", pascal, f"{pascal}Converter"] + + # ======================================================================== + # Private Helper Methods - Recursive Creation + # ======================================================================== + def _create_converter_recursive( self, config: Dict[str, Any], source: Literal["initializer", "user"], ) -> Tuple[str, Any, List[ConverterInstance]]: - """ - Recursively create converters, handling nested converter params. - - Args: - config: Converter configuration with 'type' and 'params' - source: Source of creation - - Returns: - Tuple of (converter_id, converter_object, list of all created instances) - """ + """Recursively create converters, handling nested converter params.""" converter_type = config["type"] params = dict(config.get("params", {})) created_instances: List[ConverterInstance] = [] - # Check for nested converter in params + # Handle nested converter + params, created_instances = self._resolve_nested_converter(params, source) + + # Create this converter + converter_obj = self._get_converter_class(converter_type)(**params) + converter_id = self._store_converter(converter_type, converter_obj, config, created_instances, source) + + return converter_id, converter_obj, created_instances + + def _resolve_nested_converter( + self, + params: Dict[str, Any], + source: Literal["initializer", "user"], + ) -> Tuple[Dict[str, Any], List[ConverterInstance]]: + """Resolve nested converter in params if present.""" + created_instances: List[ConverterInstance] = [] + if "converter" in params and isinstance(params["converter"], dict): nested_config = params["converter"] if "type" in nested_config: - # Recursively create nested converter - nested_id, nested_obj, nested_instances = self._create_converter_recursive(nested_config, source) + _, nested_obj, nested_instances = self._create_converter_recursive(nested_config, source) created_instances.extend(nested_instances) - # Replace inline config with the actual converter object params["converter"] = nested_obj - # Create this converter - converter_class = self._get_converter_class(converter_type) - converter_obj = converter_class(**params) + return params, created_instances + def _store_converter( + self, + converter_type: str, + converter_obj: Any, + config: Dict[str, Any], + created_instances: List[ConverterInstance], + source: Literal["initializer", "user"], + ) -> str: + """Store converter and return its ID.""" converter_id = str(uuid.uuid4()) - now = datetime.now(timezone.utc) - - # Store the converter object self._converter_objects[converter_id] = converter_obj - # Build resolved params (with nested converter IDs instead of objects) - resolved_params = dict(config.get("params", {})) - if "converter" in resolved_params and isinstance(resolved_params["converter"], dict): - # Replace with the nested converter ID - nested_id = created_instances[-1].converter_id if created_instances else None - resolved_params["converter"] = {"converter_id": nested_id} - + resolved_params = self._build_resolved_params(config, created_instances) instance = ConverterInstance( converter_id=converter_id, type=converter_type, display_name=None, params=resolved_params, - created_at=now, + created_at=datetime.now(timezone.utc), source=source, ) self._instances[converter_id] = instance created_instances.append(instance) - return converter_id, converter_obj, created_instances + return converter_id - async def list_converters( - self, - source: Optional[Literal["initializer", "user"]] = None, - ) -> ConverterInstanceListResponse: - """ - List all converter instances. - - Args: - source: Optional filter by source - - Returns: - ConverterInstanceListResponse: List of converter instances - """ - items = list(self._instances.values()) - - if source is not None: - items = [c for c in items if c.source == source] - - return ConverterInstanceListResponse(items=items) - - async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: - """ - Get a converter instance by ID. - - Args: - converter_id: Converter instance ID - - Returns: - ConverterInstance or None if not found - """ - return self._instances.get(converter_id) - - def get_converter_object(self, converter_id: str) -> Optional[Any]: - """ - Get the actual converter object. - - Args: - converter_id: Converter instance ID - - Returns: - The instantiated converter object or None - """ - return self._converter_objects.get(converter_id) - - async def create_converter( - self, - request: CreateConverterRequest, - ) -> CreateConverterResponse: - """ - Create a new converter instance. - - Supports nested converters - if params contains a 'converter' key with - a type/params dict, the nested converter will be created first. - - Args: - request: Converter creation request - - Returns: - CreateConverterResponse: Created converter details - """ - config = { - "type": request.type, - "params": request.params, - } - - converter_id, converter_obj, created_instances = self._create_converter_recursive(config, "user") - - # Update display name for the outermost converter - if request.display_name and converter_id in self._instances: - self._instances[converter_id].display_name = request.display_name - - outer_instance = self._instances[converter_id] - - return CreateConverterResponse( - converter_id=converter_id, - type=request.type, - display_name=request.display_name, - params=outer_instance.params, - created_converters=created_instances if len(created_instances) > 1 else None, - created_at=outer_instance.created_at, - source="user", - ) - - async def delete_converter(self, converter_id: str) -> bool: - """ - Delete a converter instance. - - Args: - converter_id: Converter instance ID - - Returns: - True if deleted, False if not found - """ - if converter_id in self._instances: - del self._instances[converter_id] - self._converter_objects.pop(converter_id, None) - return True - return False - - async def preview_conversion( - self, - request: ConverterPreviewRequest, - ) -> ConverterPreviewResponse: - """ - Preview conversion through a converter pipeline. - - Args: - request: Preview request with content and converters + def _build_resolved_params( + self, config: Dict[str, Any], created_instances: List[ConverterInstance] + ) -> Dict[str, Any]: + """Build resolved params with nested converter IDs.""" + resolved_params = dict(config.get("params", {})) + if "converter" in resolved_params and isinstance(resolved_params["converter"], dict): + nested_id = created_instances[-1].converter_id if created_instances else None + resolved_params["converter"] = {"converter_id": nested_id} + return resolved_params - Returns: - ConverterPreviewResponse: Conversion results with steps - """ - current_value = request.original_value - current_type: PromptDataType = request.original_value_data_type - steps: List[PreviewStep] = [] + # ======================================================================== + # Private Helper Methods - Preview + # ======================================================================== - # Get converters to apply - converters_to_apply: List[Tuple[Optional[str], str, Any]] = [] + def _gather_converters_for_preview( + self, request: ConverterPreviewRequest + ) -> List[Tuple[Optional[str], str, Any]]: + """Gather converters to apply from request.""" + converters: List[Tuple[Optional[str], str, Any]] = [] if request.converter_ids: for conv_id in request.converter_ids: @@ -251,22 +233,30 @@ async def preview_conversion( if conv_obj is None: raise ValueError(f"Converter instance '{conv_id}' not found") instance = self._instances[conv_id] - converters_to_apply.append((conv_id, instance.type, conv_obj)) + converters.append((conv_id, instance.type, conv_obj)) if request.converters: for inline_config in request.converters: - converter_class = self._get_converter_class(inline_config.type) - conv_obj = converter_class(**inline_config.params) - converters_to_apply.append((None, inline_config.type, conv_obj)) + conv_obj = self._get_converter_class(inline_config.type)(**inline_config.params) + converters.append((None, inline_config.type, conv_obj)) - # Apply converters in sequence - for conv_id, conv_type, conv_obj in converters_to_apply: - input_value = current_value - input_type = current_type + return converters + async def _apply_converters( + self, + converters: List[Tuple[Optional[str], str, Any]], + initial_value: str, + initial_type: PromptDataType, + ) -> Tuple[List[PreviewStep], str, PromptDataType]: + """Apply converters and collect steps.""" + current_value = initial_value + current_type = initial_type + steps: List[PreviewStep] = [] + + for conv_id, conv_type, conv_obj in converters: + input_value, input_type = current_value, current_type result = await conv_obj.convert_async(prompt=current_value) - current_value = result.output_text - current_type = result.output_type + current_value, current_type = result.output_text, result.output_type steps.append( PreviewStep( @@ -279,54 +269,13 @@ async def preview_conversion( ) ) - return ConverterPreviewResponse( - original_value=request.original_value, - original_value_data_type=request.original_value_data_type, - converted_value=current_value, - converted_value_data_type=current_type, - steps=steps, - ) - - def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: - """ - Get converter objects for a list of IDs. - - Args: - converter_ids: List of converter instance IDs - - Returns: - List of converter objects - - Raises: - ValueError: If any converter ID is not found - """ - converters = [] - for conv_id in converter_ids: - conv_obj = self.get_converter_object(conv_id) - if conv_obj is None: - raise ValueError(f"Converter instance '{conv_id}' not found") - converters.append(conv_obj) - return converters - - def instantiate_inline_converters(self, configs: List[InlineConverterConfig]) -> List[Any]: - """ - Instantiate converters from inline configurations. - - Args: - configs: List of inline converter configs + return steps, current_value, current_type - Returns: - List of converter objects - """ - converters = [] - for config in configs: - converter_class = self._get_converter_class(config.type) - conv_obj = converter_class(**config.params) - converters.append(conv_obj) - return converters +# ============================================================================ +# Singleton +# ============================================================================ -# Global service instance _converter_service: Optional[ConverterService] = None @@ -335,7 +284,7 @@ def get_converter_service() -> ConverterService: Get the global converter service instance. Returns: - ConverterService: The singleton converter service instance. + The singleton ConverterService instance. """ global _converter_service if _converter_service is None: diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index e02880c7e5..99ba9baf73 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -201,28 +201,6 @@ def test_update_attack_success(self, client: TestClient) -> None: data = response.json() assert data["outcome"] == "success" - def test_delete_attack_success(self, client: TestClient) -> None: - """Test deleting an attack.""" - with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: - mock_service = MagicMock() - mock_service.delete_attack = AsyncMock(return_value=True) - mock_get_service.return_value = mock_service - - response = client.delete("/api/attacks/attack-1") - - assert response.status_code == status.HTTP_204_NO_CONTENT - - def test_delete_attack_not_found(self, client: TestClient) -> None: - """Test deleting a non-existent attack.""" - with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: - mock_service = MagicMock() - mock_service.delete_attack = AsyncMock(return_value=False) - mock_get_service.return_value = mock_service - - response = client.delete("/api/attacks/nonexistent") - - assert response.status_code == status.HTTP_404_NOT_FOUND - def test_add_message_success(self, client: TestClient) -> None: """Test adding a message to an attack.""" now = datetime.now(timezone.utc) @@ -272,13 +250,91 @@ def test_add_message_success(self, client: TestClient) -> None: response = client.post( "/api/attacks/attack-1/messages", - json={"pieces": [{"content": "Hello"}]}, + json={"pieces": [{"original_value": "Hello"}]}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert len(data["attack"]["messages"]) == 2 + def test_update_attack_not_found(self, client: TestClient) -> None: + """Test updating a non-existent attack returns 404.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.update_attack = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.patch( + "/api/attacks/nonexistent", + json={"outcome": "success"}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_add_message_attack_not_found(self, client: TestClient) -> None: + """Test adding message to non-existent attack returns 404.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.add_message = AsyncMock( + side_effect=ValueError("Attack 'nonexistent' not found") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks/nonexistent/messages", + json={"pieces": [{"original_value": "Hello"}]}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_add_message_target_not_found(self, client: TestClient) -> None: + """Test adding message when target object not found returns 404.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.add_message = AsyncMock( + side_effect=ValueError("Target object for 'target-1' not found") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks/attack-1/messages", + json={"pieces": [{"original_value": "Hello"}]}, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_add_message_bad_request(self, client: TestClient) -> None: + """Test adding message with invalid request returns 400.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.add_message = AsyncMock( + side_effect=ValueError("Invalid message format") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks/attack-1/messages", + json={"pieces": [{"original_value": "Hello"}]}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_add_message_internal_error(self, client: TestClient) -> None: + """Test adding message when internal error occurs returns 500.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.add_message = AsyncMock( + side_effect=RuntimeError("Unexpected internal error") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/attacks/attack-1/messages", + json={"pieces": [{"original_value": "Hello"}]}, + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + # ============================================================================ # Target Routes Tests @@ -394,27 +450,21 @@ def test_get_target_not_found(self, client: TestClient) -> None: assert response.status_code == status.HTTP_404_NOT_FOUND - def test_delete_target_success(self, client: TestClient) -> None: - """Test deleting a target.""" + def test_create_target_internal_error(self, client: TestClient) -> None: + """Test target creation with internal error returns 500.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.delete_target = AsyncMock(return_value=True) - mock_get_service.return_value = mock_service - - response = client.delete("/api/targets/target-1") - - assert response.status_code == status.HTTP_204_NO_CONTENT - - def test_delete_target_not_found(self, client: TestClient) -> None: - """Test deleting a non-existent target.""" - with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: - mock_service = MagicMock() - mock_service.delete_target = AsyncMock(return_value=False) + mock_service.create_target = AsyncMock( + side_effect=RuntimeError("Unexpected internal error") + ) mock_get_service.return_value = mock_service - response = client.delete("/api/targets/nonexistent") + response = client.post( + "/api/targets", + json={"type": "TextTarget", "params": {}}, + ) - assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR # ============================================================================ @@ -425,18 +475,7 @@ def test_delete_target_not_found(self, client: TestClient) -> None: class TestConverterRoutes: """Tests for converter API routes.""" - def test_list_converter_types(self, client: TestClient) -> None: - """Test listing converter types from registry.""" - with patch("pyrit.backend.routes.converters.get_registry_service") as mock_get_service: - mock_service = MagicMock() - mock_service.get_converters.return_value = [] - mock_get_service.return_value = mock_service - - response = client.get("/api/converters/types") - - assert response.status_code == status.HTTP_200_OK - - def test_list_converter_instances(self, client: TestClient) -> None: + def test_list_converters(self, client: TestClient) -> None: """Test listing converter instances.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() @@ -445,13 +484,13 @@ def test_list_converter_instances(self, client: TestClient) -> None: ) mock_get_service.return_value = mock_service - response = client.get("/api/converters/instances") + response = client.get("/api/converters") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["items"] == [] - def test_create_converter_instance_success(self, client: TestClient) -> None: + def test_create_converter_success(self, client: TestClient) -> None: """Test successful converter instance creation.""" now = datetime.now(timezone.utc) @@ -470,7 +509,7 @@ def test_create_converter_instance_success(self, client: TestClient) -> None: mock_get_service.return_value = mock_service response = client.post( - "/api/converters/instances", + "/api/converters", json={"type": "Base64Converter", "display_name": "My Base64", "params": {}}, ) @@ -478,7 +517,7 @@ def test_create_converter_instance_success(self, client: TestClient) -> None: data = response.json() assert data["converter_id"] == "conv-1" - def test_create_converter_instance_invalid_type(self, client: TestClient) -> None: + def test_create_converter_invalid_type(self, client: TestClient) -> None: """Test converter creation with invalid type.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() @@ -488,13 +527,13 @@ def test_create_converter_instance_invalid_type(self, client: TestClient) -> Non mock_get_service.return_value = mock_service response = client.post( - "/api/converters/instances", + "/api/converters", json={"type": "InvalidConverter", "params": {}}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST - def test_get_converter_instance_success(self, client: TestClient) -> None: + def test_get_converter_success(self, client: TestClient) -> None: """Test getting a converter instance by ID.""" now = datetime.now(timezone.utc) @@ -511,42 +550,20 @@ def test_get_converter_instance_success(self, client: TestClient) -> None: ) mock_get_service.return_value = mock_service - response = client.get("/api/converters/instances/conv-1") + response = client.get("/api/converters/conv-1") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["converter_id"] == "conv-1" - def test_get_converter_instance_not_found(self, client: TestClient) -> None: + def test_get_converter_not_found(self, client: TestClient) -> None: """Test getting a non-existent converter instance.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() mock_service.get_converter = AsyncMock(return_value=None) mock_get_service.return_value = mock_service - response = client.get("/api/converters/instances/nonexistent") - - assert response.status_code == status.HTTP_404_NOT_FOUND - - def test_delete_converter_instance_success(self, client: TestClient) -> None: - """Test deleting a converter instance.""" - with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: - mock_service = MagicMock() - mock_service.delete_converter = AsyncMock(return_value=True) - mock_get_service.return_value = mock_service - - response = client.delete("/api/converters/instances/conv-1") - - assert response.status_code == status.HTTP_204_NO_CONTENT - - def test_delete_converter_instance_not_found(self, client: TestClient) -> None: - """Test deleting a non-existent converter instance.""" - with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: - mock_service = MagicMock() - mock_service.delete_converter = AsyncMock(return_value=False) - mock_get_service.return_value = mock_service - - response = client.delete("/api/converters/instances/nonexistent") + response = client.get("/api/converters/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND @@ -587,3 +604,119 @@ def test_preview_conversion_success(self, client: TestClient) -> None: data = response.json() assert data["converted_value"] == "dGVzdA==" assert len(data["steps"]) == 1 + + def test_create_converter_internal_error(self, client: TestClient) -> None: + """Test converter creation with internal error returns 500.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_converter = AsyncMock( + side_effect=RuntimeError("Unexpected internal error") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters", + json={"type": "Base64Converter", "params": {}}, + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + def test_preview_conversion_bad_request(self, client: TestClient) -> None: + """Test preview conversion with invalid parameters returns 400.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.preview_conversion = AsyncMock( + side_effect=ValueError("Invalid converter parameters") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters/preview", + json={ + "original_value": "test", + "original_value_data_type": "text", + "converters": [{"type": "InvalidConverter", "params": {}}], + }, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_preview_conversion_internal_error(self, client: TestClient) -> None: + """Test preview conversion with internal error returns 500.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.preview_conversion = AsyncMock( + side_effect=RuntimeError("Converter execution failed") + ) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters/preview", + json={ + "original_value": "test", + "original_value_data_type": "text", + "converters": [{"type": "Base64Converter", "params": {}}], + }, + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +# ============================================================================ +# Version Routes Tests +# ============================================================================ + + +class TestVersionRoutes: + """Tests for version API routes.""" + + def test_get_version(self, client: TestClient) -> None: + """Test getting version information.""" + response = client.get("/api/version") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "version" in data + assert "display" in data + + def test_get_version_with_build_info(self, client: TestClient) -> None: + """Test getting version with build info from Docker.""" + import tempfile + import json as json_lib + import os + + # Create a temp file to simulate Docker build info + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json_lib.dump( + { + "source": "git", + "commit": "abc123", + "modified": False, + "display": "1.0.0-test", + }, + f, + ) + temp_path = f.name + + try: + with patch("pyrit.backend.routes.version.Path") as mock_path_class: + mock_path_instance = MagicMock() + mock_path_instance.exists.return_value = True + mock_path_class.return_value = mock_path_instance + + # Mock open to return our temp file content + with patch("builtins.open", create=True) as mock_open: + mock_open.return_value.__enter__.return_value.read.return_value = json_lib.dumps( + { + "source": "git", + "commit": "abc123", + "modified": False, + "display": "1.0.0-test", + } + ) + + response = client.get("/api/version") + + assert response.status_code == status.HTTP_200_OK + finally: + os.unlink(temp_path) diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index b9379a0303..e92d5816da 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -2,11 +2,12 @@ # Licensed under the MIT license. """ -Tests for backend attack service. +Tests for attack service. + +The attack service uses PyRIT memory with AttackResult as the source of truth. """ from datetime import datetime, timezone -from typing import List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,212 +16,225 @@ AddMessageRequest, CreateAttackRequest, MessagePieceRequest, - PrependedMessageRequest, UpdateAttackRequest, ) -from pyrit.backend.models.targets import TargetInstance -from pyrit.backend.services.attack_service import AttackService, AttackState +from pyrit.backend.services.attack_service import ( + AttackService, + get_attack_service, +) +from pyrit.models import AttackOutcome, AttackResult + + +@pytest.fixture +def mock_memory(): + """Create a mock memory instance.""" + memory = MagicMock() + memory.get_attack_results.return_value = [] + memory.get_conversation.return_value = [] + memory.get_message_pieces.return_value = [] + return memory + + +@pytest.fixture +def attack_service(mock_memory): + """Create an attack service with mocked memory.""" + with patch("pyrit.backend.services.attack_service.CentralMemory") as mock_central: + mock_central.get_memory_instance.return_value = mock_memory + service = AttackService() + yield service + + +def make_attack_result( + *, + conversation_id: str = "attack-1", + objective: str = "Test objective", + target_id: str = "target-1", + target_type: str = "TextTarget", + name: str = "Test Attack", + outcome: AttackOutcome = AttackOutcome.UNDETERMINED, + created_at: datetime = None, + updated_at: datetime = None, +) -> AttackResult: + """Create a mock AttackResult for testing.""" + now = datetime.now(timezone.utc) + created = created_at or now + updated = updated_at or now + return AttackResult( + conversation_id=conversation_id, + objective=objective, + attack_identifier={ + "name": name, + "target_id": target_id, + "target_type": target_type, + "source": "gui", + }, + outcome=outcome, + metadata={ + "created_at": created.isoformat(), + "updated_at": updated.isoformat(), + }, + ) + + +def make_mock_piece( + *, + conversation_id: str, + role: str = "user", + sequence: int = 0, + original_value: str = "test", + converted_value: str = "test", + timestamp: datetime = None, +): + """Create a mock message piece.""" + piece = MagicMock() + piece.id = "piece-id" + piece.conversation_id = conversation_id + piece.role = role + piece.sequence = sequence + piece.original_value = original_value + piece.converted_value = converted_value + piece.converted_value_data_type = "text" + piece.response_error = "none" + piece.timestamp = timestamp or datetime.now(timezone.utc) + piece.scores = [] + return piece + + +def make_mock_message(pieces: list): + """Create a mock message from pieces.""" + msg = MagicMock() + msg.message_pieces = pieces + return msg + + +# ============================================================================ +# Init Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") class TestAttackServiceInit: """Tests for AttackService initialization.""" - def test_init_creates_empty_attacks_dict(self) -> None: - """Test that service initializes with empty attacks dictionary.""" - service = AttackService() - assert service._attacks == {} + def test_init_gets_memory_instance(self) -> None: + """Test that init gets the memory instance.""" + with patch("pyrit.backend.services.attack_service.CentralMemory") as mock_central: + mock_memory = MagicMock() + mock_central.get_memory_instance.return_value = mock_memory - def test_init_creates_empty_messages_dict(self) -> None: - """Test that service initializes with empty messages dictionary.""" - service = AttackService() - assert len(service._messages) == 0 + service = AttackService() + + mock_central.get_memory_instance.assert_called_once() + assert service._memory == mock_memory + + +# ============================================================================ +# List Attacks Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") class TestListAttacks: - """Tests for AttackService.list_attacks method.""" + """Tests for list_attacks method.""" @pytest.mark.asyncio - async def test_list_attacks_returns_empty_when_no_attacks(self) -> None: - """Test that list_attacks returns empty list when no attacks exist.""" - service = AttackService() + async def test_list_attacks_returns_empty_when_no_attacks( + self, attack_service, mock_memory + ) -> None: + """Test that list_attacks returns empty list when no AttackResults exist.""" + mock_memory.get_attack_results.return_value = [] - result = await service.list_attacks() + result = await attack_service.list_attacks() assert result.items == [] assert result.pagination.has_more is False @pytest.mark.asyncio - async def test_list_attacks_returns_attacks(self) -> None: - """Test that list_attacks returns existing attacks.""" - service = AttackService() - now = datetime.now(timezone.utc) - - # Add a test attack - service._attacks["test-id"] = AttackState( - attack_id="test-id", - name="Test Attack", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) + async def test_list_attacks_returns_attacks( + self, attack_service, mock_memory + ) -> None: + """Test that list_attacks returns attacks from AttackResult records.""" + ar = make_attack_result() + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] - result = await service.list_attacks() + result = await attack_service.list_attacks() assert len(result.items) == 1 - assert result.items[0].attack_id == "test-id" - assert result.items[0].name == "Test Attack" + assert result.items[0].attack_id == "attack-1" + assert result.items[0].target_id == "target-1" @pytest.mark.asyncio - async def test_list_attacks_filters_by_target_id(self) -> None: + async def test_list_attacks_filters_by_target_id( + self, attack_service, mock_memory + ) -> None: """Test that list_attacks filters by target_id.""" - service = AttackService() - now = datetime.now(timezone.utc) + ar1 = make_attack_result(conversation_id="attack-1", target_id="target-1") + ar2 = make_attack_result(conversation_id="attack-2", target_id="target-2") + mock_memory.get_attack_results.return_value = [ar1, ar2] + mock_memory.get_message_pieces.return_value = [] - service._attacks["attack-1"] = AttackState( - attack_id="attack-1", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - service._attacks["attack-2"] = AttackState( - attack_id="attack-2", - target_id="target-2", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - - result = await service.list_attacks(target_id="target-1") + result = await attack_service.list_attacks(target_id="target-1") assert len(result.items) == 1 assert result.items[0].target_id == "target-1" - @pytest.mark.asyncio - async def test_list_attacks_filters_by_outcome(self) -> None: - """Test that list_attacks filters by outcome.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["attack-1"] = AttackState( - attack_id="attack-1", - target_id="target-1", - target_type="TextTarget", - outcome="success", - created_at=now, - updated_at=now, - ) - service._attacks["attack-2"] = AttackState( - attack_id="attack-2", - target_id="target-1", - target_type="TextTarget", - outcome="failure", - created_at=now, - updated_at=now, - ) - - result = await service.list_attacks(outcome="success") - - assert len(result.items) == 1 - assert result.items[0].outcome == "success" - - @pytest.mark.asyncio - async def test_list_attacks_respects_limit(self) -> None: - """Test that list_attacks respects the limit parameter.""" - service = AttackService() - now = datetime.now(timezone.utc) - - for i in range(5): - service._attacks[f"attack-{i}"] = AttackState( - attack_id=f"attack-{i}", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - - result = await service.list_attacks(limit=2) - assert len(result.items) == 2 - assert result.pagination.has_more is True - - @pytest.mark.asyncio - async def test_list_attacks_cursor_pagination(self) -> None: - """Test that list_attacks handles cursor-based pagination.""" - service = AttackService() - now = datetime.now(timezone.utc) - - # Create attacks with different updated_at times - for i in range(3): - service._attacks[f"attack-{i}"] = AttackState( - attack_id=f"attack-{i}", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - - # Get first page - first_page = await service.list_attacks(limit=2) - assert len(first_page.items) == 2 - - # Get second page using cursor - if first_page.pagination.next_cursor: - second_page = await service.list_attacks( - limit=2, cursor=first_page.pagination.next_cursor - ) - assert len(second_page.items) == 1 +# ============================================================================ +# Get Attack Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") class TestGetAttack: - """Tests for AttackService.get_attack method.""" + """Tests for get_attack method.""" @pytest.mark.asyncio - async def test_get_attack_returns_none_for_nonexistent(self) -> None: - """Test that get_attack returns None for non-existent attack.""" - service = AttackService() + async def test_get_attack_returns_none_for_nonexistent( + self, attack_service, mock_memory + ) -> None: + """Test that get_attack returns None when AttackResult doesn't exist.""" + mock_memory.get_attack_results.return_value = [] - result = await service.get_attack("nonexistent-id") + result = await attack_service.get_attack("nonexistent") assert result is None @pytest.mark.asyncio - async def test_get_attack_returns_attack_details(self) -> None: - """Test that get_attack returns full attack details.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - name="Test Attack", + async def test_get_attack_returns_attack_details( + self, attack_service, mock_memory + ) -> None: + """Test that get_attack returns attack details from AttackResult.""" + ar = make_attack_result( + conversation_id="test-id", + name="My Attack", target_id="target-1", target_type="TextTarget", - outcome="pending", - created_at=now, - updated_at=now, ) + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] - result = await service.get_attack("test-id") + result = await attack_service.get_attack("test-id") assert result is not None assert result.attack_id == "test-id" - assert result.name == "Test Attack" + assert result.target_id == "target-1" assert result.target_type == "TextTarget" + assert result.name == "My Attack" + + +# ============================================================================ +# Create Attack Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") class TestCreateAttack: - """Tests for AttackService.create_attack method.""" + """Tests for create_attack method.""" @pytest.mark.asyncio - async def test_create_attack_validates_target_exists(self) -> None: - """Test that create_attack validates the target exists.""" - service = AttackService() - + async def test_create_attack_validates_target_exists(self, attack_service) -> None: + """Test that create_attack validates target exists.""" with patch( "pyrit.backend.services.attack_service.get_target_service" ) as mock_get_target_service: @@ -228,391 +242,121 @@ async def test_create_attack_validates_target_exists(self) -> None: mock_target_service.get_target = AsyncMock(return_value=None) mock_get_target_service.return_value = mock_target_service - request = CreateAttackRequest(target_id="nonexistent") - with pytest.raises(ValueError, match="not found"): - await service.create_attack(request) + await attack_service.create_attack( + CreateAttackRequest(target_id="nonexistent") + ) @pytest.mark.asyncio - async def test_create_attack_success(self) -> None: - """Test successful attack creation.""" - service = AttackService() - - mock_target = TargetInstance( - target_id="target-1", - type="TextTarget", - params={}, - created_at=datetime.now(timezone.utc), - source="user", - ) - + async def test_create_attack_stores_attack_result( + self, attack_service, mock_memory + ) -> None: + """Test that create_attack stores AttackResult in memory.""" with patch( "pyrit.backend.services.attack_service.get_target_service" ) as mock_get_target_service: mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=mock_target) + mock_target_service.get_target = AsyncMock( + return_value=MagicMock(type="TextTarget") + ) mock_get_target_service.return_value = mock_target_service - request = CreateAttackRequest(target_id="target-1", name="My Attack") - - result = await service.create_attack(request) + result = await attack_service.create_attack( + CreateAttackRequest(target_id="target-1", name="My Attack") + ) assert result.attack_id is not None assert result.created_at is not None + mock_memory.add_attack_results_to_memory.assert_called_once() - # Verify the attack was stored correctly by fetching it - attack = await service.get_attack(result.attack_id) - assert attack is not None - assert attack.name == "My Attack" - assert attack.target_id == "target-1" - assert attack.target_type == "TextTarget" - - @pytest.mark.asyncio - async def test_create_attack_with_prepended_conversation(self) -> None: - """Test attack creation with prepended conversation.""" - service = AttackService() - - mock_target = TargetInstance( - target_id="target-1", - type="TextTarget", - params={}, - created_at=datetime.now(timezone.utc), - source="user", - ) - - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: - mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=mock_target) - mock_get_target_service.return_value = mock_target_service - - request = CreateAttackRequest( - target_id="target-1", - prepended_conversation=[ - PrependedMessageRequest(role="system", content="You are a helpful assistant."), - ], - ) - - result = await service.create_attack(request) - # Verify the attack was stored with prepended conversation - attack = await service.get_attack(result.attack_id) - assert attack is not None - assert len(attack.prepended_conversation) == 1 - assert attack.prepended_conversation[0].role == "system" +# ============================================================================ +# Update Attack Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") class TestUpdateAttack: - """Tests for AttackService.update_attack method.""" - - @pytest.mark.asyncio - async def test_update_attack_returns_none_for_nonexistent(self) -> None: - """Test that update_attack returns None for non-existent attack.""" - service = AttackService() - - request = UpdateAttackRequest(outcome="success") - result = await service.update_attack("nonexistent", request) - - assert result is None + """Tests for update_attack method.""" @pytest.mark.asyncio - async def test_update_attack_updates_outcome(self) -> None: - """Test that update_attack updates the outcome.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - outcome=None, - created_at=now, - updated_at=now, + async def test_update_attack_returns_none_for_nonexistent( + self, attack_service, mock_memory + ) -> None: + """Test that update_attack returns None for nonexistent attack.""" + mock_memory.get_attack_results.return_value = [] + + result = await attack_service.update_attack( + "nonexistent", UpdateAttackRequest(outcome="success") ) - request = UpdateAttackRequest(outcome="success") - result = await service.update_attack("test-id", request) - - assert result is not None - assert result.outcome == "success" - - -@pytest.mark.usefixtures("patch_central_database") -class TestDeleteAttack: - """Tests for AttackService.delete_attack method.""" - - @pytest.mark.asyncio - async def test_delete_attack_returns_false_for_nonexistent(self) -> None: - """Test that delete_attack returns False for non-existent attack.""" - service = AttackService() - - result = await service.delete_attack("nonexistent") - - assert result is False + assert result is None @pytest.mark.asyncio - async def test_delete_attack_deletes_attack(self) -> None: - """Test that delete_attack removes the attack.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, + async def test_update_attack_updates_outcome( + self, attack_service, mock_memory + ) -> None: + """Test that update_attack updates the AttackResult outcome.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] + + await attack_service.update_attack( + "test-id", UpdateAttackRequest(outcome="success") ) - result = await service.delete_attack("test-id") - - assert result is True - assert "test-id" not in service._attacks - - @pytest.mark.asyncio - async def test_delete_attack_removes_messages(self) -> None: - """Test that delete_attack also removes associated messages.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - service._messages["test-id"] = [] + # Should call add_attack_results_to_memory to update + mock_memory.add_attack_results_to_memory.assert_called() - await service.delete_attack("test-id") - assert "test-id" not in service._messages +# ============================================================================ +# Add Message Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") class TestAddMessage: - """Tests for AttackService.add_message method.""" + """Tests for add_message method.""" @pytest.mark.asyncio - async def test_add_message_raises_for_nonexistent_attack(self) -> None: - """Test that add_message raises ValueError for non-existent attack.""" - service = AttackService() + async def test_add_message_raises_for_nonexistent_attack( + self, attack_service, mock_memory + ) -> None: + """Test that add_message raises ValueError for nonexistent attack.""" + mock_memory.get_attack_results.return_value = [] request = AddMessageRequest( - pieces=[MessagePieceRequest(content="Hello")], - ) - - with pytest.raises(ValueError, match="Attack"): - await service.add_message("nonexistent", request) - - @pytest.mark.asyncio - async def test_add_message_with_send_raises_for_missing_target_object(self) -> None: - """Test that add_message with send=True raises when target object is not found.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: - mock_target_service = MagicMock() - mock_target_service.get_target_object.return_value = None - mock_get_target_service.return_value = mock_target_service - - with patch( - "pyrit.backend.services.attack_service.get_converter_service" - ) as mock_get_converter_service: - mock_converter_service = MagicMock() - mock_get_converter_service.return_value = mock_converter_service - - request = AddMessageRequest( - pieces=[MessagePieceRequest(content="Hello")], - send=True, - ) - - with pytest.raises(ValueError, match="Target object"): - await service.add_message("test-id", request) - - @pytest.mark.asyncio - async def test_add_message_without_send_stores_message(self) -> None: - """Test that add_message with send=False just stores the message.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: - mock_target_service = MagicMock() - mock_get_target_service.return_value = mock_target_service - - with patch( - "pyrit.backend.services.attack_service.get_converter_service" - ) as mock_get_converter_service: - mock_converter_service = MagicMock() - mock_get_converter_service.return_value = mock_converter_service - - request = AddMessageRequest( - role="system", - pieces=[MessagePieceRequest(content="You are a helpful assistant.")], - send=False, - ) - - result = await service.add_message("test-id", request) - - assert result.attack is not None - assert len(result.attack.messages) == 1 - assert result.attack.messages[0].role == "system" - assert result.error is None - - @pytest.mark.asyncio - async def test_add_message_with_converter_ids_applies_converters(self) -> None: - """Test that add_message with converter_ids applies the converters.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) - - # Create mock converter - mock_converter = MagicMock() - mock_converter.convert_async = AsyncMock( - return_value=MagicMock(output_text="converted text") - ) - - # Create mock target - mock_target = AsyncMock() - mock_target.send_prompt_async = AsyncMock( - return_value=MagicMock( - request_pieces=[ - MagicMock( - original_value="assistant response", - converted_value="assistant response", - original_value_data_type="text", - ) - ], - response_error_description="none", - ) + pieces=[MessagePieceRequest(original_value="Hello")], ) - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: - mock_target_service = MagicMock() - mock_target_service.get_target_object.return_value = mock_target - mock_get_target_service.return_value = mock_target_service - - with patch( - "pyrit.backend.services.attack_service.get_converter_service" - ) as mock_get_converter_service: - mock_converter_service = MagicMock() - mock_converter_service.get_converter_objects_for_ids.return_value = [ - mock_converter - ] - mock_get_converter_service.return_value = mock_converter_service - - request = AddMessageRequest( - role="user", - pieces=[MessagePieceRequest(content="Hello")], - send=True, - converter_ids=["converter-1"], - ) - - result = await service.add_message("test-id", request) - - # Verify converter was applied - mock_converter_service.get_converter_objects_for_ids.assert_called_once_with( - ["converter-1"] - ) - mock_converter.convert_async.assert_called_once_with(prompt="Hello") - - # Verify message was converted - assert result.attack is not None - # First message is the user message with conversion - user_msg = result.attack.messages[0] - assert user_msg.role == "user" - assert user_msg.pieces[0].converted_value == "converted text" + with pytest.raises(ValueError, match="not found"): + await attack_service.add_message("nonexistent", request) @pytest.mark.asyncio - async def test_add_message_without_converter_ids_does_not_apply_converters(self) -> None: - """Test that add_message without converter_ids does not apply any converters.""" - service = AttackService() - now = datetime.now(timezone.utc) - - service._attacks["test-id"] = AttackState( - attack_id="test-id", - target_id="target-1", - target_type="TextTarget", - created_at=now, - updated_at=now, - ) + async def test_add_message_without_send_stores_message( + self, attack_service, mock_memory + ) -> None: + """Test that add_message with send=False stores message in memory.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + mock_memory.get_conversation.return_value = [] - # Create mock target - mock_target = AsyncMock() - mock_target.send_prompt_async = AsyncMock( - return_value=MagicMock( - request_pieces=[ - MagicMock( - original_value="response", - converted_value="response", - original_value_data_type="text", - ) - ], - response_error_description="none", - ) + request = AddMessageRequest( + role="system", + pieces=[MessagePieceRequest(original_value="You are a helpful assistant.")], + send=False, ) - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: - mock_target_service = MagicMock() - mock_target_service.get_target_object.return_value = mock_target - mock_get_target_service.return_value = mock_target_service - - with patch( - "pyrit.backend.services.attack_service.get_converter_service" - ) as mock_get_converter_service: - mock_converter_service = MagicMock() - mock_get_converter_service.return_value = mock_converter_service - - # No converter_ids in request - should not apply any converters - request = AddMessageRequest( - role="user", - pieces=[MessagePieceRequest(content="Hello")], - send=True, - ) + result = await attack_service.add_message("test-id", request) - result = await service.add_message("test-id", request) + assert result.attack is not None + mock_memory.add_message_pieces_to_memory.assert_called() - # Verify no converter lookup was done - mock_converter_service.get_converter_objects_for_ids.assert_not_called() - # Verify original value equals converted value (no conversion) - assert result.attack is not None - user_msg = result.attack.messages[0] - assert user_msg.pieces[0].original_value == "Hello" - assert user_msg.pieces[0].converted_value == "Hello" +# ============================================================================ +# Singleton Tests +# ============================================================================ @pytest.mark.usefixtures("patch_central_database") @@ -621,23 +365,23 @@ class TestAttackServiceSingleton: def test_get_attack_service_returns_attack_service(self) -> None: """Test that get_attack_service returns an AttackService instance.""" - from pyrit.backend.services.attack_service import get_attack_service - # Reset singleton for clean test import pyrit.backend.services.attack_service as module + module._attack_service = None - service = get_attack_service() - assert isinstance(service, AttackService) + with patch("pyrit.backend.services.attack_service.CentralMemory"): + service = get_attack_service() + assert isinstance(service, AttackService) def test_get_attack_service_returns_same_instance(self) -> None: """Test that get_attack_service returns the same instance.""" - from pyrit.backend.services.attack_service import get_attack_service - # Reset singleton for clean test import pyrit.backend.services.attack_service as module + module._attack_service = None - service1 = get_attack_service() - service2 = get_attack_service() - assert service1 is service2 + with patch("pyrit.backend.services.attack_service.CentralMemory"): + service1 = get_attack_service() + service2 = get_attack_service() + assert service1 is service2 From 27ca930860ddccc9ebab86b771419123d0197200 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 5 Feb 2026 10:31:48 -0800 Subject: [PATCH 07/35] update with latest PR feedback, cleanup --- frontend/README.md | 32 +- frontend/dev.py | 61 +-- pyproject.toml | 1 + pyrit/backend/main.py | 3 +- pyrit/backend/models/__init__.py | 31 +- pyrit/backend/models/attacks.py | 56 ++- pyrit/backend/models/converters.py | 56 +-- pyrit/backend/models/registry.py | 140 ------- pyrit/backend/models/targets.py | 7 +- pyrit/backend/routes/__init__.py | 3 +- pyrit/backend/routes/attacks.py | 79 +++- pyrit/backend/routes/initializers.py | 34 -- pyrit/backend/routes/labels.py | 88 +++++ pyrit/backend/routes/scenarios.py | 34 -- pyrit/backend/routes/scorers.py | 38 -- pyrit/backend/services/__init__.py | 6 - pyrit/backend/services/attack_service.py | 263 +++++++++---- pyrit/backend/services/converter_service.py | 288 +++++++-------- pyrit/backend/services/registry_service.py | 307 --------------- pyrit/backend/services/target_service.py | 163 +++----- pyrit/cli/pyrit_backend.py | 217 +++++++++++ .../registry/instance_registries/__init__.py | 8 + .../instance_registries/converter_registry.py | 108 ++++++ .../instance_registries/target_registry.py | 108 ++++++ tests/unit/backend/test_api_routes.py | 335 +++++++++++++---- tests/unit/backend/test_attack_service.py | 348 +++++++++++++++--- tests/unit/backend/test_converter_service.py | 320 ++++------------ tests/unit/backend/test_error_handlers.py | 2 +- tests/unit/backend/test_registry_service.py | 154 -------- tests/unit/backend/test_target_service.py | 201 ++-------- 30 files changed, 1733 insertions(+), 1758 deletions(-) delete mode 100644 pyrit/backend/models/registry.py delete mode 100644 pyrit/backend/routes/initializers.py create mode 100644 pyrit/backend/routes/labels.py delete mode 100644 pyrit/backend/routes/scenarios.py delete mode 100644 pyrit/backend/routes/scorers.py delete mode 100644 pyrit/backend/services/registry_service.py create mode 100644 pyrit/cli/pyrit_backend.py create mode 100644 pyrit/registry/instance_registries/converter_registry.py create mode 100644 pyrit/registry/instance_registries/target_registry.py delete mode 100644 tests/unit/backend/test_registry_service.py diff --git a/frontend/README.md b/frontend/README.md index 78845e9d21..86d5fd213a 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -13,6 +13,14 @@ python dev.py start # OR use npm script npm run start +# Start backend only (with airt initializer by default) +python dev.py backend + +# Start frontend only (backend must be started separately) +python dev.py frontend +# OR +npm run dev + # Restart both servers python dev.py restart # OR @@ -23,9 +31,6 @@ python dev.py stop # OR npm run stop -# Run Vite dev server only (backend must be started separately) -npm run dev - # Build for production npm run build @@ -33,6 +38,27 @@ npm run build npm run preview ``` +### Backend CLI + +The backend uses `pyrit_backend` CLI which supports initializers: + +```bash +# Start with default airt initializer (loads targets from env vars) +pyrit_backend --initializers airt + +# Start without initializers +pyrit_backend + +# Start with custom initialization script +pyrit_backend --initialization-scripts ./my_targets.py + +# List available initializers +pyrit_backend --list-initializers + +# Custom host/port +pyrit_backend --host 127.0.0.1 --port 8080 +``` + **Development Mode**: The `dev.py` script sets `PYRIT_DEV_MODE=true` so the backend expects the frontend to run separately on port 3000. **Production Mode**: When installed from PyPI, the backend serves the bundled frontend and will exit if frontend files are missing. diff --git a/frontend/dev.py b/frontend/dev.py index 28895a2a62..b2f8043bf3 100644 --- a/frontend/dev.py +++ b/frontend/dev.py @@ -77,8 +77,13 @@ def stop_servers(): print("✅ Servers stopped") -def start_backend(): - """Start the FastAPI backend""" +def start_backend(initializers: list[str] | None = None): + """Start the FastAPI backend using pyrit_backend CLI. + + Args: + initializers: Optional list of initializer names to run at startup. + Defaults to ["airt"] to load targets from environment variables. + """ print("🚀 Starting backend on port 8000...") # Change to workspace root @@ -88,40 +93,36 @@ def start_backend(): env = os.environ.copy() env["PYRIT_DEV_MODE"] = "true" - # Start backend with uvicorn + # Default to airt initializer if not specified + if initializers is None: + initializers = ["airt"] + + # Build command using pyrit_backend CLI + cmd = [ + sys.executable, + "-m", + "pyrit.cli.pyrit_backend", + "--host", + "0.0.0.0", + "--port", + "8000", + "--log-level", + "info", + ] + + # Add initializers if specified + if initializers: + cmd.extend(["--initializers"] + initializers) + + # Start backend if is_windows(): backend = subprocess.Popen( - [ - sys.executable, - "-m", - "uvicorn", - "pyrit.backend.main:app", - "--host", - "0.0.0.0", - "--port", - "8000", - "--log-level", - "info", - ], + cmd, env=env, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if is_windows() else 0, ) else: - backend = subprocess.Popen( - [ - sys.executable, - "-m", - "uvicorn", - "pyrit.backend.main:app", - "--host", - "0.0.0.0", - "--port", - "8000", - "--log-level", - "info", - ], - env=env, - ) + backend = subprocess.Popen(cmd, env=env) return backend diff --git a/pyproject.toml b/pyproject.toml index abf251d8aa..a8ae271d48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,7 @@ all = [ ] [project.scripts] +pyrit_backend = "pyrit.cli.pyrit_backend:main" pyrit_scan = "pyrit.cli.pyrit_scan:main" pyrit_shell = "pyrit.cli.pyrit_shell:main" diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index 7314ca54c4..d7e032d510 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -16,7 +16,7 @@ import pyrit from pyrit.backend.middleware import register_error_handlers -from pyrit.backend.routes import attacks, converters, health, targets, version +from pyrit.backend.routes import attacks, converters, health, labels, targets, version from pyrit.setup.initialization import initialize_pyrit_async # Check for development mode from environment variable @@ -54,6 +54,7 @@ async def startup_event_async() -> None: app.include_router(attacks.router, prefix="/api", tags=["attacks"]) app.include_router(targets.router, prefix="/api", tags=["targets"]) app.include_router(converters.router, prefix="/api", tags=["converters"]) +app.include_router(labels.router, prefix="/api", tags=["labels"]) app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(version.router, tags=["version"]) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 3ec408399a..d338454db5 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -10,8 +10,8 @@ from pyrit.backend.models.attacks import ( AddMessageRequest, AddMessageResponse, - AttackDetail, AttackListResponse, + AttackMessagesResponse, AttackSummary, CreateAttackRequest, CreateAttackResponse, @@ -35,27 +35,12 @@ from pyrit.backend.models.converters import ( ConverterInstance, ConverterInstanceListResponse, - ConverterMetadataResponse, ConverterPreviewRequest, ConverterPreviewResponse, CreateConverterRequest, CreateConverterResponse, - InlineConverterConfig, - NestedConverterConfig, PreviewStep, ) -from pyrit.backend.models.registry import ( - InitializerListResponse, - InitializerMetadataResponse, - ScenarioListResponse, - ScenarioMetadataResponse, - ScorerListResponse, - ScorerMetadataResponse, - TargetMetadataResponse, -) -from pyrit.backend.models.registry import ( - TargetListResponse as RegistryTargetListResponse, -) from pyrit.backend.models.targets import ( CreateTargetRequest, CreateTargetResponse, @@ -67,8 +52,8 @@ # Attacks "AddMessageRequest", "AddMessageResponse", - "AttackDetail", "AttackListResponse", + "AttackMessagesResponse", "AttackSummary", "CreateAttackRequest", "CreateAttackResponse", @@ -90,23 +75,11 @@ # Converters "ConverterInstance", "ConverterInstanceListResponse", - "ConverterMetadataResponse", "ConverterPreviewRequest", "ConverterPreviewResponse", "CreateConverterRequest", "CreateConverterResponse", - "InlineConverterConfig", - "NestedConverterConfig", "PreviewStep", - # Registry - "InitializerListResponse", - "InitializerMetadataResponse", - "ScenarioListResponse", - "ScenarioMetadataResponse", - "ScorerListResponse", - "ScorerMetadataResponse", - "RegistryTargetListResponse", - "TargetMetadataResponse", # Targets "CreateTargetRequest", "CreateTargetResponse", diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index db85c13805..8c4621ed64 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -72,7 +72,7 @@ class AttackSummary(BaseModel): name: Optional[str] = Field(None, description="Attack name/label") target_id: str = Field(..., description="Target instance ID") target_type: str = Field(..., description="Target type (e.g., 'azure_openai')") - outcome: Optional[Literal["pending", "success", "failure"]] = Field( + outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( None, description="Attack outcome (null if not yet determined)" ) last_message_preview: Optional[str] = Field( @@ -85,25 +85,15 @@ class AttackSummary(BaseModel): # ============================================================================ -# Attack Detail (Single Attack View) +# Attack Messages Response # ============================================================================ -class AttackDetail(BaseModel): - """Detailed view of an attack (includes all messages).""" +class AttackMessagesResponse(BaseModel): + """Response containing all messages for an attack.""" - attack_id: str = Field(..., description="Unique attack identifier") - name: Optional[str] = Field(None, description="Attack name/label") - target_id: str = Field(..., description="Target instance ID") - target_type: str = Field(..., description="Target type (e.g., 'azure_openai')") - outcome: Optional[Literal["pending", "success", "failure"]] = Field(None, description="Attack outcome") - prepended_conversation: List[Message] = Field( - default_factory=list, description="Prepended messages (system prompts, branching context)" - ) - messages: List[Message] = Field(default_factory=list, description="Attack messages in order") - labels: Dict[str, str] = Field(default_factory=dict, description="User-defined labels for filtering") - created_at: datetime = Field(..., description="Attack creation timestamp") - updated_at: datetime = Field(..., description="Last update timestamp") + attack_id: str = Field(..., description="Attack identifier") + messages: List[Message] = Field(default_factory=list, description="All messages in order") # ============================================================================ @@ -123,11 +113,25 @@ class AttackListResponse(BaseModel): # ============================================================================ +# ============================================================================ +# Message Input Models +# ============================================================================ + + +class MessagePieceRequest(BaseModel): + """A piece of content for a message.""" + + data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") + original_value: str = Field(..., description="Original value (text or base64 for media)") + converted_value: Optional[str] = Field(None, description="Converted value. If provided, bypasses converters.") + mime_type: Optional[str] = Field(None, description="MIME type for media content") + + class PrependedMessageRequest(BaseModel): """A message to prepend to the attack (for system prompt/branching).""" role: Literal["user", "assistant", "system"] = Field(..., description="Message role") - content: str = Field(..., description="Message content (text)") + pieces: List[MessagePieceRequest] = Field(..., description="Message pieces (supports multimodal)") class CreateAttackRequest(BaseModel): @@ -156,7 +160,7 @@ class CreateAttackResponse(BaseModel): class UpdateAttackRequest(BaseModel): """Request to update an attack's outcome.""" - outcome: Literal["pending", "success", "failure"] = Field(..., description="Updated attack outcome") + outcome: Literal["undetermined", "success", "failure"] = Field(..., description="Updated attack outcome") # ============================================================================ @@ -164,17 +168,6 @@ class UpdateAttackRequest(BaseModel): # ============================================================================ -class MessagePieceRequest(BaseModel): - """A piece of content for a message.""" - - data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") - original_value: str = Field(..., description="Original value (text or base64 for media)") - converted_value: Optional[str] = Field( - None, description="Converted value. If provided, bypasses converters." - ) - mime_type: Optional[str] = Field(None, description="MIME type for media content") - - class AddMessageRequest(BaseModel): """ Request to add a message to an attack. @@ -199,9 +192,10 @@ class AddMessageResponse(BaseModel): """ Response after adding a message. - Returns the updated attack detail. If send=True was used, the new + Returns the attack metadata and all messages. If send=True was used, the new assistant response will be in the messages list. Check response_error on the assistant's message pieces if the target returned an error. """ - attack: AttackDetail = Field(..., description="Updated attack with new message(s)") + attack: AttackSummary = Field(..., description="Updated attack metadata") + messages: AttackMessagesResponse = Field(..., description="All messages including new one(s)") diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index f507f2d2c8..903049ea1b 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -4,30 +4,20 @@ """ Converter-related request and response models. -Converters have two concepts: -- Types: Static metadata bundled with frontend (from registry) -- Instances: Runtime objects created via API with specific configuration - -This module defines both the Instance models and preview functionality. -Nested converters (e.g., SelectiveTextConverter wrapping Base64Converter) are supported. +This module defines the Instance models and preview functionality. """ -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field -from pyrit.backend.models.registry import ConverterMetadataResponse from pyrit.models import PromptDataType __all__ = [ - "ConverterMetadataResponse", "ConverterInstance", "ConverterInstanceListResponse", "CreateConverterRequest", "CreateConverterResponse", - "InlineConverterConfig", - "NestedConverterConfig", "ConverterPreviewRequest", "ConverterPreviewResponse", "PreviewStep", @@ -39,25 +29,6 @@ # ============================================================================ -class InlineConverterConfig(BaseModel): - """Inline converter configuration (type + params).""" - - type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") - params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters") - - -class NestedConverterConfig(BaseModel): - """ - Converter config that may contain nested converters. - - Used for composite converters like SelectiveTextConverter that wrap other converters. - The 'converter' param can contain another NestedConverterConfig. - """ - - type: str = Field(..., description="Converter type") - params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters") - - class ConverterInstance(BaseModel): """A registered converter instance.""" @@ -65,8 +36,6 @@ class ConverterInstance(BaseModel): type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters (resolved)") - created_at: datetime = Field(..., description="Creation timestamp") - source: Literal["initializer", "user"] = Field(..., description="How the converter was created") class ConverterInstanceListResponse(BaseModel): @@ -76,18 +45,13 @@ class ConverterInstanceListResponse(BaseModel): class CreateConverterRequest(BaseModel): - """ - Request to create a new converter instance. - - Supports nested converters - if params contains a 'converter' key with - an InlineConverterConfig, the backend will create both and link them. - """ + """Request to create a new converter instance.""" type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field( default_factory=dict, - description="Converter parameters (may include nested 'converter' config)", + description="Converter parameters", ) @@ -97,12 +61,7 @@ class CreateConverterResponse(BaseModel): converter_id: str = Field(..., description="Unique converter instance identifier") type: str = Field(..., description="Converter type") display_name: Optional[str] = Field(None, description="Human-readable display name") - params: Dict[str, Any] = Field(default_factory=dict, description="Resolved parameters (nested converters have IDs)") - created_converters: Optional[List[ConverterInstance]] = Field( - None, description="All converters created (including nested), ordered inner-to-outer" - ) - created_at: datetime = Field(..., description="Creation timestamp") - source: Literal["user"] = Field(default="user", description="Source is always 'user' for API-created") + params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters") # ============================================================================ @@ -113,7 +72,7 @@ class CreateConverterResponse(BaseModel): class PreviewStep(BaseModel): """A single step in the conversion preview.""" - converter_id: Optional[str] = Field(None, description="Converter instance ID (if using ID)") + converter_id: str = Field(..., description="Converter instance ID") converter_type: str = Field(..., description="Converter type") input_value: str = Field(..., description="Input to this converter") input_data_type: PromptDataType = Field(..., description="Input data type") @@ -126,8 +85,7 @@ class ConverterPreviewRequest(BaseModel): original_value: str = Field(..., description="Text to convert") original_value_data_type: PromptDataType = Field(default="text", description="Data type of original value") - converter_ids: Optional[List[str]] = Field(None, description="Converter instance IDs to apply") - converters: Optional[List[InlineConverterConfig]] = Field(None, description="Inline converter definitions") + converter_ids: List[str] = Field(..., description="Converter instance IDs to apply") class ConverterPreviewResponse(BaseModel): diff --git a/pyrit/backend/models/registry.py b/pyrit/backend/models/registry.py deleted file mode 100644 index 4c2cfffa8b..0000000000 --- a/pyrit/backend/models/registry.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Registry response models. - -Models for targets, scenarios, scorers, converters, and initializers. -""" - -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field - -from pyrit.models import PromptDataType - -# ============================================================================ -# Common -# ============================================================================ - - -class ParameterInfo(BaseModel): - """Information about a constructor parameter.""" - - name: str = Field(..., description="Parameter name") - type_hint: Optional[str] = Field(None, description="Type hint as string") - required: bool = Field(..., description="Whether parameter is required") - default: Optional[Any] = Field(None, description="Default value if not required") - - -# ============================================================================ -# Targets -# ============================================================================ - - -class TargetMetadataResponse(BaseModel): - """Metadata for a target type.""" - - name: str = Field(..., description="Registry name") - class_name: str = Field(..., description="Python class name") - description: str = Field(..., description="Target description") - is_chat_target: bool = Field(..., description="Whether target supports chat/system prompts") - supports_json_response: bool = Field(..., description="Whether target supports JSON response format") - supported_data_types: List[PromptDataType] = Field(..., description="Supported input data types") - params_schema: Optional[Dict[str, Any]] = Field(None, description="Parameter schema") - - -class TargetListResponse(BaseModel): - """Response containing list of available targets.""" - - targets: List[TargetMetadataResponse] = Field(..., description="Available target types") - - -# ============================================================================ -# Scenarios -# ============================================================================ - - -class ScenarioMetadataResponse(BaseModel): - """Metadata for a scenario type.""" - - name: str = Field(..., description="Registry name") - class_name: str = Field(..., description="Python class name") - description: str = Field(..., description="Scenario description") - default_strategy: str = Field(..., description="Default strategy name") - all_strategies: List[str] = Field(..., description="All available strategies") - aggregate_strategies: List[str] = Field(..., description="Composite/aggregate strategies") - default_datasets: List[str] = Field(..., description="Default dataset names") - max_dataset_size: Optional[int] = Field(None, description="Maximum dataset size limit") - - -class ScenarioListResponse(BaseModel): - """Response containing list of available scenarios.""" - - scenarios: List[ScenarioMetadataResponse] = Field(..., description="Available scenarios") - - -# ============================================================================ -# Scorers -# ============================================================================ - - -class ScorerMetadataResponse(BaseModel): - """Metadata for a registered scorer instance.""" - - name: str = Field(..., description="Registry name") - class_name: str = Field(..., description="Python class name") - description: str = Field(..., description="Scorer description") - scorer_type: str = Field(..., description="Score type (true_false or float_scale)") - scorer_identifier: Dict[str, Any] = Field(..., description="Scorer identifier (filtered)") - - -class ScorerListResponse(BaseModel): - """Response containing list of registered scorers.""" - - scorers: List[ScorerMetadataResponse] = Field(..., description="Registered scorer instances") - - -# ============================================================================ -# Initializers -# ============================================================================ - - -class InitializerMetadataResponse(BaseModel): - """Metadata for an initializer.""" - - name: str = Field(..., description="Registry name") - class_name: str = Field(..., description="Python class name") - description: str = Field(..., description="Initializer description") - required_env_vars: List[str] = Field(..., description="Required environment variables") - execution_order: int = Field(..., description="Execution order priority") - - -class InitializerListResponse(BaseModel): - """Response containing list of available initializers.""" - - initializers: List[InitializerMetadataResponse] = Field(..., description="Available initializers") - - -# ============================================================================ -# Converters -# ============================================================================ - - -class ConverterMetadataResponse(BaseModel): - """Metadata for a converter type.""" - - name: str = Field(..., description="Registry name (snake_case)") - class_name: str = Field(..., description="Python class name") - description: str = Field(..., description="Converter description") - supported_input_types: List[PromptDataType] = Field(..., description="Supported input data types") - supported_output_types: List[PromptDataType] = Field(..., description="Supported output data types") - is_llm_based: bool = Field(..., description="Whether converter requires LLM calls") - is_deterministic: bool = Field(..., description="Whether same input produces same output") - params_schema: Optional[Dict[str, Any]] = Field(None, description="Parameter schema") - - -class ConverterListResponse(BaseModel): - """Response containing list of available converters.""" - - converters: List[ConverterMetadataResponse] = Field(..., description="Available converter types") diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index d387680cc4..e1d1bf1dbb 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -11,8 +11,7 @@ This module defines the Instance models for runtime target management. """ -from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -28,8 +27,6 @@ class TargetInstance(BaseModel): type: str = Field(..., description="Target type (e.g., 'azure_openai', 'text_target')") display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field(default_factory=dict, description="Target configuration (sensitive fields filtered)") - created_at: datetime = Field(..., description="Instance creation timestamp") - source: Literal["initializer", "user"] = Field(..., description="How the target was created") class TargetListResponse(BaseModel): @@ -53,5 +50,3 @@ class CreateTargetResponse(BaseModel): type: str = Field(..., description="Target type") display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field(default_factory=dict, description="Filtered configuration (no secrets)") - created_at: datetime = Field(..., description="Instance creation timestamp") - source: Literal["user"] = Field(default="user", description="Source is always 'user' for API-created targets") diff --git a/pyrit/backend/routes/__init__.py b/pyrit/backend/routes/__init__.py index 7aa362d23f..6994f07d51 100644 --- a/pyrit/backend/routes/__init__.py +++ b/pyrit/backend/routes/__init__.py @@ -5,12 +5,13 @@ API route handlers. """ -from pyrit.backend.routes import attacks, converters, health, targets, version +from pyrit.backend.routes import attacks, converters, health, labels, targets, version __all__ = [ "attacks", "converters", "health", + "labels", "targets", "version", ] diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 7401923bcf..30b0f060e9 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -8,15 +8,16 @@ This is the attack-centric API design. """ -from typing import Literal, Optional +from typing import Dict, List, Literal, Optional from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.attacks import ( AddMessageRequest, AddMessageResponse, - AttackDetail, AttackListResponse, + AttackMessagesResponse, + AttackSummary, CreateAttackRequest, CreateAttackResponse, UpdateAttackRequest, @@ -27,13 +28,34 @@ router = APIRouter(prefix="/attacks", tags=["attacks"]) +def _parse_labels(label_params: Optional[List[str]]) -> Optional[Dict[str, str]]: + """ + Parse label query params in 'key:value' format to a dict. + + Returns: + Dict mapping label keys to values, or None if no valid labels. + """ + if not label_params: + return None + labels = {} + for param in label_params: + if ":" in param: + key, value = param.split(":", 1) + labels[key.strip()] = value.strip() + return labels if labels else None + + @router.get( "", response_model=AttackListResponse, ) async def list_attacks( target_id: Optional[str] = Query(None, description="Filter by target instance ID"), - outcome: Optional[Literal["pending", "success", "failure"]] = Query(None, description="Filter by outcome"), + outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), + name: Optional[str] = Query(None, description="Filter by attack name (substring match)"), + label: Optional[List[str]] = Query(None, description="Filter by labels (format: key:value, repeatable)"), + min_turns: Optional[int] = Query(None, ge=0, description="Filter by minimum executed turns"), + max_turns: Optional[int] = Query(None, ge=0, description="Filter by maximum executed turns"), limit: int = Query(20, ge=1, le=100, description="Maximum items per page"), cursor: Optional[str] = Query(None, description="Pagination cursor (attack_id)"), ) -> AttackListResponse: @@ -47,9 +69,14 @@ async def list_attacks( AttackListResponse: Paginated list of attack summaries. """ service = get_attack_service() + labels = _parse_labels(label) return await service.list_attacks( target_id=target_id, outcome=outcome, + name=name, + labels=labels, + min_turns=min_turns, + max_turns=max_turns, limit=limit, cursor=cursor, ) @@ -88,19 +115,19 @@ async def create_attack(request: CreateAttackRequest) -> CreateAttackResponse: @router.get( "/{attack_id}", - response_model=AttackDetail, + response_model=AttackSummary, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, }, ) -async def get_attack(attack_id: str) -> AttackDetail: +async def get_attack(attack_id: str) -> AttackSummary: """ - Get attack details including all messages. + Get attack details. - Returns the full attack with prepended_conversation and all messages. + Returns the attack metadata. Use GET /attacks/{id}/messages for messages. Returns: - AttackDetail: Full attack details with messages. + AttackSummary: Attack details. """ service = get_attack_service() @@ -116,7 +143,7 @@ async def get_attack(attack_id: str) -> AttackDetail: @router.patch( "/{attack_id}", - response_model=AttackDetail, + response_model=AttackSummary, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, }, @@ -124,14 +151,14 @@ async def get_attack(attack_id: str) -> AttackDetail: async def update_attack( attack_id: str, request: UpdateAttackRequest, -) -> AttackDetail: +) -> AttackSummary: """ Update an attack's outcome. - Used to mark attacks as success/failure/pending. + Used to mark attacks as success/failure/undetermined. Returns: - AttackDetail: Updated attack details. + AttackSummary: Updated attack details. """ service = get_attack_service() @@ -145,6 +172,34 @@ async def update_attack( return attack +@router.get( + "/{attack_id}/messages", + response_model=AttackMessagesResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Attack not found"}, + }, +) +async def get_attack_messages(attack_id: str) -> AttackMessagesResponse: + """ + Get all messages for an attack. + + Returns prepended conversation and all messages in order. + + Returns: + AttackMessagesResponse: All messages for the attack. + """ + service = get_attack_service() + + messages = await service.get_attack_messages(attack_id) + if not messages: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Attack '{attack_id}' not found", + ) + + return messages + + @router.post( "/{attack_id}/messages", response_model=AddMessageResponse, diff --git a/pyrit/backend/routes/initializers.py b/pyrit/backend/routes/initializers.py deleted file mode 100644 index 24f0b1d642..0000000000 --- a/pyrit/backend/routes/initializers.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Initializers API routes. - -Provides endpoints for listing available initializers. -""" - -from typing import List - -from fastapi import APIRouter - -from pyrit.backend.models.registry import InitializerMetadataResponse -from pyrit.backend.services import get_registry_service - -router = APIRouter(prefix="/initializers", tags=["initializers"]) - - -@router.get( - "", - response_model=List[InitializerMetadataResponse], -) -async def list_initializers() -> List[InitializerMetadataResponse]: - """ - List available initializers. - - Returns metadata about all registered initializers. - - Returns: - List[InitializerMetadataResponse]: List of initializer metadata. - """ - service = get_registry_service() - return service.get_initializers() diff --git a/pyrit/backend/routes/labels.py b/pyrit/backend/routes/labels.py new file mode 100644 index 0000000000..ce057fbf04 --- /dev/null +++ b/pyrit/backend/routes/labels.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Labels API routes. + +Provides access to unique label values for filtering in the GUI. +""" + +from typing import TYPE_CHECKING, Dict, List, Literal + +from fastapi import APIRouter, Query +from pydantic import BaseModel, Field + +from pyrit.memory import CentralMemory + +if TYPE_CHECKING: + from pyrit.memory import MemoryInterface + +router = APIRouter(prefix="/labels", tags=["labels"]) + + +class LabelOptionsResponse(BaseModel): + """Response containing unique label keys and their values.""" + + source: str = Field(..., description="Source type (e.g., 'attacks')") + labels: Dict[str, List[str]] = Field(..., description="Map of label keys to their unique values") + + +@router.get( + "", + response_model=LabelOptionsResponse, +) +async def get_label_options( + source: Literal["attacks"] = Query( + "attacks", + description="Source type to get labels from. Currently only 'attacks' is supported.", + ), +) -> LabelOptionsResponse: + """ + Get unique label keys and values for filtering. + + Returns all unique label key-value combinations from the specified source. + Useful for populating filter dropdowns in the GUI. + + Args: + source: The source type to query labels from. + + Returns: + LabelOptionsResponse: Map of label keys to their unique values. + """ + memory = CentralMemory.get_memory_instance() + + if source == "attacks": + labels = _get_attack_labels(memory) + else: + # Future: add support for other sources + labels = {} + + return LabelOptionsResponse(source=source, labels=labels) + + +def _get_attack_labels(memory: "MemoryInterface") -> Dict[str, List[str]]: + """ + Extract unique labels from all attack results. + + Returns: + Dict mapping label keys to sorted lists of unique values. + """ + attack_results = memory.get_attack_results() + + # Collect all unique key-value pairs + label_values: Dict[str, set[str]] = {} + + for ar in attack_results: + if ar.metadata: + for key, value in ar.metadata.items(): + # Skip internal metadata keys + if key.startswith("_") or key in ("created_at", "updated_at"): + continue + # Only include string values + if isinstance(value, str): + if key not in label_values: + label_values[key] = set() + label_values[key].add(value) + + # Convert sets to sorted lists + return {key: sorted(values) for key, values in sorted(label_values.items())} diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py deleted file mode 100644 index 0f45b8c53d..0000000000 --- a/pyrit/backend/routes/scenarios.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Scenarios API routes. - -Provides endpoints for listing available scenarios. -""" - -from typing import List - -from fastapi import APIRouter - -from pyrit.backend.models.registry import ScenarioMetadataResponse -from pyrit.backend.services import get_registry_service - -router = APIRouter(prefix="/scenarios", tags=["scenarios"]) - - -@router.get( - "", - response_model=List[ScenarioMetadataResponse], -) -async def list_scenarios() -> List[ScenarioMetadataResponse]: - """ - List available scenarios. - - Returns metadata about all registered scenarios. - - Returns: - List[ScenarioMetadataResponse]: List of scenario metadata. - """ - service = get_registry_service() - return service.get_scenarios() diff --git a/pyrit/backend/routes/scorers.py b/pyrit/backend/routes/scorers.py deleted file mode 100644 index b516746555..0000000000 --- a/pyrit/backend/routes/scorers.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Scorers API routes. - -Provides endpoints for listing available scorers. -""" - -from typing import List, Optional - -from fastapi import APIRouter, Query - -from pyrit.backend.models.registry import ScorerMetadataResponse -from pyrit.backend.services import get_registry_service - -router = APIRouter(prefix="/scorers", tags=["scorers"]) - - -@router.get( - "", - response_model=List[ScorerMetadataResponse], -) -async def list_scorers( - scorer_type: Optional[str] = Query( - None, description="Filter by scorer type (true_false or float_scale)" - ), -) -> List[ScorerMetadataResponse]: - """ - List available scorers. - - Returns metadata about all registered scorer types. - - Returns: - List[ScorerMetadataResponse]: List of scorer metadata. - """ - service = get_registry_service() - return service.get_scorers(scorer_type=scorer_type) diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index c91c71cffa..fe7ac6c907 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -15,10 +15,6 @@ ConverterService, get_converter_service, ) -from pyrit.backend.services.registry_service import ( - RegistryService, - get_registry_service, -) from pyrit.backend.services.target_service import ( TargetService, get_target_service, @@ -29,8 +25,6 @@ "get_attack_service", "ConverterService", "get_converter_service", - "RegistryService", - "get_registry_service", "TargetService", "get_target_service", ] diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index edceec3a8e..fe29b2be1e 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -11,19 +11,19 @@ - Each attack is represented by an AttackResult stored in the database - The AttackResult has a conversation_id that links to the main conversation - Messages are stored via PyRIT memory with that conversation_id -- For GUI use, it's a 1-to-1 mapping: one AttackResult, one conversation -- Future AI-generated attacks may have multiple related conversations +- For human-led attacks, it's a 1-to-1 mapping: one AttackResult, one conversation +- AI-generated attacks may have multiple related conversations """ import uuid from datetime import datetime, timezone -from typing import List, Literal, Optional, cast +from typing import Any, Dict, List, Literal, Optional, cast from pyrit.backend.models.attacks import ( AddMessageRequest, AddMessageResponse, - AttackDetail, AttackListResponse, + AttackMessagesResponse, AttackSummary, CreateAttackRequest, CreateAttackResponse, @@ -36,10 +36,9 @@ from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service from pyrit.memory import CentralMemory -from pyrit.models import AttackOutcome, AttackResult +from pyrit.models import AttackOutcome, AttackResult, PromptDataType from pyrit.models import Message as PyritMessage from pyrit.models import MessagePiece as PyritMessagePiece -from pyrit.models import PromptDataType from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer @@ -62,7 +61,11 @@ async def list_attacks( self, *, target_id: Optional[str] = None, - outcome: Optional[Literal["pending", "success", "failure"]] = None, + outcome: Optional[Literal["undetermined", "success", "failure"]] = None, + name: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + min_turns: Optional[int] = None, + max_turns: Optional[int] = None, limit: int = 20, cursor: Optional[str] = None, ) -> AttackListResponse: @@ -70,27 +73,48 @@ async def list_attacks( List attacks with optional filtering and pagination. Queries AttackResult entries from the database. + + Args: + target_id: Filter by target instance ID (from attack_identifier). + outcome: Filter by attack outcome. + name: Filter by attack name (substring match on attack_identifier.name). + labels: Filter by labels (all must match). + min_turns: Filter by minimum executed turns. + max_turns: Filter by maximum executed turns. + limit: Maximum items to return. + cursor: Pagination cursor. + + Returns: + AttackListResponse with filtered and paginated attack summaries. """ # Map outcome string to AttackOutcome enum value for filtering - outcome_filter = None - if outcome: - outcome_map = { - "pending": "undetermined", - "success": "success", - "failure": "failure", - } - outcome_filter = outcome_map.get(outcome) - - attack_results = self._memory.get_attack_results(outcome=outcome_filter) - - # Convert to summaries and filter + outcome_filter = outcome # Already matches enum values + + # Use labels filter at the database level if supported + attack_results = self._memory.get_attack_results( + outcome=outcome_filter, + labels=labels, + ) + + # Convert to summaries and apply filters summaries = [] for ar in attack_results: - # Get target info from attack_identifier + # Filter by target_id ar_target_id = ar.attack_identifier.get("target_id", "") if target_id and ar_target_id != target_id: continue + # Filter by name (substring match) + ar_name = ar.attack_identifier.get("name", "") + if name and name.lower() not in ar_name.lower(): + continue + + # Filter by executed_turns + if min_turns is not None and ar.executed_turns < min_turns: + continue + if max_turns is not None and ar.executed_turns > max_turns: + continue + summary = self._build_summary(ar) summaries.append(summary) @@ -103,16 +127,17 @@ async def list_attacks( return AttackListResponse( items=page, - pagination=PaginationInfo( - limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor - ), + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) - async def get_attack(self, attack_id: str) -> Optional[AttackDetail]: + async def get_attack(self, attack_id: str) -> Optional[AttackSummary]: """ - Get attack details including all messages. + Get attack details (high-level metadata, no messages). - Queries the AttackResult and its conversation from the database. + Queries the AttackResult from the database. + + Returns: + AttackSummary if found, None otherwise. """ # Get the attack result results = self._memory.get_attack_results(conversation_id=attack_id) @@ -121,20 +146,47 @@ async def get_attack(self, attack_id: str) -> Optional[AttackDetail]: ar = results[0] - # Get messages for this conversation + # Get message count pyrit_messages = self._memory.get_conversation(conversation_id=attack_id) - backend_messages = self._convert_pyrit_messages(pyrit_messages) + message_count = len(list(pyrit_messages)) - return AttackDetail( + created_str = ar.metadata.get("created_at") + updated_str = ar.metadata.get("updated_at") + created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) + updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at + + return AttackSummary( attack_id=attack_id, name=ar.attack_identifier.get("name"), target_id=ar.attack_identifier.get("target_id", ""), target_type=ar.attack_identifier.get("target_type", ""), outcome=self._map_outcome(ar.outcome), - prepended_conversation=[], + last_message_preview=self._get_last_message_preview(attack_id), + message_count=message_count, + labels=ar.metadata.get("labels", {}), + created_at=created_at, + updated_at=updated_at, + ) + + async def get_attack_messages(self, attack_id: str) -> Optional[AttackMessagesResponse]: + """ + Get all messages for an attack. + + Returns: + AttackMessagesResponse if attack found, None otherwise. + """ + # Check attack exists + results = self._memory.get_attack_results(conversation_id=attack_id) + if not results: + return None + + # Get messages for this conversation + pyrit_messages = self._memory.get_conversation(conversation_id=attack_id) + backend_messages = self._translate_pyrit_messages_to_backend(list(pyrit_messages)) + + return AttackMessagesResponse( + attack_id=attack_id, messages=backend_messages, - created_at=ar.metadata.get("created_at", datetime.now(timezone.utc)), - updated_at=ar.metadata.get("updated_at", datetime.now(timezone.utc)), ) async def create_attack(self, request: CreateAttackRequest) -> CreateAttackResponse: @@ -142,6 +194,9 @@ async def create_attack(self, request: CreateAttackRequest) -> CreateAttackRespo Create a new attack. Creates an AttackResult with a new conversation_id. + + Returns: + CreateAttackResponse with the new attack's ID and creation time. """ target_service = get_target_service() target_instance = await target_service.get_target(request.target_id) @@ -182,13 +237,14 @@ async def create_attack(self, request: CreateAttackRequest) -> CreateAttackRespo return CreateAttackResponse(attack_id=conversation_id, created_at=now) - async def update_attack( - self, attack_id: str, request: UpdateAttackRequest - ) -> Optional[AttackDetail]: + async def update_attack(self, attack_id: str, request: UpdateAttackRequest) -> Optional[AttackSummary]: """ Update an attack's outcome. Updates the AttackResult in the database. + + Returns: + Updated AttackSummary if found, None otherwise. """ results = self._memory.get_attack_results(conversation_id=attack_id) if not results: @@ -196,7 +252,7 @@ async def update_attack( # Map outcome outcome_map = { - "pending": AttackOutcome.UNDETERMINED, + "undetermined": AttackOutcome.UNDETERMINED, "success": AttackOutcome.SUCCESS, "failure": AttackOutcome.FAILURE, } @@ -213,13 +269,14 @@ async def update_attack( return await self.get_attack(attack_id) - async def add_message( - self, attack_id: str, request: AddMessageRequest - ) -> AddMessageResponse: + async def add_message(self, attack_id: str, request: AddMessageRequest) -> AddMessageResponse: """ Add a message to an attack, optionally sending to target. Messages are stored in the database via PromptNormalizer. + + Returns: + AddMessageResponse containing the updated attack detail. """ # Check if attack exists results = self._memory.get_attack_results(conversation_id=attack_id) @@ -247,14 +304,23 @@ async def add_message( if attack_detail is None: raise ValueError(f"Attack '{attack_id}' not found after update") - return AddMessageResponse(attack=attack_detail) + attack_messages = await self.get_attack_messages(attack_id) + if attack_messages is None: + raise ValueError(f"Attack '{attack_id}' messages not found after update") + + return AddMessageResponse(attack=attack_detail, messages=attack_messages) # ======================================================================== # Private Helper Methods - Summary Building # ======================================================================== def _build_summary(self, ar: AttackResult) -> AttackSummary: - """Build an AttackSummary from an AttackResult.""" + """ + Build an AttackSummary from an AttackResult. + + Returns: + AttackSummary with message count and preview. + """ # Get message count and last preview pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) message_count = len(set(p.sequence for p in pieces)) @@ -281,14 +347,33 @@ def _build_summary(self, ar: AttackResult) -> AttackSummary: updated_at=updated_at, ) - def _map_outcome(self, outcome: AttackOutcome) -> Optional[Literal["pending", "success", "failure"]]: - """Map AttackOutcome enum to API outcome string.""" + def _map_outcome(self, outcome: AttackOutcome) -> Optional[Literal["undetermined", "success", "failure"]]: + """ + Map AttackOutcome enum to API outcome string. + + Returns: + Outcome string ('success', 'failure', 'undetermined') or None. + """ if outcome == AttackOutcome.SUCCESS: return "success" elif outcome == AttackOutcome.FAILURE: return "failure" else: - return "pending" + return "undetermined" + + def _get_last_message_preview(self, conversation_id: str) -> Optional[str]: + """ + Get a preview of the last message in a conversation. + + Returns: + Truncated last message text, or None if no messages. + """ + pieces = self._memory.get_message_pieces(conversation_id=conversation_id) + if not pieces: + return None + last_piece = max(pieces, key=lambda p: p.sequence) + text = last_piece.converted_value or "" + return text[:100] + "..." if len(text) > 100 else text # ======================================================================== # Private Helper Methods - Pagination @@ -297,7 +382,12 @@ def _map_outcome(self, outcome: AttackOutcome) -> Optional[Literal["pending", "s def _paginate( self, items: List[AttackSummary], cursor: Optional[str], limit: int ) -> tuple[List[AttackSummary], bool]: - """Apply cursor-based pagination.""" + """ + Apply cursor-based pagination. + + Returns: + Tuple of (paginated items, has_more flag). + """ start_idx = 0 if cursor: for i, item in enumerate(items): @@ -313,8 +403,13 @@ def _paginate( # Private Helper Methods - Message Conversion # ======================================================================== - def _convert_pyrit_messages(self, pyrit_messages: list) -> List[Message]: - """Convert PyRIT messages to backend Message format.""" + def _translate_pyrit_messages_to_backend(self, pyrit_messages: List[Any]) -> List[Message]: + """ + Translate PyRIT messages to backend Message format. + + Returns: + List of Message models for the API. + """ messages = [] for msg in pyrit_messages: pieces = [ @@ -323,25 +418,34 @@ def _convert_pyrit_messages(self, pyrit_messages: list) -> List[Message]: data_type=p.converted_value_data_type or "text", original_value=p.original_value, converted_value=p.converted_value or "", - scores=self._convert_scores(p.scores) if hasattr(p, 'scores') and p.scores else [], + scores=self._translate_pyrit_scores_to_backend(p.scores) + if hasattr(p, "scores") and p.scores + else [], response_error=p.response_error or "none", ) for p in msg.message_pieces ] first = msg.message_pieces[0] if msg.message_pieces else None - messages.append(Message( - message_id=str(first.id) if first else str(uuid.uuid4()), - turn_number=first.sequence if first else 0, - role=first.role if first else "user", - pieces=pieces, - created_at=first.timestamp if first else datetime.now(timezone.utc), - )) + messages.append( + Message( + message_id=str(first.id) if first else str(uuid.uuid4()), + turn_number=first.sequence if first else 0, + role=first.role if first else "user", + pieces=pieces, + created_at=first.timestamp if first else datetime.now(timezone.utc), + ) + ) return messages - def _convert_scores(self, scores: list) -> List[Score]: - """Convert PyRIT scores to backend Score format.""" + def _translate_pyrit_scores_to_backend(self, scores: List[Any]) -> List[Score]: + """ + Translate PyRIT scores to backend Score format. + + Returns: + List of Score models for the API. + """ return [ Score( score_id=str(s.id), @@ -360,20 +464,23 @@ def _convert_scores(self, scores: list) -> List[Score]: async def _store_prepended_messages( self, conversation_id: str, - prepended: list, + prepended: List[Any], ) -> None: """Store prepended conversation messages in memory.""" - for seq, msg in enumerate(prepended): - piece = PyritMessagePiece( - role=msg.role, - original_value=msg.content, - original_value_data_type="text", - converted_value=msg.content, - converted_value_data_type="text", - conversation_id=conversation_id, - sequence=seq, - ) - self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + seq = 0 + for msg in prepended: + for p in msg.pieces: + piece = PyritMessagePiece( + role=msg.role, + original_value=p.original_value, + original_value_data_type=cast(PromptDataType, p.data_type), + converted_value=p.converted_value or p.original_value, + converted_value_data_type=cast(PromptDataType, p.data_type), + conversation_id=conversation_id, + sequence=seq, + ) + self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + seq += 1 async def _send_and_store_message( self, @@ -424,7 +531,12 @@ def _build_pyrit_message( conversation_id: str, sequence: int, ) -> PyritMessage: - """Build PyRIT Message from request.""" + """ + Build PyRIT Message from request. + + Returns: + PyritMessage ready to send to the target. + """ pieces = [ PyritMessagePiece( role=request.role, @@ -439,10 +551,13 @@ def _build_pyrit_message( ] return PyritMessage(pieces) - def _get_converter_configs( - self, request: AddMessageRequest - ) -> List[PromptConverterConfiguration]: - """Get converter configurations if needed.""" + def _get_converter_configs(self, request: AddMessageRequest) -> List[PromptConverterConfiguration]: + """ + Get converter configurations if needed. + + Returns: + List of PromptConverterConfiguration for the converters. + """ has_preconverted = any(p.converted_value is not None for p in request.pieces) if has_preconverted or not request.converter_ids: return [] diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 9702e7ed13..edc471d205 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -4,12 +4,15 @@ """ Converter service for managing converter instances. -Handles creation, retrieval, and nested converter support. +Handles creation, retrieval, and preview of converters. +Uses ConverterRegistry as the source of truth. + +If a converter requires another converter (e.g., SelectiveTextConverter), +the inner converter must be created first and passed by ID in params. """ import importlib import uuid -from datetime import datetime, timezone from typing import Any, Dict, List, Literal, Optional, Tuple, cast from pyrit.backend.models.converters import ( @@ -19,19 +22,38 @@ ConverterPreviewResponse, CreateConverterRequest, CreateConverterResponse, - InlineConverterConfig, PreviewStep, ) from pyrit.models import PromptDataType +from pyrit.registry.instance_registries import ConverterRegistry class ConverterService: - """Service for managing converter instances.""" + """ + Service for managing converter instances. + + Uses ConverterRegistry as the sole source of truth. + API metadata is derived from the converter objects. + """ def __init__(self) -> None: """Initialize the converter service.""" - self._instances: Dict[str, ConverterInstance] = {} - self._converter_objects: Dict[str, Any] = {} + self._registry = ConverterRegistry.get_registry_singleton() + + def _build_instance_from_object(self, converter_id: str, converter_obj: Any) -> ConverterInstance: + """ + Build a ConverterInstance from a registry object. + + Returns: + ConverterInstance with metadata derived from the object. + """ + converter_type = converter_obj.__class__.__name__ + return ConverterInstance( + converter_id=converter_id, + type=converter_type, + display_name=None, + params={}, # Params aren't stored on converter objects + ) # ======================================================================== # Public API Methods @@ -40,54 +62,73 @@ def __init__(self) -> None: async def list_converters( self, source: Optional[Literal["initializer", "user"]] = None ) -> ConverterInstanceListResponse: - """List all converter instances.""" - items = list(self._instances.values()) - if source is not None: - items = [c for c in items if c.source == source] + """ + List all converter instances. + + Returns: + ConverterInstanceListResponse containing all registered converters. + """ + # source filter is ignored for now - all come from registry + items: List[ConverterInstance] = [] + for name in self._registry.get_names(): + obj = self._registry.get_instance_by_name(name) + if obj: + items.append(self._build_instance_from_object(name, obj)) return ConverterInstanceListResponse(items=items) async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: - """Get a converter instance by ID.""" - return self._instances.get(converter_id) + """ + Get a converter instance by ID. + + Returns: + ConverterInstance if found, None otherwise. + """ + obj = self._registry.get_instance_by_name(converter_id) + if obj is None: + return None + return self._build_instance_from_object(converter_id, obj) def get_converter_object(self, converter_id: str) -> Optional[Any]: - """Get the actual converter object.""" - return self._converter_objects.get(converter_id) + """ + Get the actual converter object. + + Returns: + The PromptConverter object if found, None otherwise. + """ + return self._registry.get_instance_by_name(converter_id) + + async def create_converter(self, request: CreateConverterRequest) -> CreateConverterResponse: + """ + Create a new converter instance. - async def create_converter( - self, request: CreateConverterRequest - ) -> CreateConverterResponse: - """Create a new converter instance with optional nested converters.""" - config = {"type": request.type, "params": request.params} - converter_id, _, created_instances = self._create_converter_recursive(config, "user") + If params contains a 'converter' key with a converter_id, + the referenced converter object will be resolved and passed. - if request.display_name and converter_id in self._instances: - self._instances[converter_id].display_name = request.display_name + Returns: + CreateConverterResponse with the new converter's details. + """ + converter_id = str(uuid.uuid4()) + + # Resolve any converter references in params and create the object + params = self._resolve_converter_params(request.params) + converter_obj = self._get_converter_class(request.type)(**params) + self._registry.register_instance(converter_obj, name=converter_id) - outer_instance = self._instances[converter_id] return CreateConverterResponse( converter_id=converter_id, type=request.type, display_name=request.display_name, - params=outer_instance.params, - created_converters=created_instances if len(created_instances) > 1 else None, - created_at=outer_instance.created_at, - source="user", + params=request.params, ) - async def delete_converter(self, converter_id: str) -> bool: - """Delete a converter instance.""" - if converter_id in self._instances: - del self._instances[converter_id] - self._converter_objects.pop(converter_id, None) - return True - return False - - async def preview_conversion( - self, request: ConverterPreviewRequest - ) -> ConverterPreviewResponse: - """Preview conversion through a converter pipeline.""" - converters = self._gather_converters_for_preview(request) + async def preview_conversion(self, request: ConverterPreviewRequest) -> ConverterPreviewResponse: + """ + Preview conversion through a converter pipeline. + + Returns: + ConverterPreviewResponse with step-by-step conversion results. + """ + converters = self._gather_converters(request.converter_ids) steps, final_value, final_type = await self._apply_converters( converters, request.original_value, request.original_value_data_type ) @@ -101,7 +142,12 @@ async def preview_conversion( ) def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: - """Get converter objects for a list of IDs.""" + """ + Get converter objects for a list of IDs. + + Returns: + List of converter objects in the same order as the input IDs. + """ converters = [] for conv_id in converter_ids: conv_obj = self.get_converter_object(conv_id) @@ -110,19 +156,34 @@ def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: converters.append(conv_obj) return converters - def instantiate_inline_converters(self, configs: List[InlineConverterConfig]) -> List[Any]: - """Instantiate converters from inline configurations.""" - return [ - self._get_converter_class(config.type)(**config.params) - for config in configs - ] - # ======================================================================== - # Private Helper Methods - Class Resolution + # Private Helper Methods # ======================================================================== + def _resolve_converter_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + Resolve converter references in params. + + Returns: + Params dict with converter_id references replaced by actual objects. + """ + resolved = dict(params) + if "converter" in resolved and isinstance(resolved["converter"], dict): + ref = resolved["converter"] + if "converter_id" in ref: + conv_obj = self.get_converter_object(ref["converter_id"]) + if conv_obj is None: + raise ValueError(f"Referenced converter '{ref['converter_id']}' not found") + resolved["converter"] = conv_obj + return resolved + def _get_converter_class(self, converter_type: str) -> type: - """Get the converter class for a given type.""" + """ + Get the converter class for a given type. + + Returns: + The converter class matching the given type. + """ module = importlib.import_module("pyrit.prompt_converter") cls = getattr(module, converter_type, None) @@ -137,118 +198,43 @@ def _get_converter_class(self, converter_type: str) -> type: raise ValueError(f"Converter type '{converter_type}' not found in pyrit.prompt_converter") def _class_name_patterns(self, type_name: str) -> List[str]: - """Generate class name patterns to try.""" + """ + Generate class name patterns to try. + + Returns: + List of possible class name variations. + """ pascal = "".join(word.capitalize() for word in type_name.split("_")) return [type_name, f"{type_name}Converter", pascal, f"{pascal}Converter"] - # ======================================================================== - # Private Helper Methods - Recursive Creation - # ======================================================================== - - def _create_converter_recursive( - self, - config: Dict[str, Any], - source: Literal["initializer", "user"], - ) -> Tuple[str, Any, List[ConverterInstance]]: - """Recursively create converters, handling nested converter params.""" - converter_type = config["type"] - params = dict(config.get("params", {})) - created_instances: List[ConverterInstance] = [] - - # Handle nested converter - params, created_instances = self._resolve_nested_converter(params, source) - - # Create this converter - converter_obj = self._get_converter_class(converter_type)(**params) - converter_id = self._store_converter(converter_type, converter_obj, config, created_instances, source) - - return converter_id, converter_obj, created_instances - - def _resolve_nested_converter( - self, - params: Dict[str, Any], - source: Literal["initializer", "user"], - ) -> Tuple[Dict[str, Any], List[ConverterInstance]]: - """Resolve nested converter in params if present.""" - created_instances: List[ConverterInstance] = [] - - if "converter" in params and isinstance(params["converter"], dict): - nested_config = params["converter"] - if "type" in nested_config: - _, nested_obj, nested_instances = self._create_converter_recursive(nested_config, source) - created_instances.extend(nested_instances) - params["converter"] = nested_obj - - return params, created_instances - - def _store_converter( - self, - converter_type: str, - converter_obj: Any, - config: Dict[str, Any], - created_instances: List[ConverterInstance], - source: Literal["initializer", "user"], - ) -> str: - """Store converter and return its ID.""" - converter_id = str(uuid.uuid4()) - self._converter_objects[converter_id] = converter_obj - - resolved_params = self._build_resolved_params(config, created_instances) - instance = ConverterInstance( - converter_id=converter_id, - type=converter_type, - display_name=None, - params=resolved_params, - created_at=datetime.now(timezone.utc), - source=source, - ) - self._instances[converter_id] = instance - created_instances.append(instance) - - return converter_id - - def _build_resolved_params( - self, config: Dict[str, Any], created_instances: List[ConverterInstance] - ) -> Dict[str, Any]: - """Build resolved params with nested converter IDs.""" - resolved_params = dict(config.get("params", {})) - if "converter" in resolved_params and isinstance(resolved_params["converter"], dict): - nested_id = created_instances[-1].converter_id if created_instances else None - resolved_params["converter"] = {"converter_id": nested_id} - return resolved_params - - # ======================================================================== - # Private Helper Methods - Preview - # ======================================================================== - - def _gather_converters_for_preview( - self, request: ConverterPreviewRequest - ) -> List[Tuple[Optional[str], str, Any]]: - """Gather converters to apply from request.""" - converters: List[Tuple[Optional[str], str, Any]] = [] - - if request.converter_ids: - for conv_id in request.converter_ids: - conv_obj = self.get_converter_object(conv_id) - if conv_obj is None: - raise ValueError(f"Converter instance '{conv_id}' not found") - instance = self._instances[conv_id] - converters.append((conv_id, instance.type, conv_obj)) - - if request.converters: - for inline_config in request.converters: - conv_obj = self._get_converter_class(inline_config.type)(**inline_config.params) - converters.append((None, inline_config.type, conv_obj)) + def _gather_converters(self, converter_ids: List[str]) -> List[Tuple[str, str, Any]]: + """ + Gather converters to apply from IDs. + Returns: + List of tuples (converter_id, converter_type, converter_obj). + """ + converters: List[Tuple[str, str, Any]] = [] + for conv_id in converter_ids: + conv_obj = self.get_converter_object(conv_id) + if conv_obj is None: + raise ValueError(f"Converter instance '{conv_id}' not found") + conv_type = conv_obj.__class__.__name__ + converters.append((conv_id, conv_type, conv_obj)) return converters async def _apply_converters( self, - converters: List[Tuple[Optional[str], str, Any]], + converters: List[Tuple[str, str, Any]], initial_value: str, initial_type: PromptDataType, ) -> Tuple[List[PreviewStep], str, PromptDataType]: - """Apply converters and collect steps.""" + """ + Apply converters and collect steps. + + Returns: + Tuple of (steps, final_value, final_type). + """ current_value = initial_value current_type = initial_type steps: List[PreviewStep] = [] diff --git a/pyrit/backend/services/registry_service.py b/pyrit/backend/services/registry_service.py deleted file mode 100644 index 2a4109bda3..0000000000 --- a/pyrit/backend/services/registry_service.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Registry service for API access to registered components. - -Wraps component registries with filtering and metadata extraction. -Uses class introspection when registries are not available. -""" - -import inspect -import logging -from typing import Any, Dict, List, Optional, Type - -from pyrit.backend.models.common import filter_sensitive_fields -from pyrit.backend.models.registry import ( - ConverterMetadataResponse, - InitializerMetadataResponse, - ScenarioMetadataResponse, - ScorerMetadataResponse, - TargetMetadataResponse, -) -from pyrit.prompt_converter import PromptConverter -from pyrit.prompt_target import PromptChatTarget, PromptTarget -from pyrit.registry import InitializerRegistry, ScenarioRegistry, ScorerRegistry - -logger = logging.getLogger(__name__) - - -def _extract_params_schema(cls: Type[Any]) -> Dict[str, Any]: - """ - Extract parameter schema from a class constructor. - - Returns: - Dict[str, Any]: Dict with 'required' and 'optional' fields. - """ - required = [] - optional = [] - - try: - sig = inspect.signature(cls.__init__) - - for name, param in sig.parameters.items(): - if name in ("self", "cls", "args", "kwargs"): - continue - - if param.default == inspect.Parameter.empty: - required.append(name) - else: - optional.append(name) - except (ValueError, TypeError): - pass - - return {"required": required, "optional": optional} - - -def _get_all_subclasses(cls: type) -> List[type]: - """ - Recursively get all non-abstract subclasses of a class. - - Returns: - List[type]: List of non-abstract subclass types. - """ - subclasses = [] - for subclass in cls.__subclasses__(): - # Skip abstract classes - if hasattr(subclass, "__abstractmethods__") and subclass.__abstractmethods__: - subclasses.extend(_get_all_subclasses(subclass)) - else: - subclasses.append(subclass) - subclasses.extend(_get_all_subclasses(subclass)) - return subclasses - - -class RegistryService: - """Service for querying component registries.""" - - def get_targets( - self, - *, - is_chat_target: Optional[bool] = None, - ) -> List[TargetMetadataResponse]: - """ - Get available target types via introspection. - - Note: TargetRegistry may not exist yet (PR #1320). - Falls back to class introspection. - - Args: - is_chat_target: Filter by chat target support. - - Returns: - List of target metadata. - """ - # Get all concrete target subclasses via introspection - target_classes = _get_all_subclasses(PromptTarget) - - results = [] - for target_class in target_classes: - # Determine if chat target - is_chat = issubclass(target_class, PromptChatTarget) - - if is_chat_target is not None and is_chat != is_chat_target: - continue - - # Check JSON response support - supports_json = False - if is_chat: - try: - supports_json = hasattr(target_class, "is_json_response_supported") - except Exception: - pass - - # Get supported data types from class attribute if available - supported_types = getattr(target_class, "SUPPORTED_DATA_TYPES", ["text"]) - - results.append( - TargetMetadataResponse( - name=target_class.__name__, - class_name=target_class.__name__, - description=(target_class.__doc__ or "").split("\n")[0].strip(), - is_chat_target=is_chat, - supports_json_response=supports_json, - supported_data_types=list(supported_types), - params_schema=_extract_params_schema(target_class), - ) - ) - - return results - - def get_scenarios(self) -> List[ScenarioMetadataResponse]: - """ - Get all available scenarios from the registry. - - Returns: - List of scenario metadata. - """ - try: - registry = ScenarioRegistry.get_registry_singleton() - metadata_list = registry.list_metadata() - - results = [] - for m in metadata_list: - results.append( - ScenarioMetadataResponse( - name=m.name, - class_name=m.class_name, - description=m.class_description or "", - default_strategy=m.default_strategy, - all_strategies=list(m.all_strategies), - aggregate_strategies=list(m.aggregate_strategies), - default_datasets=list(m.default_datasets), - max_dataset_size=m.max_dataset_size, - ) - ) - return results - except Exception as e: - logger.warning(f"Failed to get scenarios from registry: {e}") - return [] - - def get_scorers( - self, - *, - scorer_type: Optional[str] = None, - ) -> List[ScorerMetadataResponse]: - """ - Get registered scorer instances. - - Args: - scorer_type: Filter by scorer type ('true_false' or 'float_scale'). - - Returns: - List of scorer metadata. - """ - try: - registry = ScorerRegistry.get_registry_singleton() - - # Build filter if scorer_type specified - include_filters: dict[str, object] | None = None - if scorer_type: - include_filters = {"scorer_type": scorer_type} - - metadata_list = registry.list_metadata(include_filters=include_filters) - - results = [] - for m in metadata_list: - # Get scorer identifier and filter sensitive fields - scorer_id = m.scorer_identifier.to_compact_dict() if m.scorer_identifier else {} - filtered_id = filter_sensitive_fields(scorer_id) - - results.append( - ScorerMetadataResponse( - name=m.name, - class_name=m.class_name, - description=m.class_description or "", - scorer_type=m.scorer_type, - scorer_identifier=filtered_id, - ) - ) - return results - except Exception as e: - logger.warning(f"Failed to get scorers from registry: {e}") - return [] - - def get_initializers(self) -> List[InitializerMetadataResponse]: - """ - Get all available initializers from the registry. - - Returns: - List of initializer metadata. - """ - try: - registry = InitializerRegistry.get_registry_singleton() - metadata_list = registry.list_metadata() - - results = [] - for m in metadata_list: - results.append( - InitializerMetadataResponse( - name=m.name, - class_name=m.class_name, - description=m.class_description or "", - required_env_vars=list(m.required_env_vars) if m.required_env_vars else [], - execution_order=getattr(m, "execution_order", 0), - ) - ) - return results - except Exception as e: - logger.warning(f"Failed to get initializers from registry: {e}") - return [] - - def get_converters( - self, - *, - is_llm_based: Optional[bool] = None, - is_deterministic: Optional[bool] = None, - ) -> List[ConverterMetadataResponse]: - """ - Get available converters via introspection. - - Note: ConverterRegistry may not exist yet. - Falls back to class introspection. - - Args: - is_llm_based: Filter by LLM-based converters. - is_deterministic: Filter by deterministic converters. - - Returns: - List of converter metadata. - """ - # Get all converter subclasses using the shared helper - converter_classes = _get_all_subclasses(PromptConverter) - - results = [] - for converter_class in converter_classes: - # Get supported types from class attributes - input_types = getattr(converter_class, "SUPPORTED_INPUT_TYPES", ["text"]) - output_types = getattr(converter_class, "SUPPORTED_OUTPUT_TYPES", ["text"]) - - # Determine if LLM-based (has converter_target parameter) - converter_is_llm_based = False - try: - sig = inspect.signature(converter_class) - converter_is_llm_based = "converter_target" in sig.parameters - except Exception: - pass - - if is_llm_based is not None and converter_is_llm_based != is_llm_based: - continue - - # Assume deterministic if not LLM-based - converter_is_deterministic = not converter_is_llm_based - - if is_deterministic is not None and converter_is_deterministic != is_deterministic: - continue - - results.append( - ConverterMetadataResponse( - name=converter_class.__name__, - class_name=converter_class.__name__, - description=(converter_class.__doc__ or "").split("\n")[0].strip(), - supported_input_types=list(input_types), - supported_output_types=list(output_types), - is_llm_based=converter_is_llm_based, - is_deterministic=converter_is_deterministic, - params_schema=_extract_params_schema(converter_class), - ) - ) - - return results - - -# Singleton instance -_registry_service: Optional[RegistryService] = None - - -def get_registry_service() -> RegistryService: - """ - Get the registry service singleton. - - Returns: - RegistryService: The registry service instance. - """ - global _registry_service - if _registry_service is None: - _registry_service = RegistryService() - return _registry_service diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 7b847e676f..a55911dfd2 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -5,12 +5,12 @@ Target service for managing target instances. Handles creation, retrieval, and lifecycle of runtime target instances. +Uses TargetRegistry as the source of truth. """ import importlib import uuid -from datetime import datetime, timezone -from typing import Any, Dict, Literal, Optional, cast +from typing import Any, List, Literal, Optional, cast from pyrit.backend.models.common import filter_sensitive_fields from pyrit.backend.models.targets import ( @@ -19,42 +19,38 @@ TargetInstance, TargetListResponse, ) +from pyrit.registry.instance_registries import TargetRegistry class TargetService: - """Service for managing target instances.""" + """ + Service for managing target instances. + + Uses TargetRegistry as the sole source of truth. + API metadata is derived from the target objects' identifiers. + """ def __init__(self) -> None: """Initialize the target service.""" - # In-memory storage for target instances - self._instances: Dict[str, TargetInstance] = {} - # Actual instantiated target objects (not serializable) - self._target_objects: Dict[str, Any] = {} + self._registry = TargetRegistry.get_registry_singleton() def _get_target_class(self, target_type: str) -> type: """ Get the target class for a given type. - Args: - target_type: Target type string (e.g., 'azure_openai', 'TextTarget') - Returns: - The target class + The target class matching the given type. """ - # Try to import from pyrit.prompt_target module = importlib.import_module("pyrit.prompt_target") - # Handle both snake_case and PascalCase - # First try direct attribute lookup cls = getattr(module, target_type, None) if cls is not None: return cast(type, cls) - # Try common class name patterns class_name_patterns = [ target_type, f"{target_type}Target", - "".join(word.capitalize() for word in target_type.split("_")), # snake_case to PascalCase + "".join(word.capitalize() for word in target_type.split("_")), "".join(word.capitalize() for word in target_type.split("_")) + "Target", ] @@ -65,6 +61,24 @@ def _get_target_class(self, target_type: str) -> type: raise ValueError(f"Target type '{target_type}' not found in pyrit.prompt_target") + def _build_instance_from_object(self, target_id: str, target_obj: Any) -> TargetInstance: + """ + Build a TargetInstance from a registry object. + + Returns: + TargetInstance with metadata derived from the object. + """ + identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} + target_type = identifier.get("__type__", target_obj.__class__.__name__) + filtered_params = filter_sensitive_fields(identifier) + + return TargetInstance( + target_id=target_id, + type=target_type, + display_name=None, # Could be added to identifier if needed + params=filtered_params, + ) + async def list_targets( self, source: Optional[Literal["initializer", "user"]] = None, @@ -72,42 +86,37 @@ async def list_targets( """ List all target instances. - Args: - source: Optional filter by source ("initializer" or "user") - Returns: - TargetListResponse: List of target instances + TargetListResponse containing all registered targets. """ - items = list(self._instances.values()) - - if source is not None: - items = [t for t in items if t.source == source] - + # source filter is ignored for now - all come from registry + items: List[TargetInstance] = [] + for name in self._registry.get_names(): + obj = self._registry.get_instance_by_name(name) + if obj: + items.append(self._build_instance_from_object(name, obj)) return TargetListResponse(items=items) async def get_target(self, target_id: str) -> Optional[TargetInstance]: """ Get a target instance by ID. - Args: - target_id: Target instance ID - Returns: - TargetInstance or None if not found + TargetInstance if found, None otherwise. """ - return self._instances.get(target_id) + obj = self._registry.get_instance_by_name(target_id) + if obj is None: + return None + return self._build_instance_from_object(target_id, obj) def get_target_object(self, target_id: str) -> Optional[Any]: """ Get the actual target object for use in attacks. - Args: - target_id: Target instance ID - Returns: - The instantiated target object or None if not found + The PromptTarget object if found, None otherwise. """ - return self._target_objects.get(target_id) + return self._registry.get_instance_by_name(target_id) async def create_target( self, @@ -116,98 +125,26 @@ async def create_target( """ Create a new target instance. - Args: - request: Target creation request - Returns: - CreateTargetResponse: Created target details + CreateTargetResponse with the new target's details. """ target_id = str(uuid.uuid4()) - now = datetime.now(timezone.utc) - # Get the target class and instantiate + # Create and register the target object target_class = self._get_target_class(request.type) target_obj = target_class(**request.params) - self._target_objects[target_id] = target_obj - - # Get filtered params from target identifier - target_identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} - filtered_params = filter_sensitive_fields(target_identifier) + self._registry.register_instance(target_obj, name=target_id) - # Store the target instance metadata - instance = TargetInstance( - target_id=target_id, - type=request.type, - display_name=request.display_name, - params=filtered_params, - created_at=now, - source="user", - ) - self._instances[target_id] = instance + # Build response from the object's identifier + identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} + filtered_params = filter_sensitive_fields(identifier) return CreateTargetResponse( target_id=target_id, type=request.type, display_name=request.display_name, params=filtered_params, - created_at=now, - source="user", - ) - - async def delete_target(self, target_id: str) -> bool: - """ - Delete a target instance. - - Args: - target_id: Target instance ID - - Returns: - True if deleted, False if not found - """ - if target_id in self._instances: - del self._instances[target_id] - self._target_objects.pop(target_id, None) - return True - return False - - async def register_initializer_target( - self, - target_type: str, - target_obj: Any, - display_name: Optional[str] = None, - ) -> TargetInstance: - """ - Register a target from an initializer (not user-created). - - Args: - target_type: Target type string - target_obj: Already-instantiated target object - display_name: Optional display name - - Returns: - TargetInstance: The registered target - """ - target_id = str(uuid.uuid4()) - now = datetime.now(timezone.utc) - - # Store the target object - self._target_objects[target_id] = target_obj - - # Get filtered params from target identifier - target_identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} - filtered_params = filter_sensitive_fields(target_identifier) - - instance = TargetInstance( - target_id=target_id, - type=target_type, - display_name=display_name, - params=filtered_params, - created_at=now, - source="initializer", ) - self._instances[target_id] = instance - - return instance # Global service instance diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py new file mode 100644 index 0000000000..657569ae9d --- /dev/null +++ b/pyrit/cli/pyrit_backend.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +PyRIT Backend CLI - Command-line interface for running the PyRIT backend server. + +This module provides the main entry point for the pyrit_backend command. +""" + +import asyncio +import sys +from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from typing import Optional + +from pyrit.cli import frontend_core + + +def parse_args(args: Optional[list[str]] = None) -> Namespace: + """ + Parse command-line arguments for the PyRIT backend server. + + Returns: + Namespace: Parsed command-line arguments. + """ + parser = ArgumentParser( + prog="pyrit_backend", + description="""PyRIT Backend - Run the PyRIT backend API server + +Examples: + # Start backend with default settings + pyrit_backend + + # Start with built-in initializers + pyrit_backend --initializers airt + + # Start with custom initialization scripts + pyrit_backend --initialization-scripts ./my_targets.py + + # Start with custom port and host + pyrit_backend --host 0.0.0.0 --port 8080 + + # List available initializers + pyrit_backend --list-initializers +""", + formatter_class=RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) + + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind the server to (default: 8000)", + ) + + parser.add_argument( + "--log-level", + type=frontend_core.validate_log_level_argparse, + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: INFO)", + ) + + parser.add_argument( + "--list-initializers", + action="store_true", + help="List all available initializers and exit", + ) + + parser.add_argument( + "--database", + type=frontend_core.validate_database_argparse, + default=frontend_core.SQLITE, + help=( + f"Database type to use for memory storage ({frontend_core.IN_MEMORY}, " + f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE})" + ), + ) + + parser.add_argument( + "--initializers", + type=str, + nargs="+", + help=frontend_core.ARG_HELP["initializers"], + ) + + parser.add_argument( + "--initialization-scripts", + type=str, + nargs="+", + help=frontend_core.ARG_HELP["initialization_scripts"], + ) + + parser.add_argument( + "--env-files", + type=str, + nargs="+", + help=frontend_core.ARG_HELP["env_files"], + ) + + parser.add_argument( + "--reload", + action="store_true", + help="Enable auto-reload for development (watches for file changes)", + ) + + return parser.parse_args(args) + + +async def initialize_and_run(parsed_args: Namespace) -> int: + """ + Initialize PyRIT and start the backend server. + + Returns: + int: Exit code (0 for success, 1 for error). + """ + from pyrit.setup import initialize_pyrit_async + + # Resolve initialization scripts if provided + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 + + # Resolve env files if provided + env_files = None + if parsed_args.env_files: + try: + env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + except ValueError as e: + print(f"Error: {e}") + return 1 + + # Resolve initializer instances if names provided + initializer_instances = None + if parsed_args.initializers: + from pyrit.registry import InitializerRegistry + + registry = InitializerRegistry() + initializer_instances = [] + for name in parsed_args.initializers: + try: + initializer_class = registry.get_class(name) + initializer_instances.append(initializer_class()) + except Exception as e: + print(f"Error: Could not load initializer '{name}': {e}") + return 1 + + # Initialize PyRIT with the provided configuration + print("🔧 Initializing PyRIT...") + await initialize_pyrit_async( + memory_db_type=parsed_args.database, + initialization_scripts=initialization_scripts, + initializers=initializer_instances, + env_files=env_files, + ) + + # Start uvicorn server + import uvicorn + + print(f"🚀 Starting PyRIT backend on http://{parsed_args.host}:{parsed_args.port}") + print(f" API Docs: http://{parsed_args.host}:{parsed_args.port}/docs") + + config = uvicorn.Config( + "pyrit.backend.main:app", + host=parsed_args.host, + port=parsed_args.port, + log_level=parsed_args.log_level.lower(), + reload=parsed_args.reload, + ) + server = uvicorn.Server(config) + await server.serve() + + return 0 + + +def main(args: Optional[list[str]] = None) -> int: + """ + Start the PyRIT backend server CLI. + + Returns: + int: Exit code (0 for success, 1 for error). + """ + try: + parsed_args = parse_args(args) + except SystemExit as e: + return e.code if isinstance(e.code, int) else 1 + + # Handle list-initializers command + if parsed_args.list_initializers: + context = frontend_core.FrontendCore(log_level=parsed_args.log_level) + scenarios_path = frontend_core.get_default_initializer_discovery_path() + return asyncio.run(frontend_core.print_initializers_list_async(context=context, discovery_path=scenarios_path)) + + # Run the server + try: + return asyncio.run(initialize_and_run(parsed_args)) + except KeyboardInterrupt: + print("\n🛑 Backend stopped") + return 0 + except Exception as e: + print(f"\nError: {e}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index eab870f0e1..761735e261 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -14,13 +14,21 @@ from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) +from pyrit.registry.instance_registries.converter_registry import ( + ConverterRegistry, +) from pyrit.registry.instance_registries.scorer_registry import ( ScorerRegistry, ) +from pyrit.registry.instance_registries.target_registry import ( + TargetRegistry, +) __all__ = [ # Base class "BaseInstanceRegistry", # Concrete registries + "ConverterRegistry", "ScorerRegistry", + "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/instance_registries/converter_registry.py new file mode 100644 index 0000000000..3439b88c9e --- /dev/null +++ b/pyrit/registry/instance_registries/converter_registry.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Converter registry for managing PyRIT converter instances. + +Converters are registered explicitly via initializers as pre-configured instances. + +NOTE: This is a placeholder implementation. A full implementation will be added soon. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +from pyrit.identifiers import Identifier +from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) + +if TYPE_CHECKING: + from pyrit.prompt_converter import PromptConverter + +logger = logging.getLogger(__name__) + + +# Placeholder identifier type until proper ConverterIdentifier is defined +# TODO: Replace with ConverterIdentifier when available +@dataclass(frozen=True) +class ConverterIdentifier(Identifier): + """Temporary identifier type for converters.""" + + pass + + +class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ConverterIdentifier]): + """ + Registry for managing available converter instances. + + This registry stores pre-configured PromptConverter instances (not classes). + Converters are registered explicitly via initializers after being instantiated + with their required parameters. + + NOTE: This is a placeholder. A full implementation will be added soon. + """ + + @classmethod + def get_registry_singleton(cls) -> "ConverterRegistry": + """ + Get the singleton instance of the ConverterRegistry. + + Returns: + The singleton ConverterRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_instance( + self, + converter: "PromptConverter", + *, + name: Optional[str] = None, + ) -> None: + """ + Register a converter instance. + + Args: + converter: The pre-configured converter instance (not a class). + name: Optional custom registry name. If not provided, + derived from class name (e.g., Base64Converter -> base64). + """ + if name is None: + name = class_name_to_snake_case(converter.__class__.__name__, suffix="Converter") + + self.register(converter, name=name) + logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") + + def get_instance_by_name(self, name: str) -> Optional["PromptConverter"]: + """ + Get a registered converter instance by name. + + Args: + name: The registry name of the converter. + + Returns: + The converter instance, or None if not found. + """ + return self.get(name) + + def _build_metadata(self, name: str, instance: "PromptConverter") -> ConverterIdentifier: + """ + Build metadata for a converter instance. + + Args: + name: The registry name of the converter. + instance: The converter instance. + + Returns: + ConverterIdentifier with basic info about the converter. + """ + return ConverterIdentifier( + class_name=instance.__class__.__name__, + class_module=instance.__class__.__module__, + class_description=f"Converter: {name}", + identifier_type="instance", + ) diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py new file mode 100644 index 0000000000..c430750f41 --- /dev/null +++ b/pyrit/registry/instance_registries/target_registry.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target registry for managing PyRIT target instances. + +Targets are registered explicitly via initializers as pre-configured instances. + +NOTE: This is a placeholder implementation. PR #1320 will add the full implementation. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +from pyrit.identifiers import Identifier +from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget + +logger = logging.getLogger(__name__) + + +# Placeholder identifier type until proper TargetIdentifier is defined +# TODO: Replace with TargetIdentifier when available +@dataclass(frozen=True) +class TargetIdentifier(Identifier): + """Temporary identifier type for targets.""" + + pass + + +class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): + """ + Registry for managing available target instances. + + This registry stores pre-configured PromptTarget instances (not classes). + Targets are registered explicitly via initializers after being instantiated + with their required parameters. + + NOTE: This is a placeholder. PR #1320 will add the full implementation. + """ + + @classmethod + def get_registry_singleton(cls) -> "TargetRegistry": + """ + Get the singleton instance of the TargetRegistry. + + Returns: + The singleton TargetRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_instance( + self, + target: "PromptTarget", + *, + name: Optional[str] = None, + ) -> None: + """ + Register a target instance. + + Args: + target: The pre-configured target instance (not a class). + name: Optional custom registry name. If not provided, + derived from class name (e.g., AzureOpenAIGPT4OChatTarget -> azure_openai_gpt4o_chat). + """ + if name is None: + name = class_name_to_snake_case(target.__class__.__name__, suffix="Target") + + self.register(target, name=name) + logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") + + def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: + """ + Get a registered target instance by name. + + Args: + name: The registry name of the target. + + Returns: + The target instance, or None if not found. + """ + return self.get(name) + + def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetIdentifier: + """ + Build metadata for a target instance. + + Args: + name: The registry name of the target. + instance: The target instance. + + Returns: + TargetIdentifier with basic info about the target. + """ + return TargetIdentifier( + class_name=instance.__class__.__name__, + class_module=instance.__class__.__module__, + class_description=f"Target: {name}", + identifier_type="instance", + ) diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 99ba9baf73..c3acdb2e3f 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -5,8 +5,10 @@ Tests for backend API routes. """ +import json +import os +import tempfile from datetime import datetime, timezone -from typing import List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -16,8 +18,8 @@ from pyrit.backend.main import app from pyrit.backend.models.attacks import ( AddMessageResponse, - AttackDetail, AttackListResponse, + AttackMessagesResponse, AttackSummary, CreateAttackResponse, Message, @@ -59,7 +61,7 @@ def test_list_attacks_returns_empty_list(self, client: TestClient) -> None: mock_service.list_attacks = AsyncMock( return_value=AttackListResponse( items=[], - pagination=PaginationInfo(limit=20, has_more=False), + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), ) ) mock_get_service.return_value = mock_service @@ -77,7 +79,7 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: mock_service.list_attacks = AsyncMock( return_value=AttackListResponse( items=[], - pagination=PaginationInfo(limit=10, has_more=False), + pagination=PaginationInfo(limit=10, has_more=False, next_cursor=None, prev_cursor=None), ) ) mock_get_service.return_value = mock_service @@ -91,6 +93,10 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: mock_service.list_attacks.assert_called_once_with( target_id="t1", outcome="success", + name=None, + labels=None, + min_turns=None, + max_turns=None, limit=10, cursor=None, ) @@ -122,9 +128,7 @@ def test_create_attack_target_not_found(self, client: TestClient) -> None: """Test attack creation with non-existent target.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_attack = AsyncMock( - side_effect=ValueError("Target not found") - ) + mock_service.create_attack = AsyncMock(side_effect=ValueError("Target not found")) mock_get_service.return_value = mock_service response = client.post( @@ -141,14 +145,14 @@ def test_get_attack_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() mock_service.get_attack = AsyncMock( - return_value=AttackDetail( + return_value=AttackSummary( attack_id="attack-1", name="Test", target_id="target-1", target_type="TextTarget", outcome=None, - prepended_conversation=[], - messages=[], + last_message_preview=None, + message_count=0, created_at=now, updated_at=now, ) @@ -179,13 +183,14 @@ def test_update_attack_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() mock_service.update_attack = AsyncMock( - return_value=AttackDetail( + return_value=AttackSummary( attack_id="attack-1", + name=None, target_id="target-1", target_type="TextTarget", outcome="success", - prepended_conversation=[], - messages=[], + last_message_preview=None, + message_count=0, created_at=now, updated_at=now, ) @@ -205,10 +210,20 @@ def test_add_message_success(self, client: TestClient) -> None: """Test adding a message to an attack.""" now = datetime.now(timezone.utc) - attack_detail = AttackDetail( + attack_summary = AttackSummary( attack_id="attack-1", + name=None, target_id="target-1", target_type="TextTarget", + outcome=None, + last_message_preview=None, + message_count=2, + created_at=now, + updated_at=now, + ) + + attack_messages = AttackMessagesResponse( + attack_id="attack-1", messages=[ Message( message_id="msg-1", @@ -235,15 +250,14 @@ def test_add_message_success(self, client: TestClient) -> None: created_at=now, ), ], - created_at=now, - updated_at=now, ) with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() mock_service.add_message = AsyncMock( return_value=AddMessageResponse( - attack=attack_detail, + attack=attack_summary, + messages=attack_messages, ) ) mock_get_service.return_value = mock_service @@ -255,7 +269,7 @@ def test_add_message_success(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert len(data["attack"]["messages"]) == 2 + assert len(data["messages"]["messages"]) == 2 def test_update_attack_not_found(self, client: TestClient) -> None: """Test updating a non-existent attack returns 404.""" @@ -275,9 +289,7 @@ def test_add_message_attack_not_found(self, client: TestClient) -> None: """Test adding message to non-existent attack returns 404.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock( - side_effect=ValueError("Attack 'nonexistent' not found") - ) + mock_service.add_message = AsyncMock(side_effect=ValueError("Attack 'nonexistent' not found")) mock_get_service.return_value = mock_service response = client.post( @@ -291,9 +303,7 @@ def test_add_message_target_not_found(self, client: TestClient) -> None: """Test adding message when target object not found returns 404.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock( - side_effect=ValueError("Target object for 'target-1' not found") - ) + mock_service.add_message = AsyncMock(side_effect=ValueError("Target object for 'target-1' not found")) mock_get_service.return_value = mock_service response = client.post( @@ -307,9 +317,7 @@ def test_add_message_bad_request(self, client: TestClient) -> None: """Test adding message with invalid request returns 400.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock( - side_effect=ValueError("Invalid message format") - ) + mock_service.add_message = AsyncMock(side_effect=ValueError("Invalid message format")) mock_get_service.return_value = mock_service response = client.post( @@ -323,9 +331,7 @@ def test_add_message_internal_error(self, client: TestClient) -> None: """Test adding message when internal error occurs returns 500.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock( - side_effect=RuntimeError("Unexpected internal error") - ) + mock_service.add_message = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) mock_get_service.return_value = mock_service response = client.post( @@ -335,6 +341,81 @@ def test_add_message_internal_error(self, client: TestClient) -> None: assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + def test_get_attack_messages_success(self, client: TestClient) -> None: + """Test getting attack messages.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_attack_messages = AsyncMock( + return_value=AttackMessagesResponse( + attack_id="attack-1", + messages=[ + Message( + message_id="msg-1", + turn_number=1, + role="user", + pieces=[MessagePiece(piece_id="p1", converted_value="Hello")], + created_at=now, + ) + ], + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks/attack-1/messages") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["attack_id"] == "attack-1" + assert len(data["messages"]) == 1 + + def test_get_attack_messages_not_found(self, client: TestClient) -> None: + """Test getting messages for non-existent attack returns 404.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_attack_messages = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks/nonexistent/messages") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_list_attacks_with_labels(self, client: TestClient) -> None: + """Test listing attacks with label filters.""" + now = datetime.now(timezone.utc) + + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks = AsyncMock( + return_value=AttackListResponse( + items=[ + AttackSummary( + attack_id="attack-1", + name=None, + target_id="target-1", + target_type="TextTarget", + outcome=None, + last_message_preview=None, + message_count=0, + labels={"env": "prod"}, + created_at=now, + updated_at=now, + ) + ], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?label=env:prod&label=team:red") + + assert response.status_code == status.HTTP_200_OK + # Verify labels were parsed and passed to service + mock_service.list_attacks.assert_called_once() + call_kwargs = mock_service.list_attacks.call_args[1] + assert call_kwargs["labels"] == {"env": "prod", "team": "red"} + # ============================================================================ # Target Routes Tests @@ -348,9 +429,7 @@ def test_list_targets_returns_empty_list(self, client: TestClient) -> None: """Test that list targets returns empty list initially.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_targets = AsyncMock( - return_value=TargetListResponse(items=[]) - ) + mock_service.list_targets = AsyncMock(return_value=TargetListResponse(items=[])) mock_get_service.return_value = mock_service response = client.get("/api/targets") @@ -363,9 +442,7 @@ def test_list_targets_with_source_filter(self, client: TestClient) -> None: """Test that list targets accepts source filter.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_targets = AsyncMock( - return_value=TargetListResponse(items=[]) - ) + mock_service.list_targets = AsyncMock(return_value=TargetListResponse(items=[])) mock_get_service.return_value = mock_service response = client.get("/api/targets", params={"source": "user"}) @@ -385,8 +462,6 @@ def test_create_target_success(self, client: TestClient) -> None: type="TextTarget", display_name="My Target", params={}, - created_at=now, - source="user", ) ) mock_get_service.return_value = mock_service @@ -404,9 +479,7 @@ def test_create_target_invalid_type(self, client: TestClient) -> None: """Test target creation with invalid type.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_target = AsyncMock( - side_effect=ValueError("Target type not found") - ) + mock_service.create_target = AsyncMock(side_effect=ValueError("Target type not found")) mock_get_service.return_value = mock_service response = client.post( @@ -426,9 +499,8 @@ def test_get_target_success(self, client: TestClient) -> None: return_value=TargetInstance( target_id="target-1", type="TextTarget", + display_name=None, params={}, - created_at=now, - source="user", ) ) mock_get_service.return_value = mock_service @@ -454,9 +526,7 @@ def test_create_target_internal_error(self, client: TestClient) -> None: """Test target creation with internal error returns 500.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_target = AsyncMock( - side_effect=RuntimeError("Unexpected internal error") - ) + mock_service.create_target = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) mock_get_service.return_value = mock_service response = client.post( @@ -479,9 +549,7 @@ def test_list_converters(self, client: TestClient) -> None: """Test listing converter instances.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_converters = AsyncMock( - return_value=ConverterInstanceListResponse(items=[]) - ) + mock_service.list_converters = AsyncMock(return_value=ConverterInstanceListResponse(items=[])) mock_get_service.return_value = mock_service response = client.get("/api/converters") @@ -502,8 +570,6 @@ def test_create_converter_success(self, client: TestClient) -> None: type="Base64Converter", display_name="My Base64", params={}, - created_at=now, - source="user", ) ) mock_get_service.return_value = mock_service @@ -521,9 +587,7 @@ def test_create_converter_invalid_type(self, client: TestClient) -> None: """Test converter creation with invalid type.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_converter = AsyncMock( - side_effect=ValueError("Converter type not found") - ) + mock_service.create_converter = AsyncMock(side_effect=ValueError("Converter type not found")) mock_get_service.return_value = mock_service response = client.post( @@ -543,9 +607,8 @@ def test_get_converter_success(self, client: TestClient) -> None: return_value=ConverterInstance( converter_id="conv-1", type="Base64Converter", + display_name=None, params={}, - created_at=now, - source="user", ) ) mock_get_service.return_value = mock_service @@ -579,7 +642,7 @@ def test_preview_conversion_success(self, client: TestClient) -> None: converted_value_data_type="text", steps=[ PreviewStep( - converter_id=None, + converter_id="conv-1", converter_type="Base64Converter", input_value="test", input_data_type="text", @@ -596,7 +659,7 @@ def test_preview_conversion_success(self, client: TestClient) -> None: json={ "original_value": "test", "original_value_data_type": "text", - "converters": [{"type": "Base64Converter", "params": {}}], + "converter_ids": ["conv-1"], }, ) @@ -609,9 +672,7 @@ def test_create_converter_internal_error(self, client: TestClient) -> None: """Test converter creation with internal error returns 500.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_converter = AsyncMock( - side_effect=RuntimeError("Unexpected internal error") - ) + mock_service.create_converter = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) mock_get_service.return_value = mock_service response = client.post( @@ -622,11 +683,11 @@ def test_create_converter_internal_error(self, client: TestClient) -> None: assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR def test_preview_conversion_bad_request(self, client: TestClient) -> None: - """Test preview conversion with invalid parameters returns 400.""" + """Test preview conversion with invalid converter ID returns 400.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() mock_service.preview_conversion = AsyncMock( - side_effect=ValueError("Invalid converter parameters") + side_effect=ValueError("Converter instance 'nonexistent' not found") ) mock_get_service.return_value = mock_service @@ -635,7 +696,7 @@ def test_preview_conversion_bad_request(self, client: TestClient) -> None: json={ "original_value": "test", "original_value_data_type": "text", - "converters": [{"type": "InvalidConverter", "params": {}}], + "converter_ids": ["nonexistent"], }, ) @@ -645,9 +706,7 @@ def test_preview_conversion_internal_error(self, client: TestClient) -> None: """Test preview conversion with internal error returns 500.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.preview_conversion = AsyncMock( - side_effect=RuntimeError("Converter execution failed") - ) + mock_service.preview_conversion = AsyncMock(side_effect=RuntimeError("Converter execution failed")) mock_get_service.return_value = mock_service response = client.post( @@ -655,7 +714,7 @@ def test_preview_conversion_internal_error(self, client: TestClient) -> None: json={ "original_value": "test", "original_value_data_type": "text", - "converters": [{"type": "Base64Converter", "params": {}}], + "converter_ids": ["conv-1"], }, ) @@ -681,13 +740,9 @@ def test_get_version(self, client: TestClient) -> None: def test_get_version_with_build_info(self, client: TestClient) -> None: """Test getting version with build info from Docker.""" - import tempfile - import json as json_lib - import os - # Create a temp file to simulate Docker build info with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - json_lib.dump( + json.dump( { "source": "git", "commit": "abc123", @@ -706,7 +761,7 @@ def test_get_version_with_build_info(self, client: TestClient) -> None: # Mock open to return our temp file content with patch("builtins.open", create=True) as mock_open: - mock_open.return_value.__enter__.return_value.read.return_value = json_lib.dumps( + mock_open.return_value.__enter__.return_value.read.return_value = json.dumps( { "source": "git", "commit": "abc123", @@ -720,3 +775,135 @@ def test_get_version_with_build_info(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK finally: os.unlink(temp_path) + + +# ============================================================================ +# Health Routes Tests +# ============================================================================ + + +class TestHealthRoutes: + """Tests for health check API routes.""" + + def test_health_check(self, client: TestClient) -> None: + """Test health check endpoint returns ok.""" + response = client.get("/api/health") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["status"] == "healthy" + + +# ============================================================================ +# Labels Routes Tests +# ============================================================================ + + +class TestLabelsRoutes: + """Tests for labels API routes.""" + + def test_get_labels_for_attacks(self, client: TestClient) -> None: + """Test getting labels from attack results.""" + mock_attack_result = MagicMock() + mock_attack_result.metadata = {"env": "prod", "team": "red", "created_at": "2024-01-01"} + + with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_attack_results.return_value = [mock_attack_result] + mock_memory_class.get_memory_instance.return_value = mock_memory + + response = client.get("/api/labels?source=attacks") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["source"] == "attacks" + # env and team should be included, created_at should be excluded + assert "env" in data["labels"] + assert "team" in data["labels"] + assert "created_at" not in data["labels"] + + def test_get_labels_empty(self, client: TestClient) -> None: + """Test getting labels when no attack results exist.""" + with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_attack_results.return_value = [] + mock_memory_class.get_memory_instance.return_value = mock_memory + + response = client.get("/api/labels?source=attacks") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["source"] == "attacks" + assert data["labels"] == {} + + def test_get_labels_multiple_values(self, client: TestClient) -> None: + """Test getting labels with multiple values per key.""" + mock_ar1 = MagicMock() + mock_ar1.metadata = {"env": "prod"} + mock_ar2 = MagicMock() + mock_ar2.metadata = {"env": "staging"} + mock_ar3 = MagicMock() + mock_ar3.metadata = {"env": "prod", "team": "blue"} + + with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_attack_results.return_value = [mock_ar1, mock_ar2, mock_ar3] + mock_memory_class.get_memory_instance.return_value = mock_memory + + response = client.get("/api/labels") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + # Should have both env values sorted + assert set(data["labels"]["env"]) == {"prod", "staging"} + assert data["labels"]["team"] == ["blue"] + + def test_get_labels_skips_internal_metadata(self, client: TestClient) -> None: + """Test that internal metadata keys are skipped.""" + mock_ar = MagicMock() + mock_ar.metadata = { + "_internal": "value", + "created_at": "2024-01-01", + "updated_at": "2024-01-02", + "visible_label": "keep", + } + + with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_attack_results.return_value = [mock_ar] + mock_memory_class.get_memory_instance.return_value = mock_memory + + response = client.get("/api/labels") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + # Only visible_label should be included + assert "visible_label" in data["labels"] + assert "_internal" not in data["labels"] + assert "created_at" not in data["labels"] + assert "updated_at" not in data["labels"] + + def test_get_labels_skips_non_string_values(self, client: TestClient) -> None: + """Test that non-string metadata values are skipped.""" + mock_ar = MagicMock() + mock_ar.metadata = { + "string_val": "keep", + "int_val": 123, + "list_val": ["a", "b"], + "dict_val": {"nested": "value"}, + } + + with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: + mock_memory = MagicMock() + mock_memory.get_attack_results.return_value = [mock_ar] + mock_memory_class.get_memory_instance.return_value = mock_memory + + response = client.get("/api/labels") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + # Only string_val should be included + assert "string_val" in data["labels"] + assert "int_val" not in data["labels"] + assert "list_val" not in data["labels"] + assert "dict_val" not in data["labels"] diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index e92d5816da..7217740e7f 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -138,9 +138,7 @@ class TestListAttacks: """Tests for list_attacks method.""" @pytest.mark.asyncio - async def test_list_attacks_returns_empty_when_no_attacks( - self, attack_service, mock_memory - ) -> None: + async def test_list_attacks_returns_empty_when_no_attacks(self, attack_service, mock_memory) -> None: """Test that list_attacks returns empty list when no AttackResults exist.""" mock_memory.get_attack_results.return_value = [] @@ -150,9 +148,7 @@ async def test_list_attacks_returns_empty_when_no_attacks( assert result.pagination.has_more is False @pytest.mark.asyncio - async def test_list_attacks_returns_attacks( - self, attack_service, mock_memory - ) -> None: + async def test_list_attacks_returns_attacks(self, attack_service, mock_memory) -> None: """Test that list_attacks returns attacks from AttackResult records.""" ar = make_attack_result() mock_memory.get_attack_results.return_value = [ar] @@ -165,9 +161,7 @@ async def test_list_attacks_returns_attacks( assert result.items[0].target_id == "target-1" @pytest.mark.asyncio - async def test_list_attacks_filters_by_target_id( - self, attack_service, mock_memory - ) -> None: + async def test_list_attacks_filters_by_target_id(self, attack_service, mock_memory) -> None: """Test that list_attacks filters by target_id.""" ar1 = make_attack_result(conversation_id="attack-1", target_id="target-1") ar2 = make_attack_result(conversation_id="attack-2", target_id="target-2") @@ -179,6 +173,49 @@ async def test_list_attacks_filters_by_target_id( assert len(result.items) == 1 assert result.items[0].target_id == "target-1" + @pytest.mark.asyncio + async def test_list_attacks_filters_by_name(self, attack_service, mock_memory) -> None: + """Test that list_attacks filters by name substring (case-insensitive).""" + ar1 = make_attack_result(conversation_id="attack-1", name="Test Attack") + ar2 = make_attack_result(conversation_id="attack-2", name="Other") + mock_memory.get_attack_results.return_value = [ar1, ar2] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks(name="test") + + assert len(result.items) == 1 + assert result.items[0].name == "Test Attack" + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_min_turns(self, attack_service, mock_memory) -> None: + """Test that list_attacks filters by minimum executed turns.""" + ar1 = make_attack_result(conversation_id="attack-1") + ar1.executed_turns = 5 + ar2 = make_attack_result(conversation_id="attack-2") + ar2.executed_turns = 2 + mock_memory.get_attack_results.return_value = [ar1, ar2] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks(min_turns=3) + + assert len(result.items) == 1 + assert result.items[0].attack_id == "attack-1" + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_max_turns(self, attack_service, mock_memory) -> None: + """Test that list_attacks filters by maximum executed turns.""" + ar1 = make_attack_result(conversation_id="attack-1") + ar1.executed_turns = 5 + ar2 = make_attack_result(conversation_id="attack-2") + ar2.executed_turns = 2 + mock_memory.get_attack_results.return_value = [ar1, ar2] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks(max_turns=3) + + assert len(result.items) == 1 + assert result.items[0].attack_id == "attack-2" + # ============================================================================ # Get Attack Tests @@ -190,9 +227,7 @@ class TestGetAttack: """Tests for get_attack method.""" @pytest.mark.asyncio - async def test_get_attack_returns_none_for_nonexistent( - self, attack_service, mock_memory - ) -> None: + async def test_get_attack_returns_none_for_nonexistent(self, attack_service, mock_memory) -> None: """Test that get_attack returns None when AttackResult doesn't exist.""" mock_memory.get_attack_results.return_value = [] @@ -201,9 +236,7 @@ async def test_get_attack_returns_none_for_nonexistent( assert result is None @pytest.mark.asyncio - async def test_get_attack_returns_attack_details( - self, attack_service, mock_memory - ) -> None: + async def test_get_attack_returns_attack_details(self, attack_service, mock_memory) -> None: """Test that get_attack returns attack details from AttackResult.""" ar = make_attack_result( conversation_id="test-id", @@ -223,6 +256,38 @@ async def test_get_attack_returns_attack_details( assert result.name == "My Attack" +# ============================================================================ +# Get Attack Messages Tests +# ============================================================================ + + +@pytest.mark.usefixtures("patch_central_database") +class TestGetAttackMessages: + """Tests for get_attack_messages method.""" + + @pytest.mark.asyncio + async def test_get_attack_messages_returns_none_for_nonexistent(self, attack_service, mock_memory) -> None: + """Test that get_attack_messages returns None when attack doesn't exist.""" + mock_memory.get_attack_results.return_value = [] + + result = await attack_service.get_attack_messages("nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_get_attack_messages_returns_messages(self, attack_service, mock_memory) -> None: + """Test that get_attack_messages returns messages for existing attack.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] + + result = await attack_service.get_attack_messages("test-id") + + assert result is not None + assert result.attack_id == "test-id" + assert result.messages == [] + + # ============================================================================ # Create Attack Tests # ============================================================================ @@ -235,39 +300,53 @@ class TestCreateAttack: @pytest.mark.asyncio async def test_create_attack_validates_target_exists(self, attack_service) -> None: """Test that create_attack validates target exists.""" - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() mock_target_service.get_target = AsyncMock(return_value=None) mock_get_target_service.return_value = mock_target_service with pytest.raises(ValueError, match="not found"): - await attack_service.create_attack( - CreateAttackRequest(target_id="nonexistent") - ) + await attack_service.create_attack(CreateAttackRequest(target_id="nonexistent")) @pytest.mark.asyncio - async def test_create_attack_stores_attack_result( - self, attack_service, mock_memory - ) -> None: + async def test_create_attack_stores_attack_result(self, attack_service, mock_memory) -> None: """Test that create_attack stores AttackResult in memory.""" - with patch( - "pyrit.backend.services.attack_service.get_target_service" - ) as mock_get_target_service: + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock( - return_value=MagicMock(type="TextTarget") - ) + mock_target_service.get_target = AsyncMock(return_value=MagicMock(type="TextTarget")) mock_get_target_service.return_value = mock_target_service + result = await attack_service.create_attack(CreateAttackRequest(target_id="target-1", name="My Attack")) + + assert result.attack_id is not None + assert result.created_at is not None + mock_memory.add_attack_results_to_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_create_attack_stores_prepended_conversation(self, attack_service, mock_memory) -> None: + """Test that create_attack stores prepended conversation messages.""" + from pyrit.backend.models.attacks import PrependedMessageRequest + + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_get_target_service.return_value = mock_target_service + + prepended = [ + PrependedMessageRequest( + role="system", + pieces=[MessagePieceRequest(original_value="You are a helpful assistant.")], + ) + ] + result = await attack_service.create_attack( - CreateAttackRequest(target_id="target-1", name="My Attack") + CreateAttackRequest(target_id="target-1", prepended_conversation=prepended) ) assert result.attack_id is not None - assert result.created_at is not None + # Both attack result and prepended message pieces should be stored mock_memory.add_attack_results_to_memory.assert_called_once() + mock_memory.add_message_pieces_to_memory.assert_called() # ============================================================================ @@ -280,30 +359,22 @@ class TestUpdateAttack: """Tests for update_attack method.""" @pytest.mark.asyncio - async def test_update_attack_returns_none_for_nonexistent( - self, attack_service, mock_memory - ) -> None: + async def test_update_attack_returns_none_for_nonexistent(self, attack_service, mock_memory) -> None: """Test that update_attack returns None for nonexistent attack.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.update_attack( - "nonexistent", UpdateAttackRequest(outcome="success") - ) + result = await attack_service.update_attack("nonexistent", UpdateAttackRequest(outcome="success")) assert result is None @pytest.mark.asyncio - async def test_update_attack_updates_outcome( - self, attack_service, mock_memory - ) -> None: + async def test_update_attack_updates_outcome(self, attack_service, mock_memory) -> None: """Test that update_attack updates the AttackResult outcome.""" ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - await attack_service.update_attack( - "test-id", UpdateAttackRequest(outcome="success") - ) + await attack_service.update_attack("test-id", UpdateAttackRequest(outcome="success")) # Should call add_attack_results_to_memory to update mock_memory.add_attack_results_to_memory.assert_called() @@ -319,9 +390,7 @@ class TestAddMessage: """Tests for add_message method.""" @pytest.mark.asyncio - async def test_add_message_raises_for_nonexistent_attack( - self, attack_service, mock_memory - ) -> None: + async def test_add_message_raises_for_nonexistent_attack(self, attack_service, mock_memory) -> None: """Test that add_message raises ValueError for nonexistent attack.""" mock_memory.get_attack_results.return_value = [] @@ -333,9 +402,7 @@ async def test_add_message_raises_for_nonexistent_attack( await attack_service.add_message("nonexistent", request) @pytest.mark.asyncio - async def test_add_message_without_send_stores_message( - self, attack_service, mock_memory - ) -> None: + async def test_add_message_without_send_stores_message(self, attack_service, mock_memory) -> None: """Test that add_message with send=False stores message in memory.""" ar = make_attack_result(conversation_id="test-id", target_id="target-1") mock_memory.get_attack_results.return_value = [ar] @@ -353,6 +420,187 @@ async def test_add_message_without_send_stores_message( assert result.attack is not None mock_memory.add_message_pieces_to_memory.assert_called() + @pytest.mark.asyncio + async def test_add_message_raises_when_no_target_id(self, attack_service, mock_memory) -> None: + """Test that add_message raises ValueError when attack has no target configured.""" + ar = make_attack_result(conversation_id="test-id", target_id="") + ar.attack_identifier["target_id"] = "" # Explicitly set to empty + mock_memory.get_attack_results.return_value = [ar] + + request = AddMessageRequest( + pieces=[MessagePieceRequest(original_value="Hello")], + ) + + with pytest.raises(ValueError, match="has no target configured"): + await attack_service.add_message("test-id", request) + + @pytest.mark.asyncio + async def test_add_message_with_send_calls_normalizer(self, attack_service, mock_memory) -> None: + """Test that add_message with send=True sends message via normalizer.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + mock_memory.get_conversation.return_value = [] + + with ( + patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, + patch("pyrit.backend.services.attack_service.PromptNormalizer") as mock_normalizer_cls, + ): + mock_target_svc = MagicMock() + mock_target_svc.get_target_object.return_value = MagicMock() + mock_get_target_svc.return_value = mock_target_svc + + mock_normalizer = MagicMock() + mock_normalizer.send_prompt_async = AsyncMock() + mock_normalizer_cls.return_value = mock_normalizer + + request = AddMessageRequest( + pieces=[MessagePieceRequest(original_value="Hello")], + send=True, + ) + + result = await attack_service.add_message("test-id", request) + + mock_normalizer.send_prompt_async.assert_called_once() + assert result.attack is not None + + @pytest.mark.asyncio + async def test_add_message_with_send_raises_when_target_not_found(self, attack_service, mock_memory) -> None: + """Test that add_message with send=True raises when target object not found.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc: + mock_target_svc = MagicMock() + mock_target_svc.get_target_object.return_value = None + mock_get_target_svc.return_value = mock_target_svc + + request = AddMessageRequest( + pieces=[MessagePieceRequest(original_value="Hello")], + send=True, + ) + + with pytest.raises(ValueError, match="Target object .* not found"): + await attack_service.add_message("test-id", request) + + @pytest.mark.asyncio + async def test_add_message_with_converter_ids_gets_converters(self, attack_service, mock_memory) -> None: + """Test that add_message with converter_ids gets converters from service.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + mock_memory.get_conversation.return_value = [] + + with ( + patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, + patch("pyrit.backend.services.attack_service.get_converter_service") as mock_get_conv_svc, + patch("pyrit.backend.services.attack_service.PromptNormalizer") as mock_normalizer_cls, + patch("pyrit.backend.services.attack_service.PromptConverterConfiguration") as mock_config, + ): + mock_target_svc = MagicMock() + mock_target_svc.get_target_object.return_value = MagicMock() + mock_get_target_svc.return_value = mock_target_svc + + mock_conv_svc = MagicMock() + mock_conv_svc.get_converter_objects_for_ids.return_value = [MagicMock()] + mock_get_conv_svc.return_value = mock_conv_svc + + mock_config.from_converters.return_value = [MagicMock()] + + mock_normalizer = MagicMock() + mock_normalizer.send_prompt_async = AsyncMock() + mock_normalizer_cls.return_value = mock_normalizer + + request = AddMessageRequest( + pieces=[MessagePieceRequest(original_value="Hello")], + send=True, + converter_ids=["conv-1"], + ) + + await attack_service.add_message("test-id", request) + + mock_conv_svc.get_converter_objects_for_ids.assert_called_once_with(["conv-1"]) + + +# ============================================================================ +# Pagination Tests +# ============================================================================ + + +@pytest.mark.usefixtures("patch_central_database") +class TestPagination: + """Tests for pagination in list_attacks.""" + + @pytest.mark.asyncio + async def test_list_attacks_with_cursor_paginates(self, attack_service, mock_memory) -> None: + """Test that list_attacks with cursor starts from the right position.""" + ar1 = make_attack_result(conversation_id="attack-1") + ar2 = make_attack_result(conversation_id="attack-2") + ar3 = make_attack_result(conversation_id="attack-3") + mock_memory.get_attack_results.return_value = [ar1, ar2, ar3] + mock_memory.get_message_pieces.return_value = [] + + # Get first page + result = await attack_service.list_attacks(limit=2) + # Results are sorted by updated_at desc, so order may vary + assert len(result.items) == 2 + + @pytest.mark.asyncio + async def test_list_attacks_has_more_flag(self, attack_service, mock_memory) -> None: + """Test that list_attacks sets has_more flag correctly.""" + ar1 = make_attack_result(conversation_id="attack-1") + ar2 = make_attack_result(conversation_id="attack-2") + ar3 = make_attack_result(conversation_id="attack-3") + mock_memory.get_attack_results.return_value = [ar1, ar2, ar3] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks(limit=2) + + assert result.pagination.has_more is True + assert len(result.items) == 2 + + +# ============================================================================ +# Message Building Tests +# ============================================================================ + + +@pytest.mark.usefixtures("patch_central_database") +class TestMessageBuilding: + """Tests for message translation and building.""" + + @pytest.mark.asyncio + async def test_get_attack_with_messages_translates_correctly(self, attack_service, mock_memory) -> None: + """Test that get_attack_messages translates PyRIT messages to backend format.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + + # Create mock message with pieces + mock_piece = MagicMock() + mock_piece.id = "piece-1" + mock_piece.converted_value_data_type = "text" + mock_piece.original_value = "Hello" + mock_piece.converted_value = "Hello" + mock_piece.response_error = None + mock_piece.sequence = 0 + mock_piece.role = "user" + mock_piece.timestamp = datetime.now(timezone.utc) + mock_piece.scores = None + + mock_msg = MagicMock() + mock_msg.message_pieces = [mock_piece] + + mock_memory.get_conversation.return_value = [mock_msg] + + result = await attack_service.get_attack_messages("test-id") + + assert result is not None + assert len(result.messages) == 1 + assert result.messages[0].role == "user" + assert len(result.messages[0].pieces) == 1 + assert result.messages[0].pieces[0].original_value == "Hello" + # ============================================================================ # Singleton Tests diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index e8ff4a5aad..c17e15bf26 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -5,32 +5,24 @@ Tests for backend converter service. """ -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from pyrit.backend.models.converters import ( - ConverterInstance, ConverterPreviewRequest, CreateConverterRequest, - InlineConverterConfig, ) from pyrit.backend.services.converter_service import ConverterService +from pyrit.registry.instance_registries import ConverterRegistry -class TestConverterServiceInit: - """Tests for ConverterService initialization.""" - - def test_init_creates_empty_instances_dict(self) -> None: - """Test that service initializes with empty instances dictionary.""" - service = ConverterService() - assert service._instances == {} - - def test_init_creates_empty_converter_objects_dict(self) -> None: - """Test that service initializes with empty converter objects dictionary.""" - service = ConverterService() - assert service._converter_objects == {} +@pytest.fixture(autouse=True) +def reset_registry(): + """Reset the ConverterRegistry singleton before each test.""" + ConverterRegistry.reset_instance() + yield + ConverterRegistry.reset_instance() class TestListConverters: @@ -46,77 +38,19 @@ async def test_list_converters_returns_empty_when_no_converters(self) -> None: assert result.items == [] @pytest.mark.asyncio - async def test_list_converters_returns_converters(self) -> None: - """Test that list_converters returns existing converters.""" + async def test_list_converters_returns_converters_from_registry(self) -> None: + """Test that list_converters returns converters from registry.""" service = ConverterService() - now = datetime.now(timezone.utc) - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="Base64Converter", - display_name="My Converter", - params={}, - created_at=now, - source="user", - ) + mock_converter = MagicMock() + mock_converter.__class__.__name__ = "MockConverter" + service._registry.register_instance(mock_converter, name="conv-1") result = await service.list_converters() assert len(result.items) == 1 assert result.items[0].converter_id == "conv-1" - assert result.items[0].display_name == "My Converter" - - @pytest.mark.asyncio - async def test_list_converters_filters_by_source_user(self) -> None: - """Test that list_converters filters by source='user'.""" - service = ConverterService() - now = datetime.now(timezone.utc) - - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="Base64Converter", - params={}, - created_at=now, - source="user", - ) - service._instances["conv-2"] = ConverterInstance( - converter_id="conv-2", - type="Base64Converter", - params={}, - created_at=now, - source="initializer", - ) - - result = await service.list_converters(source="user") - - assert len(result.items) == 1 - assert result.items[0].source == "user" - - @pytest.mark.asyncio - async def test_list_converters_filters_by_source_initializer(self) -> None: - """Test that list_converters filters by source='initializer'.""" - service = ConverterService() - now = datetime.now(timezone.utc) - - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="Base64Converter", - params={}, - created_at=now, - source="user", - ) - service._instances["conv-2"] = ConverterInstance( - converter_id="conv-2", - type="Base64Converter", - params={}, - created_at=now, - source="initializer", - ) - - result = await service.list_converters(source="initializer") - - assert len(result.items) == 1 - assert result.items[0].source == "initializer" + assert result.items[0].type == "MockConverter" class TestGetConverter: @@ -132,25 +66,19 @@ async def test_get_converter_returns_none_for_nonexistent(self) -> None: assert result is None @pytest.mark.asyncio - async def test_get_converter_returns_converter(self) -> None: - """Test that get_converter returns the converter instance.""" + async def test_get_converter_returns_converter_from_registry(self) -> None: + """Test that get_converter returns converter built from registry object.""" service = ConverterService() - now = datetime.now(timezone.utc) - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="Base64Converter", - display_name="Test Converter", - params={"key": "value"}, - created_at=now, - source="user", - ) + mock_converter = MagicMock() + mock_converter.__class__.__name__ = "MockConverter" + service._registry.register_instance(mock_converter, name="conv-1") result = await service.get_converter("conv-1") assert result is not None assert result.converter_id == "conv-1" - assert result.display_name == "Test Converter" + assert result.type == "MockConverter" class TestGetConverterObject: @@ -164,11 +92,11 @@ def test_get_converter_object_returns_none_for_nonexistent(self) -> None: assert result is None - def test_get_converter_object_returns_object(self) -> None: + def test_get_converter_object_returns_object_from_registry(self) -> None: """Test that get_converter_object returns the actual converter object.""" service = ConverterService() mock_converter = MagicMock() - service._converter_objects["conv-1"] = mock_converter + service._registry.register_instance(mock_converter, name="conv-1") result = service.get_converter_object("conv-1") @@ -198,7 +126,6 @@ def test_get_converter_class_handles_snake_case(self) -> None: """Test that _get_converter_class handles snake_case names.""" service = ConverterService() - # base64 should resolve to Base64Converter result = service._get_converter_class("base64") assert result is not None @@ -236,11 +163,10 @@ async def test_create_converter_success(self) -> None: assert result.converter_id is not None assert result.type == "Base64Converter" assert result.display_name == "My Base64" - assert result.source == "user" @pytest.mark.asyncio - async def test_create_converter_stores_instance(self) -> None: - """Test that create_converter stores the instance.""" + async def test_create_converter_registers_in_registry(self) -> None: + """Test that create_converter registers object in registry.""" service = ConverterService() request = CreateConverterRequest( @@ -250,42 +176,9 @@ async def test_create_converter_stores_instance(self) -> None: result = await service.create_converter(request) - assert result.converter_id in service._instances - assert result.converter_id in service._converter_objects - - -class TestDeleteConverter: - """Tests for ConverterService.delete_converter method.""" - - @pytest.mark.asyncio - async def test_delete_converter_returns_false_for_nonexistent(self) -> None: - """Test that delete_converter returns False for non-existent converter.""" - service = ConverterService() - - result = await service.delete_converter("nonexistent") - - assert result is False - - @pytest.mark.asyncio - async def test_delete_converter_deletes_converter(self) -> None: - """Test that delete_converter removes the converter.""" - service = ConverterService() - now = datetime.now(timezone.utc) - - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="Base64Converter", - params={}, - created_at=now, - source="user", - ) - service._converter_objects["conv-1"] = MagicMock() - - result = await service.delete_converter("conv-1") - - assert result is True - assert "conv-1" not in service._instances - assert "conv-1" not in service._converter_objects + # Object should be retrievable from registry + converter_obj = service.get_converter_object(result.converter_id) + assert converter_obj is not None class TestPreviewConversion: @@ -309,23 +202,14 @@ async def test_preview_conversion_raises_for_nonexistent_converter(self) -> None async def test_preview_conversion_with_converter_ids(self) -> None: """Test preview with converter IDs.""" service = ConverterService() - now = datetime.now(timezone.utc) - # Create a mock converter mock_converter = MagicMock() + mock_converter.__class__.__name__ = "MockConverter" mock_result = MagicMock() mock_result.output_text = "encoded_value" mock_result.output_type = "text" mock_converter.convert_async = AsyncMock(return_value=mock_result) - - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="MockConverter", - params={}, - created_at=now, - source="user", - ) - service._converter_objects["conv-1"] = mock_converter + service._registry.register_instance(mock_converter, name="conv-1") request = ConverterPreviewRequest( original_value="test", @@ -340,63 +224,27 @@ async def test_preview_conversion_with_converter_ids(self) -> None: assert len(result.steps) == 1 assert result.steps[0].converter_id == "conv-1" - @pytest.mark.asyncio - async def test_preview_conversion_with_inline_converters(self) -> None: - """Test preview with inline converter configs.""" - service = ConverterService() - - request = ConverterPreviewRequest( - original_value="test", - original_value_data_type="text", - converters=[ - InlineConverterConfig(type="Base64Converter", params={}), - ], - ) - - result = await service.preview_conversion(request) - - assert result.original_value == "test" - assert result.converted_value is not None - assert len(result.steps) == 1 - # Base64 of "test" should be different from "test" - assert result.converted_value != "test" - @pytest.mark.asyncio async def test_preview_conversion_chains_multiple_converters(self) -> None: """Test that preview chains multiple converters.""" service = ConverterService() - now = datetime.now(timezone.utc) - # Create two mock converters mock_converter1 = MagicMock() + mock_converter1.__class__.__name__ = "MockConverter1" mock_result1 = MagicMock() mock_result1.output_text = "step1_output" mock_result1.output_type = "text" mock_converter1.convert_async = AsyncMock(return_value=mock_result1) mock_converter2 = MagicMock() + mock_converter2.__class__.__name__ = "MockConverter2" mock_result2 = MagicMock() mock_result2.output_text = "step2_output" mock_result2.output_type = "text" mock_converter2.convert_async = AsyncMock(return_value=mock_result2) - service._instances["conv-1"] = ConverterInstance( - converter_id="conv-1", - type="MockConverter1", - params={}, - created_at=now, - source="user", - ) - service._converter_objects["conv-1"] = mock_converter1 - - service._instances["conv-2"] = ConverterInstance( - converter_id="conv-2", - type="MockConverter2", - params={}, - created_at=now, - source="user", - ) - service._converter_objects["conv-2"] = mock_converter2 + service._registry.register_instance(mock_converter1, name="conv-1") + service._registry.register_instance(mock_converter2, name="conv-2") request = ConverterPreviewRequest( original_value="input", @@ -408,7 +256,6 @@ async def test_preview_conversion_chains_multiple_converters(self) -> None: assert result.converted_value == "step2_output" assert len(result.steps) == 2 - # Second converter should receive output from first mock_converter2.convert_async.assert_called_with(prompt="step1_output") @@ -428,85 +275,66 @@ def test_get_converter_objects_for_ids_returns_objects(self) -> None: mock1 = MagicMock() mock2 = MagicMock() - service._converter_objects["conv-1"] = mock1 - service._converter_objects["conv-2"] = mock2 + service._registry.register_instance(mock1, name="conv-1") + service._registry.register_instance(mock2, name="conv-2") result = service.get_converter_objects_for_ids(["conv-1", "conv-2"]) assert result == [mock1, mock2] -class TestInstantiateInlineConverters: - """Tests for ConverterService.instantiate_inline_converters method.""" - - def test_instantiate_inline_converters_creates_objects(self) -> None: - """Test that inline converters are instantiated.""" - service = ConverterService() - - configs = [ - InlineConverterConfig(type="Base64Converter", params={}), - ] - - result = service.instantiate_inline_converters(configs) - - assert len(result) == 1 - # Verify it's a real converter object - assert hasattr(result[0], "convert_async") - - def test_instantiate_inline_converters_raises_for_invalid_type(self) -> None: - """Test that invalid type raises ValueError.""" - service = ConverterService() - - configs = [ - InlineConverterConfig(type="NonExistentConverter", params={}), - ] - - with pytest.raises(ValueError, match="not found"): - service.instantiate_inline_converters(configs) - - -class TestNestedConverterCreation: - """Tests for nested converter creation.""" +class TestConverterWithReferencedConverter: + """Tests for creating converters that reference other converters by ID.""" @pytest.mark.asyncio - async def test_create_converter_with_nested_converter(self) -> None: - """Test creating a converter with a nested converter config.""" + async def test_create_converter_with_referenced_converter(self) -> None: + """Test creating a converter that references another converter by ID.""" service = ConverterService() - # Mock the parent converter class that accepts a 'converter' param - mock_parent_class = MagicMock() - mock_parent_instance = MagicMock() - mock_parent_class.return_value = mock_parent_instance + mock_inner_class = MagicMock() + mock_inner_instance = MagicMock() + mock_inner_class.return_value = mock_inner_instance - mock_child_class = MagicMock() - mock_child_instance = MagicMock() - mock_child_class.return_value = mock_child_instance + mock_outer_class = MagicMock() + mock_outer_instance = MagicMock() + mock_outer_class.return_value = mock_outer_instance def mock_get_class(converter_type: str) -> type: - if converter_type == "ParentConverter": - return mock_parent_class - elif converter_type == "ChildConverter": - return mock_child_class + if converter_type == "OuterConverter": + return mock_outer_class + elif converter_type == "InnerConverter": + return mock_inner_class raise ValueError(f"Unknown type: {converter_type}") with patch.object(service, "_get_converter_class", side_effect=mock_get_class): - request = CreateConverterRequest( - type="ParentConverter", - params={ - "converter": { - "type": "ChildConverter", - "params": {}, - }, - }, + inner_result = await service.create_converter(CreateConverterRequest(type="InnerConverter", params={})) + inner_id = inner_result.converter_id + + await service.create_converter( + CreateConverterRequest( + type="OuterConverter", + params={"converter": {"converter_id": inner_id}}, + ) ) - result = await service.create_converter(request) + mock_outer_class.assert_called() + call_kwargs = mock_outer_class.call_args[1] + assert call_kwargs.get("converter") is mock_inner_instance + + @pytest.mark.asyncio + async def test_create_converter_with_invalid_reference_raises(self) -> None: + """Test that referencing a non-existent converter raises ValueError.""" + service = ConverterService() - # Parent should be created with child converter object - mock_parent_class.assert_called() - # The call should have received the child instance, not the dict - call_kwargs = mock_parent_class.call_args[1] - assert call_kwargs.get("converter") is mock_child_instance + mock_class = MagicMock() + with patch.object(service, "_get_converter_class", return_value=mock_class): + with pytest.raises(ValueError, match="not found"): + await service.create_converter( + CreateConverterRequest( + type="OuterConverter", + params={"converter": {"converter_id": "nonexistent"}}, + ) + ) class TestConverterServiceSingleton: @@ -514,7 +342,6 @@ class TestConverterServiceSingleton: def test_get_converter_service_returns_converter_service(self) -> None: """Test that get_converter_service returns a ConverterService instance.""" - # Reset singleton for clean test import pyrit.backend.services.converter_service as module from pyrit.backend.services.converter_service import get_converter_service @@ -525,7 +352,6 @@ def test_get_converter_service_returns_converter_service(self) -> None: def test_get_converter_service_returns_same_instance(self) -> None: """Test that get_converter_service returns the same instance.""" - # Reset singleton for clean test import pyrit.backend.services.converter_service as module from pyrit.backend.services.converter_service import get_converter_service diff --git a/tests/unit/backend/test_error_handlers.py b/tests/unit/backend/test_error_handlers.py index 93131360e8..ad370fce4d 100644 --- a/tests/unit/backend/test_error_handlers.py +++ b/tests/unit/backend/test_error_handlers.py @@ -8,6 +8,7 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from pydantic import BaseModel from pyrit.backend.middleware.error_handlers import register_error_handlers @@ -61,7 +62,6 @@ async def test_endpoint(data: TestInput) -> dict: def test_validation_error_includes_field_details(self, app: FastAPI, client: TestClient) -> None: """Test that validation errors include field-level details.""" - from pydantic import BaseModel class TestInput(BaseModel): name: str diff --git a/tests/unit/backend/test_registry_service.py b/tests/unit/backend/test_registry_service.py deleted file mode 100644 index 0f8ac2b65b..0000000000 --- a/tests/unit/backend/test_registry_service.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Tests for backend registry service. -""" - -from pyrit.backend.services.registry_service import ( - RegistryService, - _extract_params_schema, - _get_all_subclasses, - get_registry_service, -) - - -class TestExtractParamsSchema: - """Tests for _extract_params_schema helper function.""" - - def test_extract_params_with_required_and_optional(self) -> None: - """Test extracting params from a class with required and optional params.""" - - class TestClass: - def __init__(self, required_param: str, optional_param: str = "default") -> None: - pass - - result = _extract_params_schema(TestClass) - - assert "required_param" in result["required"] - assert "optional_param" in result["optional"] - - def test_extract_params_ignores_self(self) -> None: - """Test that self is ignored in param extraction.""" - - class TestClass: - def __init__(self, param: str) -> None: - pass - - result = _extract_params_schema(TestClass) - - assert "self" not in result["required"] - assert "self" not in result["optional"] - - -class TestGetAllSubclasses: - """Tests for _get_all_subclasses helper function.""" - - def test_get_subclasses_finds_concrete_classes(self) -> None: - """Test that concrete subclasses are found.""" - - class Base: - pass - - class Child1(Base): - pass - - class Child2(Base): - pass - - result = _get_all_subclasses(Base) - - assert Child1 in result - assert Child2 in result - - -class TestRegistryService: - """Tests for RegistryService.""" - - def test_get_targets_returns_list(self) -> None: - """Test that get_targets returns a list.""" - service = RegistryService() - - result = service.get_targets() - - assert isinstance(result, list) - - def test_get_targets_filters_chat_targets(self) -> None: - """Test that get_targets can filter by chat target support.""" - service = RegistryService() - - chat_only = service.get_targets(is_chat_target=True) - non_chat = service.get_targets(is_chat_target=False) - - # Chat targets and non-chat targets should be different - chat_names = {t.name for t in chat_only} - non_chat_names = {t.name for t in non_chat} - # They should be disjoint (no overlap) - assert len(chat_names & non_chat_names) == 0 - - def test_get_converters_returns_list(self) -> None: - """Test that get_converters returns a list.""" - service = RegistryService() - - result = service.get_converters() - - assert isinstance(result, list) - - def test_get_converters_filters_llm_based(self) -> None: - """Test that get_converters can filter by LLM-based status.""" - service = RegistryService() - - llm_based = service.get_converters(is_llm_based=True) - non_llm = service.get_converters(is_llm_based=False) - - # LLM-based and non-LLM converters should be different - llm_names = {c.name for c in llm_based} - non_llm_names = {c.name for c in non_llm} - # They should be disjoint - assert len(llm_names & non_llm_names) == 0 - - def test_get_scenarios_returns_list(self) -> None: - """Test that get_scenarios returns a list.""" - service = RegistryService() - - result = service.get_scenarios() - - assert isinstance(result, list) - - def test_get_scorers_returns_list(self) -> None: - """Test that get_scorers returns a list.""" - service = RegistryService() - - result = service.get_scorers() - - assert isinstance(result, list) - - def test_get_initializers_returns_list(self) -> None: - """Test that get_initializers returns a list.""" - service = RegistryService() - - result = service.get_initializers() - - assert isinstance(result, list) - - -class TestGetRegistryServiceSingleton: - """Tests for get_registry_service singleton function.""" - - def test_returns_same_instance(self) -> None: - """Test that get_registry_service returns the same instance.""" - # Reset singleton for test - import pyrit.backend.services.registry_service as module - - module._registry_service = None - - service1 = get_registry_service() - service2 = get_registry_service() - - assert service1 is service2 - - def test_returns_registry_service_instance(self) -> None: - """Test that get_registry_service returns a RegistryService.""" - service = get_registry_service() - - assert isinstance(service, RegistryService) diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index a95aed843d..5b6b023b56 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -5,27 +5,21 @@ Tests for backend target service. """ -from datetime import datetime, timezone from unittest.mock import MagicMock, patch import pytest -from pyrit.backend.models.targets import CreateTargetRequest, TargetInstance +from pyrit.backend.models.targets import CreateTargetRequest from pyrit.backend.services.target_service import TargetService +from pyrit.registry.instance_registries import TargetRegistry -class TestTargetServiceInit: - """Tests for TargetService initialization.""" - - def test_init_creates_empty_instances_dict(self) -> None: - """Test that service initializes with empty instances dictionary.""" - service = TargetService() - assert service._instances == {} - - def test_init_creates_empty_target_objects_dict(self) -> None: - """Test that service initializes with empty target objects dictionary.""" - service = TargetService() - assert service._target_objects == {} +@pytest.fixture(autouse=True) +def reset_registry(): + """Reset the TargetRegistry singleton before each test.""" + TargetRegistry.reset_instance() + yield + TargetRegistry.reset_instance() class TestListTargets: @@ -41,77 +35,20 @@ async def test_list_targets_returns_empty_when_no_targets(self) -> None: assert result.items == [] @pytest.mark.asyncio - async def test_list_targets_returns_targets(self) -> None: - """Test that list_targets returns existing targets.""" + async def test_list_targets_returns_targets_from_registry(self) -> None: + """Test that list_targets returns targets from registry.""" service = TargetService() - now = datetime.now(timezone.utc) - service._instances["target-1"] = TargetInstance( - target_id="target-1", - type="TextTarget", - display_name="My Target", - params={}, - created_at=now, - source="user", - ) + # Register a mock target + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"__type__": "MockTarget", "endpoint": "http://test"} + service._registry.register_instance(mock_target, name="target-1") result = await service.list_targets() assert len(result.items) == 1 assert result.items[0].target_id == "target-1" - assert result.items[0].display_name == "My Target" - - @pytest.mark.asyncio - async def test_list_targets_filters_by_source_user(self) -> None: - """Test that list_targets filters by source='user'.""" - service = TargetService() - now = datetime.now(timezone.utc) - - service._instances["target-1"] = TargetInstance( - target_id="target-1", - type="TextTarget", - params={}, - created_at=now, - source="user", - ) - service._instances["target-2"] = TargetInstance( - target_id="target-2", - type="TextTarget", - params={}, - created_at=now, - source="initializer", - ) - - result = await service.list_targets(source="user") - - assert len(result.items) == 1 - assert result.items[0].source == "user" - - @pytest.mark.asyncio - async def test_list_targets_filters_by_source_initializer(self) -> None: - """Test that list_targets filters by source='initializer'.""" - service = TargetService() - now = datetime.now(timezone.utc) - - service._instances["target-1"] = TargetInstance( - target_id="target-1", - type="TextTarget", - params={}, - created_at=now, - source="user", - ) - service._instances["target-2"] = TargetInstance( - target_id="target-2", - type="TextTarget", - params={}, - created_at=now, - source="initializer", - ) - - result = await service.list_targets(source="initializer") - - assert len(result.items) == 1 - assert result.items[0].source == "initializer" + assert result.items[0].type == "MockTarget" class TestGetTarget: @@ -127,25 +64,19 @@ async def test_get_target_returns_none_for_nonexistent(self) -> None: assert result is None @pytest.mark.asyncio - async def test_get_target_returns_target(self) -> None: - """Test that get_target returns the target instance.""" + async def test_get_target_returns_target_from_registry(self) -> None: + """Test that get_target returns target built from registry object.""" service = TargetService() - now = datetime.now(timezone.utc) - service._instances["target-1"] = TargetInstance( - target_id="target-1", - type="TextTarget", - display_name="Test Target", - params={"key": "value"}, - created_at=now, - source="user", - ) + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + service._registry.register_instance(mock_target, name="target-1") result = await service.get_target("target-1") assert result is not None assert result.target_id == "target-1" - assert result.display_name == "Test Target" + assert result.type == "MockTarget" class TestGetTargetObject: @@ -159,11 +90,11 @@ def test_get_target_object_returns_none_for_nonexistent(self) -> None: assert result is None - def test_get_target_object_returns_object(self) -> None: + def test_get_target_object_returns_object_from_registry(self) -> None: """Test that get_target_object returns the actual target object.""" service = TargetService() mock_target = MagicMock() - service._target_objects["target-1"] = mock_target + service._registry.register_instance(mock_target, name="target-1") result = service.get_target_object("target-1") @@ -184,7 +115,6 @@ def test_get_target_class_finds_text_target(self) -> None: """Test that _get_target_class finds TextTarget.""" service = TargetService() - # TextTarget should exist in pyrit.prompt_target result = service._get_target_class("TextTarget") assert result is not None @@ -212,7 +142,6 @@ async def test_create_target_success(self) -> None: """Test successful target creation.""" service = TargetService() - # Use a target that doesn't require external dependencies request = CreateTargetRequest( type="TextTarget", display_name="My Text Target", @@ -224,11 +153,10 @@ async def test_create_target_success(self) -> None: assert result.target_id is not None assert result.type == "TextTarget" assert result.display_name == "My Text Target" - assert result.source == "user" @pytest.mark.asyncio - async def test_create_target_stores_instance(self) -> None: - """Test that create_target stores the instance.""" + async def test_create_target_registers_in_registry(self) -> None: + """Test that create_target registers object in registry.""" service = TargetService() request = CreateTargetRequest( @@ -238,15 +166,15 @@ async def test_create_target_stores_instance(self) -> None: result = await service.create_target(request) - assert result.target_id in service._instances - assert result.target_id in service._target_objects + # Object should be retrievable from registry + target_obj = service.get_target_object(result.target_id) + assert target_obj is not None @pytest.mark.asyncio async def test_create_target_filters_sensitive_params(self) -> None: """Test that create_target filters sensitive parameters.""" service = TargetService() - # Create a mock target class that has sensitive identifier fields mock_target_class = MagicMock() mock_target_instance = MagicMock() mock_target_instance.get_identifier.return_value = { @@ -266,86 +194,14 @@ async def test_create_target_filters_sensitive_params(self) -> None: # api_key should be filtered out assert "api_key" not in result.params - # endpoint should remain assert result.params.get("endpoint") == "https://api.example.com" -class TestDeleteTarget: - """Tests for TargetService.delete_target method.""" - - @pytest.mark.asyncio - async def test_delete_target_returns_false_for_nonexistent(self) -> None: - """Test that delete_target returns False for non-existent target.""" - service = TargetService() - - result = await service.delete_target("nonexistent") - - assert result is False - - @pytest.mark.asyncio - async def test_delete_target_deletes_target(self) -> None: - """Test that delete_target removes the target.""" - service = TargetService() - now = datetime.now(timezone.utc) - - service._instances["target-1"] = TargetInstance( - target_id="target-1", - type="TextTarget", - params={}, - created_at=now, - source="user", - ) - service._target_objects["target-1"] = MagicMock() - - result = await service.delete_target("target-1") - - assert result is True - assert "target-1" not in service._instances - assert "target-1" not in service._target_objects - - -class TestRegisterInitializerTarget: - """Tests for TargetService.register_initializer_target method.""" - - @pytest.mark.asyncio - async def test_register_initializer_target_creates_instance(self) -> None: - """Test that register_initializer_target creates an instance.""" - service = TargetService() - mock_target = MagicMock() - mock_target.get_identifier.return_value = {"type": "MockTarget"} - - result = await service.register_initializer_target( - target_type="MockTarget", - target_obj=mock_target, - display_name="Initializer Target", - ) - - assert result.target_id is not None - assert result.type == "MockTarget" - assert result.display_name == "Initializer Target" - assert result.source == "initializer" - - @pytest.mark.asyncio - async def test_register_initializer_target_stores_object(self) -> None: - """Test that register_initializer_target stores the target object.""" - service = TargetService() - mock_target = MagicMock() - mock_target.get_identifier.return_value = {} - - result = await service.register_initializer_target( - target_type="MockTarget", - target_obj=mock_target, - ) - - assert service._target_objects[result.target_id] is mock_target - - class TestTargetServiceSingleton: """Tests for get_target_service singleton function.""" def test_get_target_service_returns_target_service(self) -> None: """Test that get_target_service returns a TargetService instance.""" - # Reset singleton for clean test import pyrit.backend.services.target_service as module from pyrit.backend.services.target_service import get_target_service @@ -356,7 +212,6 @@ def test_get_target_service_returns_target_service(self) -> None: def test_get_target_service_returns_same_instance(self) -> None: """Test that get_target_service returns the same instance.""" - # Reset singleton for clean test import pyrit.backend.services.target_service as module from pyrit.backend.services.target_service import get_target_service From 5fb3f2af4c5459a54dfb2807ff24ff0b10de3973 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 5 Feb 2026 13:05:01 -0800 Subject: [PATCH 08/35] more simplifications --- pyrit/backend/models/__init__.py | 6 - pyrit/backend/models/common.py | 52 +-- pyrit/backend/models/converters.py | 4 +- pyrit/backend/models/targets.py | 2 +- pyrit/backend/routes/converters.py | 16 +- pyrit/backend/routes/targets.py | 15 +- pyrit/backend/services/converter_service.py | 133 ++++---- pyrit/backend/services/target_service.py | 101 +++--- .../base_instance_registry.py | 9 + tests/unit/backend/test_api_routes.py | 72 ++--- tests/unit/backend/test_common_models.py | 29 -- tests/unit/backend/test_converter_service.py | 296 ++++++++++++------ tests/unit/backend/test_target_service.py | 48 +-- 13 files changed, 397 insertions(+), 386 deletions(-) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index d338454db5..4e3bcf6635 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -23,11 +23,8 @@ UpdateAttackRequest, ) from pyrit.backend.models.common import ( - ALLOWED_IDENTIFIER_FIELDS, SENSITIVE_FIELD_PATTERNS, FieldError, - IdentifierDict, - PaginatedResponse, PaginationInfo, ProblemDetail, filter_sensitive_fields, @@ -64,12 +61,9 @@ "Score", "UpdateAttackRequest", # Common - "ALLOWED_IDENTIFIER_FIELDS", "SENSITIVE_FIELD_PATTERNS", "FieldError", "filter_sensitive_fields", - "IdentifierDict", - "PaginatedResponse", "PaginationInfo", "ProblemDetail", # Converters diff --git a/pyrit/backend/models/common.py b/pyrit/backend/models/common.py index 3aadc8157a..44203ddcd0 100644 --- a/pyrit/backend/models/common.py +++ b/pyrit/backend/models/common.py @@ -7,13 +7,10 @@ Includes pagination, error handling (RFC 7807), and shared base models. """ -from datetime import datetime -from typing import Any, Generic, List, Optional, TypeVar +from typing import Any, List, Optional from pydantic import BaseModel, Field -T = TypeVar("T") - class PaginationInfo(BaseModel): """Pagination metadata for list responses.""" @@ -24,13 +21,6 @@ class PaginationInfo(BaseModel): prev_cursor: Optional[str] = Field(None, description="Cursor for previous page") -class PaginatedResponse(BaseModel, Generic[T]): - """Generic paginated response wrapper.""" - - items: List[T] = Field(..., description="List of items") - pagination: PaginationInfo = Field(..., description="Pagination metadata") - - class FieldError(BaseModel): """Individual field validation error.""" @@ -55,30 +45,6 @@ class ProblemDetail(BaseModel): errors: Optional[List[FieldError]] = Field(None, description="Field-level errors for validation") -class IdentifierDict(BaseModel): - """ - Represents a filtered identifier dictionary. - - Only contains safe fields (no API keys, tokens, etc.). - Uses 'type_' and 'module_' as field names but serializes to '__type__' and '__module__'. - """ - - type_: str = Field(..., alias="__type__", description="Class name") - module_: Optional[str] = Field(None, alias="__module__", description="Module path") - - model_config = { - "extra": "allow", # Allow additional fields like endpoint, model_name, etc. - "populate_by_name": True, - } - - -class TimestampMixin(BaseModel): - """Mixin for models with timestamps.""" - - timestamp: datetime = Field(..., description="Creation/event timestamp") - created_at: Optional[datetime] = Field(None, description="Resource creation time") - - # Sensitive field patterns to filter from identifiers SENSITIVE_FIELD_PATTERNS = frozenset( [ @@ -93,22 +59,6 @@ class TimestampMixin(BaseModel): ] ) -# Fields allowed in identifier responses -ALLOWED_IDENTIFIER_FIELDS = frozenset( - [ - "__type__", - "__module__", - "endpoint", - "model_name", - "deployment_name", - "underlying_model", - "temperature", - "top_p", - "language", - "tone", - ] -) - def filter_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]: """ diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index 903049ea1b..fd57dc0a9a 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -47,11 +47,11 @@ class ConverterInstanceListResponse(BaseModel): class CreateConverterRequest(BaseModel): """Request to create a new converter instance.""" - type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") + type: str = Field(..., description="Converter type (e.g., 'Base64Converter')") display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field( default_factory=dict, - description="Converter parameters", + description="Converter constructor parameters", ) diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index e1d1bf1dbb..eaa151130c 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -38,7 +38,7 @@ class TargetListResponse(BaseModel): class CreateTargetRequest(BaseModel): """Request to create a new target instance.""" - type: str = Field(..., description="Target type (e.g., 'azure_openai', 'text_target')") + type: str = Field(..., description="Target type (e.g., 'OpenAIChatTarget')") display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field(default_factory=dict, description="Target constructor parameters") diff --git a/pyrit/backend/routes/converters.py b/pyrit/backend/routes/converters.py index c1319104ac..45134a7adc 100644 --- a/pyrit/backend/routes/converters.py +++ b/pyrit/backend/routes/converters.py @@ -8,9 +8,7 @@ Converter types are set at app startup - you cannot add new types at runtime. """ -from typing import Literal, Optional - -from fastapi import APIRouter, HTTPException, Query, status +from fastapi import APIRouter, HTTPException, status from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.converters import ( @@ -30,11 +28,7 @@ "", response_model=ConverterInstanceListResponse, ) -async def list_converters( - source: Optional[Literal["initializer", "user"]] = Query( - None, description="Filter by source (initializer or user)" - ), -) -> ConverterInstanceListResponse: +async def list_converters() -> ConverterInstanceListResponse: """ List converter instances. @@ -44,7 +38,7 @@ async def list_converters( ConverterInstanceListResponse: List of converter instances. """ service = get_converter_service() - return await service.list_converters(source=source) + return await service.list_converters() @router.post( @@ -59,8 +53,8 @@ async def create_converter(request: CreateConverterRequest) -> CreateConverterRe """ Create a new converter instance. - Supports nested converters - if params contains a 'converter' key with - a type/params object, the nested converter will be created first. + Instantiates a converter with the given type and parameters. + Supports nested converters via converter_id references in params. Returns: CreateConverterResponse: The created converter instance details. diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index 03aa41c8f2..41b5ab3bdc 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -8,9 +8,7 @@ Target types are set at app startup via initializers - you cannot add new types at runtime. """ -from typing import Literal, Optional - -from fastapi import APIRouter, HTTPException, Query, status +from fastapi import APIRouter, HTTPException, status from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.targets import ( @@ -28,22 +26,17 @@ "", response_model=TargetListResponse, ) -async def list_targets( - source: Optional[Literal["initializer", "user"]] = Query( - None, description="Filter by source (initializer or user)" - ), -) -> TargetListResponse: +async def list_targets() -> TargetListResponse: """ List target instances. - Returns all registered target instances. Use source filter to distinguish - between initializer-created (startup) and user-created (API) targets. + Returns all registered target instances. Returns: TargetListResponse: List of target instances. """ service = get_target_service() - return await service.list_targets(source=source) + return await service.list_targets() @router.post( diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index edc471d205..52947a63c5 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -5,16 +5,17 @@ Converter service for managing converter instances. Handles creation, retrieval, and preview of converters. -Uses ConverterRegistry as the source of truth. +Uses ConverterRegistry as the source of truth for instances. -If a converter requires another converter (e.g., SelectiveTextConverter), -the inner converter must be created first and passed by ID in params. +Converters can be: +- Created via API request (instantiated from request params, then registered) +- Retrieved from registry (pre-registered at startup or created earlier) """ -import importlib import uuid -from typing import Any, Dict, List, Literal, Optional, Tuple, cast +from typing import Any, List, Optional, Tuple +from pyrit import prompt_converter from pyrit.backend.models.converters import ( ConverterInstance, ConverterInstanceListResponse, @@ -25,9 +26,31 @@ PreviewStep, ) from pyrit.models import PromptDataType +from pyrit.prompt_converter import PromptConverter from pyrit.registry.instance_registries import ConverterRegistry +def _build_converter_class_registry() -> dict[str, type]: + """ + Build a registry mapping converter class names to their classes. + + Uses the prompt_converter module's __all__ to discover all available converters. + + Returns: + Dict mapping class name (str) to class (type). + """ + registry: dict[str, type] = {} + for name in prompt_converter.__all__: + cls = getattr(prompt_converter, name, None) + if cls is not None and isinstance(cls, type) and issubclass(cls, PromptConverter): + registry[name] = cls + return registry + + +# Module-level class registry (built once on import) +_CONVERTER_CLASS_REGISTRY: dict[str, type] = _build_converter_class_registry() + + class ConverterService: """ Service for managing converter instances. @@ -44,36 +67,35 @@ def _build_instance_from_object(self, converter_id: str, converter_obj: Any) -> """ Build a ConverterInstance from a registry object. + Uses the converter's identifier to extract all relevant metadata. + Returns: - ConverterInstance with metadata derived from the object. + ConverterInstance with metadata derived from the object's identifier. """ - converter_type = converter_obj.__class__.__name__ + identifier = converter_obj.get_identifier() + identifier_dict = identifier.to_dict() + return ConverterInstance( converter_id=converter_id, - type=converter_type, + type=identifier_dict.get("class_name", converter_obj.__class__.__name__), display_name=None, - params={}, # Params aren't stored on converter objects + params=identifier_dict, ) # ======================================================================== # Public API Methods # ======================================================================== - async def list_converters( - self, source: Optional[Literal["initializer", "user"]] = None - ) -> ConverterInstanceListResponse: + async def list_converters(self) -> ConverterInstanceListResponse: """ List all converter instances. Returns: ConverterInstanceListResponse containing all registered converters. """ - # source filter is ignored for now - all come from registry - items: List[ConverterInstance] = [] - for name in self._registry.get_names(): - obj = self._registry.get_instance_by_name(name) - if obj: - items.append(self._build_instance_from_object(name, obj)) + items = [ + self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() + ] return ConverterInstanceListResponse(items=items) async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: @@ -99,19 +121,26 @@ def get_converter_object(self, converter_id: str) -> Optional[Any]: async def create_converter(self, request: CreateConverterRequest) -> CreateConverterResponse: """ - Create a new converter instance. + Create a new converter instance from API request. + + Instantiates the converter with the given type and params, + then registers it in the registry. - If params contains a 'converter' key with a converter_id, - the referenced converter object will be resolved and passed. + Args: + request: The create converter request with type and params. Returns: CreateConverterResponse with the new converter's details. + + Raises: + ValueError: If the converter type is not found. """ converter_id = str(uuid.uuid4()) - # Resolve any converter references in params and create the object + # Resolve any converter references in params and instantiate params = self._resolve_converter_params(request.params) - converter_obj = self._get_converter_class(request.type)(**params) + converter_class = self._get_converter_class(request.type) + converter_obj = converter_class(**params) self._registry.register_instance(converter_obj, name=converter_id) return CreateConverterResponse( @@ -160,10 +189,36 @@ def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: # Private Helper Methods # ======================================================================== - def _resolve_converter_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + def _get_converter_class(self, converter_type: str) -> type: + """ + Get the converter class for a given type name. + + Looks up the class in the module-level converter class registry. + + Args: + converter_type: The exact class name of the converter (e.g., 'Base64Converter'). + + Returns: + The converter class. + + Raises: + ValueError: If the converter type is not found. + """ + cls = _CONVERTER_CLASS_REGISTRY.get(converter_type) + if cls is None: + raise ValueError( + f"Converter type '{converter_type}' not found. " + f"Available types: {sorted(_CONVERTER_CLASS_REGISTRY.keys())}" + ) + return cls + + def _resolve_converter_params(self, params: dict[str, Any]) -> dict[str, Any]: """ Resolve converter references in params. + If params contains a 'converter' key with a converter_id reference, + resolve it to the actual converter object from the registry. + Returns: Params dict with converter_id references replaced by actual objects. """ @@ -177,36 +232,6 @@ def _resolve_converter_params(self, params: Dict[str, Any]) -> Dict[str, Any]: resolved["converter"] = conv_obj return resolved - def _get_converter_class(self, converter_type: str) -> type: - """ - Get the converter class for a given type. - - Returns: - The converter class matching the given type. - """ - module = importlib.import_module("pyrit.prompt_converter") - - cls = getattr(module, converter_type, None) - if cls is not None: - return cast(type, cls) - - for pattern in self._class_name_patterns(converter_type): - cls = getattr(module, pattern, None) - if cls is not None: - return cast(type, cls) - - raise ValueError(f"Converter type '{converter_type}' not found in pyrit.prompt_converter") - - def _class_name_patterns(self, type_name: str) -> List[str]: - """ - Generate class name patterns to try. - - Returns: - List of possible class name variations. - """ - pascal = "".join(word.capitalize() for word in type_name.split("_")) - return [type_name, f"{type_name}Converter", pascal, f"{pascal}Converter"] - def _gather_converters(self, converter_ids: List[str]) -> List[Tuple[str, str, Any]]: """ Gather converters to apply from IDs. diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index a55911dfd2..16e3bc4c2a 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -4,14 +4,18 @@ """ Target service for managing target instances. -Handles creation, retrieval, and lifecycle of runtime target instances. -Uses TargetRegistry as the source of truth. +Handles creation and retrieval of target instances. +Uses TargetRegistry as the source of truth for instances. + +Targets can be: +- Created via API request (instantiated from request params, then registered) +- Retrieved from registry (pre-registered at startup or created earlier) """ -import importlib import uuid -from typing import Any, List, Literal, Optional, cast +from typing import Any, Optional +from pyrit import prompt_target from pyrit.backend.models.common import filter_sensitive_fields from pyrit.backend.models.targets import ( CreateTargetRequest, @@ -19,9 +23,31 @@ TargetInstance, TargetListResponse, ) +from pyrit.prompt_target import PromptTarget from pyrit.registry.instance_registries import TargetRegistry +def _build_target_class_registry() -> dict[str, type]: + """ + Build a registry mapping target class names to their classes. + + Uses the prompt_target module's __all__ to discover all available targets. + + Returns: + Dict mapping class name (str) to class (type). + """ + registry: dict[str, type] = {} + for name in prompt_target.__all__: + cls = getattr(prompt_target, name, None) + if cls is not None and isinstance(cls, type) and issubclass(cls, PromptTarget): + registry[name] = cls + return registry + + +# Module-level class registry (built once on import) +_TARGET_CLASS_REGISTRY: dict[str, type] = _build_target_class_registry() + + class TargetService: """ Service for managing target instances. @@ -36,30 +62,25 @@ def __init__(self) -> None: def _get_target_class(self, target_type: str) -> type: """ - Get the target class for a given type. + Get the target class for a given type name. - Returns: - The target class matching the given type. - """ - module = importlib.import_module("pyrit.prompt_target") + Looks up the class in the module-level target class registry. - cls = getattr(module, target_type, None) - if cls is not None: - return cast(type, cls) + Args: + target_type: The exact class name of the target (e.g., 'TextTarget'). - class_name_patterns = [ - target_type, - f"{target_type}Target", - "".join(word.capitalize() for word in target_type.split("_")), - "".join(word.capitalize() for word in target_type.split("_")) + "Target", - ] - - for pattern in class_name_patterns: - cls = getattr(module, pattern, None) - if cls is not None: - return cast(type, cls) + Returns: + The target class. - raise ValueError(f"Target type '{target_type}' not found in pyrit.prompt_target") + Raises: + ValueError: If the target type is not found. + """ + cls = _TARGET_CLASS_REGISTRY.get(target_type) + if cls is None: + raise ValueError( + f"Target type '{target_type}' not found. Available types: {sorted(_TARGET_CLASS_REGISTRY.keys())}" + ) + return cls def _build_instance_from_object(self, target_id: str, target_obj: Any) -> TargetInstance: """ @@ -79,22 +100,16 @@ def _build_instance_from_object(self, target_id: str, target_obj: Any) -> Target params=filtered_params, ) - async def list_targets( - self, - source: Optional[Literal["initializer", "user"]] = None, - ) -> TargetListResponse: + async def list_targets(self) -> TargetListResponse: """ List all target instances. Returns: TargetListResponse containing all registered targets. """ - # source filter is ignored for now - all come from registry - items: List[TargetInstance] = [] - for name in self._registry.get_names(): - obj = self._registry.get_instance_by_name(name) - if obj: - items.append(self._build_instance_from_object(name, obj)) + items = [ + self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() + ] return TargetListResponse(items=items) async def get_target(self, target_id: str) -> Optional[TargetInstance]: @@ -118,19 +133,25 @@ def get_target_object(self, target_id: str) -> Optional[Any]: """ return self._registry.get_instance_by_name(target_id) - async def create_target( - self, - request: CreateTargetRequest, - ) -> CreateTargetResponse: + async def create_target(self, request: CreateTargetRequest) -> CreateTargetResponse: """ - Create a new target instance. + Create a new target instance from API request. + + Instantiates the target with the given type and params, + then registers it in the registry. + + Args: + request: The create target request with type and params. Returns: CreateTargetResponse with the new target's details. + + Raises: + ValueError: If the target type is not found. """ target_id = str(uuid.uuid4()) - # Create and register the target object + # Instantiate from request params and register target_class = self._get_target_class(request.type) target_obj = target_class(**request.params) self._registry.register_instance(target_obj, name=target_id) diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index 18fc320a25..0897e2efbd 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -111,6 +111,15 @@ def get_names(self) -> List[str]: """ return sorted(self._registry_items.keys()) + def get_all_instances(self) -> Dict[str, T]: + """ + Get all registered instances as a name -> instance mapping. + + Returns: + Dict mapping registry names to their instances. + """ + return dict(self._registry_items) + def list_metadata( self, *, diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index c3acdb2e3f..420a3be142 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -438,22 +438,8 @@ def test_list_targets_returns_empty_list(self, client: TestClient) -> None: data = response.json() assert data["items"] == [] - def test_list_targets_with_source_filter(self, client: TestClient) -> None: - """Test that list targets accepts source filter.""" - with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: - mock_service = MagicMock() - mock_service.list_targets = AsyncMock(return_value=TargetListResponse(items=[])) - mock_get_service.return_value = mock_service - - response = client.get("/api/targets", params={"source": "user"}) - - assert response.status_code == status.HTTP_200_OK - mock_service.list_targets.assert_called_once_with(source="user") - def test_create_target_success(self, client: TestClient) -> None: """Test successful target creation.""" - now = datetime.now(timezone.utc) - with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() mock_service.create_target = AsyncMock( @@ -489,6 +475,20 @@ def test_create_target_invalid_type(self, client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST + def test_create_target_internal_error(self, client: TestClient) -> None: + """Test target creation with internal error returns 500.""" + with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_target = AsyncMock(side_effect=RuntimeError("Unexpected error")) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/targets", + json={"type": "TextTarget", "params": {}}, + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + def test_get_target_success(self, client: TestClient) -> None: """Test getting a target by ID.""" now = datetime.now(timezone.utc) @@ -522,20 +522,6 @@ def test_get_target_not_found(self, client: TestClient) -> None: assert response.status_code == status.HTTP_404_NOT_FOUND - def test_create_target_internal_error(self, client: TestClient) -> None: - """Test target creation with internal error returns 500.""" - with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: - mock_service = MagicMock() - mock_service.create_target = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) - mock_get_service.return_value = mock_service - - response = client.post( - "/api/targets", - json={"type": "TextTarget", "params": {}}, - ) - - assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - # ============================================================================ # Converter Routes Tests @@ -560,8 +546,6 @@ def test_list_converters(self, client: TestClient) -> None: def test_create_converter_success(self, client: TestClient) -> None: """Test successful converter instance creation.""" - now = datetime.now(timezone.utc) - with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() mock_service.create_converter = AsyncMock( @@ -597,6 +581,20 @@ def test_create_converter_invalid_type(self, client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST + def test_create_converter_internal_error(self, client: TestClient) -> None: + """Test converter creation with internal error returns 500.""" + with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create_converter = AsyncMock(side_effect=RuntimeError("Unexpected error")) + mock_get_service.return_value = mock_service + + response = client.post( + "/api/converters", + json={"type": "Base64Converter", "params": {}}, + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + def test_get_converter_success(self, client: TestClient) -> None: """Test getting a converter instance by ID.""" now = datetime.now(timezone.utc) @@ -668,20 +666,6 @@ def test_preview_conversion_success(self, client: TestClient) -> None: assert data["converted_value"] == "dGVzdA==" assert len(data["steps"]) == 1 - def test_create_converter_internal_error(self, client: TestClient) -> None: - """Test converter creation with internal error returns 500.""" - with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: - mock_service = MagicMock() - mock_service.create_converter = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) - mock_get_service.return_value = mock_service - - response = client.post( - "/api/converters", - json={"type": "Base64Converter", "params": {}}, - ) - - assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - def test_preview_conversion_bad_request(self, client: TestClient) -> None: """Test preview conversion with invalid converter ID returns 400.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: diff --git a/tests/unit/backend/test_common_models.py b/tests/unit/backend/test_common_models.py index 803e58db76..00ae1b794f 100644 --- a/tests/unit/backend/test_common_models.py +++ b/tests/unit/backend/test_common_models.py @@ -7,8 +7,6 @@ from pyrit.backend.models.common import ( FieldError, - IdentifierDict, - PaginatedResponse, PaginationInfo, ProblemDetail, filter_sensitive_fields, @@ -42,22 +40,6 @@ def test_pagination_info_full(self) -> None: assert info.prev_cursor == "prev" -class TestPaginatedResponse: - """Tests for PaginatedResponse model.""" - - def test_paginated_response_with_strings(self) -> None: - """Test creating a paginated response with string items.""" - pagination = PaginationInfo(limit=10, has_more=False) - response = PaginatedResponse[str]( - items=["a", "b", "c"], - pagination=pagination, - ) - - assert len(response.items) == 3 - assert response.items[0] == "a" - assert response.pagination.limit == 10 - - class TestFieldError: """Tests for FieldError model.""" @@ -123,17 +105,6 @@ def test_problem_detail_with_errors(self) -> None: assert problem.instance == "/api/v1/test" -class TestIdentifierDict: - """Tests for IdentifierDict model.""" - - def test_identifier_dict_creation(self) -> None: - """Test creating an IdentifierDict.""" - identifier = IdentifierDict(__type__="TestClass", __module__="pyrit.test") - - assert identifier.type_ == "TestClass" - assert identifier.module_ == "pyrit.test" - - class TestFilterSensitiveFields: """Tests for filter_sensitive_fields function.""" diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index c17e15bf26..2b3e0576dd 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -5,15 +5,24 @@ Tests for backend converter service. """ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest +import pyrit.backend.services.converter_service as converter_service_module +from pyrit import prompt_converter from pyrit.backend.models.converters import ( ConverterPreviewRequest, CreateConverterRequest, ) -from pyrit.backend.services.converter_service import ConverterService +from pyrit.backend.services.converter_service import ConverterService, get_converter_service +from pyrit.prompt_converter import ( + Base64Converter, + CaesarConverter, + RepeatTokenConverter, + SuffixAppendConverter, +) +from pyrit.prompt_converter.prompt_converter import get_converter_modalities from pyrit.registry.instance_registries import ConverterRegistry @@ -39,11 +48,19 @@ async def test_list_converters_returns_empty_when_no_converters(self) -> None: @pytest.mark.asyncio async def test_list_converters_returns_converters_from_registry(self) -> None: - """Test that list_converters returns converters from registry.""" + """Test that list_converters returns converters from registry with full params.""" service = ConverterService() mock_converter = MagicMock() mock_converter.__class__.__name__ = "MockConverter" + mock_identifier = MagicMock() + mock_identifier.to_dict.return_value = { + "class_name": "MockConverter", + "converter_specific_params": {"param1": "value1", "param2": 42}, + "supported_input_types": ["text"], + "supported_output_types": ["text"], + } + mock_converter.get_identifier.return_value = mock_identifier service._registry.register_instance(mock_converter, name="conv-1") result = await service.list_converters() @@ -51,6 +68,11 @@ async def test_list_converters_returns_converters_from_registry(self) -> None: assert len(result.items) == 1 assert result.items[0].converter_id == "conv-1" assert result.items[0].type == "MockConverter" + # Verify params contains the full identifier dict + assert result.items[0].params["class_name"] == "MockConverter" + assert result.items[0].params["converter_specific_params"] == {"param1": "value1", "param2": 42} + assert result.items[0].params["supported_input_types"] == ["text"] + assert result.items[0].params["supported_output_types"] == ["text"] class TestGetConverter: @@ -72,6 +94,12 @@ async def test_get_converter_returns_converter_from_registry(self) -> None: mock_converter = MagicMock() mock_converter.__class__.__name__ = "MockConverter" + mock_identifier = MagicMock() + mock_identifier.to_dict.return_value = { + "class_name": "MockConverter", + "converter_specific_params": {"param1": "value1"}, + } + mock_converter.get_identifier.return_value = mock_identifier service._registry.register_instance(mock_converter, name="conv-1") result = await service.get_converter("conv-1") @@ -103,34 +131,6 @@ def test_get_converter_object_returns_object_from_registry(self) -> None: assert result is mock_converter -class TestGetConverterClass: - """Tests for ConverterService._get_converter_class method.""" - - def test_get_converter_class_raises_for_invalid_type(self) -> None: - """Test that _get_converter_class raises ValueError for invalid type.""" - service = ConverterService() - - with pytest.raises(ValueError, match="not found"): - service._get_converter_class("NonExistentConverter") - - def test_get_converter_class_finds_base64_converter(self) -> None: - """Test that _get_converter_class finds Base64Converter.""" - service = ConverterService() - - result = service._get_converter_class("Base64Converter") - - assert result is not None - assert "Base64" in result.__name__ - - def test_get_converter_class_handles_snake_case(self) -> None: - """Test that _get_converter_class handles snake_case names.""" - service = ConverterService() - - result = service._get_converter_class("base64") - - assert result is not None - - class TestCreateConverter: """Tests for ConverterService.create_converter method.""" @@ -181,6 +181,51 @@ async def test_create_converter_registers_in_registry(self) -> None: assert converter_obj is not None +class TestResolveConverterParams: + """Tests for ConverterService._resolve_converter_params method.""" + + def test_resolve_converter_params_returns_params_unchanged_when_no_converter_ref(self) -> None: + """Test that params without converter reference are returned unchanged.""" + service = ConverterService() + params = {"key": "value", "number": 42} + + result = service._resolve_converter_params(params) + + assert result == params + + def test_resolve_converter_params_resolves_converter_id_reference(self) -> None: + """Test that converter_id reference is resolved to actual object.""" + service = ConverterService() + + # Register a mock converter + mock_converter = MagicMock() + service._registry.register_instance(mock_converter, name="inner-conv") + + params = {"converter": {"converter_id": "inner-conv"}} + + result = service._resolve_converter_params(params) + + assert result["converter"] is mock_converter + + def test_resolve_converter_params_raises_for_nonexistent_reference(self) -> None: + """Test that referencing a non-existent converter raises ValueError.""" + service = ConverterService() + + params = {"converter": {"converter_id": "nonexistent"}} + + with pytest.raises(ValueError, match="not found"): + service._resolve_converter_params(params) + + def test_resolve_converter_params_ignores_non_dict_converter(self) -> None: + """Test that non-dict converter values are not modified.""" + service = ConverterService() + params = {"converter": "some_string_value"} + + result = service._resolve_converter_params(params) + + assert result == params + + class TestPreviewConversion: """Tests for ConverterService.preview_conversion method.""" @@ -283,80 +328,151 @@ def test_get_converter_objects_for_ids_returns_objects(self) -> None: assert result == [mock1, mock2] -class TestConverterWithReferencedConverter: - """Tests for creating converters that reference other converters by ID.""" +class TestConverterServiceSingleton: + """Tests for get_converter_service singleton function.""" - @pytest.mark.asyncio - async def test_create_converter_with_referenced_converter(self) -> None: - """Test creating a converter that references another converter by ID.""" - service = ConverterService() + def test_get_converter_service_returns_converter_service(self) -> None: + """Test that get_converter_service returns a ConverterService instance.""" + converter_service_module._converter_service = None - mock_inner_class = MagicMock() - mock_inner_instance = MagicMock() - mock_inner_class.return_value = mock_inner_instance + service = get_converter_service() + assert isinstance(service, ConverterService) - mock_outer_class = MagicMock() - mock_outer_instance = MagicMock() - mock_outer_class.return_value = mock_outer_instance + def test_get_converter_service_returns_same_instance(self) -> None: + """Test that get_converter_service returns the same instance.""" + converter_service_module._converter_service = None - def mock_get_class(converter_type: str) -> type: - if converter_type == "OuterConverter": - return mock_outer_class - elif converter_type == "InnerConverter": - return mock_inner_class - raise ValueError(f"Unknown type: {converter_type}") + service1 = get_converter_service() + service2 = get_converter_service() + assert service1 is service2 - with patch.object(service, "_get_converter_class", side_effect=mock_get_class): - inner_result = await service.create_converter(CreateConverterRequest(type="InnerConverter", params={})) - inner_id = inner_result.converter_id - await service.create_converter( - CreateConverterRequest( - type="OuterConverter", - params={"converter": {"converter_id": inner_id}}, - ) - ) +# ============================================================================ +# Real Converter Integration Tests +# ============================================================================ - mock_outer_class.assert_called() - call_kwargs = mock_outer_class.call_args[1] - assert call_kwargs.get("converter") is mock_inner_instance - @pytest.mark.asyncio - async def test_create_converter_with_invalid_reference_raises(self) -> None: - """Test that referencing a non-existent converter raises ValueError.""" +def _get_all_converter_names() -> list[str]: + """ + Dynamically collect all converter class names from the codebase. + + Uses get_converter_modalities() which reads from prompt_converter.__all__ + and filters to only actual PromptConverter subclasses. + """ + return [name for name, _, _ in get_converter_modalities()] + + +def _try_instantiate_converter(converter_name: str): + """ + Try to instantiate a converter with no arguments. + + Returns: + Tuple of (converter_instance, error_message). + If successful, error_message is None. + If failed, converter_instance is None and error_message explains why. + """ + converter_cls = getattr(prompt_converter, converter_name, None) + if converter_cls is None: + return None, f"Converter {converter_name} not found in prompt_converter module" + + try: + instance = converter_cls() + return instance, None + except Exception as e: + return None, f"Could not instantiate {converter_name} with no args: {e}" + + +# Get all converter names dynamically +ALL_CONVERTERS = _get_all_converter_names() + + +class TestBuildInstanceFromObjectWithRealConverters: + """ + Integration tests that verify _build_instance_from_object works with real converters. + + These tests ensure the identifier extraction works correctly across all converter types. + Uses dynamic discovery to test ALL converters in the codebase. + """ + + @pytest.mark.parametrize("converter_name", ALL_CONVERTERS) + def test_build_instance_from_converter(self, converter_name: str) -> None: + """ + Test that _build_instance_from_object works with each converter. + + For converters that can be instantiated with no arguments, verifies: + - converter_id is set correctly + - type matches the class name + - params contains class_name from the identifier + + For converters requiring arguments, the test is skipped (since we can't + know the required parameters without external configuration). + """ + # Try to instantiate the converter + converter_instance, error = _try_instantiate_converter(converter_name) + + if error: + pytest.skip(error) + + # Build the instance using the service method service = ConverterService() + result = service._build_instance_from_object("test-id", converter_instance) - mock_class = MagicMock() - with patch.object(service, "_get_converter_class", return_value=mock_class): - with pytest.raises(ValueError, match="not found"): - await service.create_converter( - CreateConverterRequest( - type="OuterConverter", - params={"converter": {"converter_id": "nonexistent"}}, - ) - ) + # Verify the result + assert result.converter_id == "test-id" + assert result.type == converter_name + assert isinstance(result.params, dict) + # The params should contain at least class_name from the identifier + assert "class_name" in result.params + assert result.params["class_name"] == converter_name -class TestConverterServiceSingleton: - """Tests for get_converter_service singleton function.""" +class TestConverterParamsExtraction: + """ + Tests that verify converter_specific_params are correctly extracted. - def test_get_converter_service_returns_converter_service(self) -> None: - """Test that get_converter_service returns a ConverterService instance.""" - import pyrit.backend.services.converter_service as module - from pyrit.backend.services.converter_service import get_converter_service + Uses converters with known parameters to verify the params are properly + captured from the identifier. + """ - module._converter_service = None + def test_caesar_converter_params(self) -> None: + """Test that CaesarConverter params are extracted correctly.""" + converter = CaesarConverter(caesar_offset=13) + service = ConverterService() + result = service._build_instance_from_object("test-id", converter) - service = get_converter_service() - assert isinstance(service, ConverterService) + assert result.type == "CaesarConverter" + converter_specific = result.params.get("converter_specific_params", {}) + assert converter_specific.get("caesar_offset") == 13 - def test_get_converter_service_returns_same_instance(self) -> None: - """Test that get_converter_service returns the same instance.""" - import pyrit.backend.services.converter_service as module - from pyrit.backend.services.converter_service import get_converter_service + def test_suffix_append_converter_params(self) -> None: + """Test that SuffixAppendConverter params are extracted correctly.""" + converter = SuffixAppendConverter(suffix="test suffix") + service = ConverterService() + result = service._build_instance_from_object("test-id", converter) - module._converter_service = None + assert result.type == "SuffixAppendConverter" + converter_specific = result.params.get("converter_specific_params", {}) + assert converter_specific.get("suffix") == "test suffix" - service1 = get_converter_service() - service2 = get_converter_service() - assert service1 is service2 + def test_repeat_token_converter_params(self) -> None: + """Test that RepeatTokenConverter params are extracted correctly.""" + converter = RepeatTokenConverter(token_to_repeat="x", times_to_repeat=5) + service = ConverterService() + result = service._build_instance_from_object("test-id", converter) + + assert result.type == "RepeatTokenConverter" + converter_specific = result.params.get("converter_specific_params", {}) + assert converter_specific.get("token_to_repeat") == "x" + assert converter_specific.get("times_to_repeat") == 5 + + def test_base64_converter_default_params(self) -> None: + """Test that Base64Converter default params are captured.""" + converter = Base64Converter() + service = ConverterService() + result = service._build_instance_from_object("test-id", converter) + + assert result.type == "Base64Converter" + # Verify params dict is populated from identifier + assert "class_name" in result.params + assert "supported_input_types" in result.params + assert "supported_output_types" in result.params diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 5b6b023b56..6ccfab608a 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -5,7 +5,7 @@ Tests for backend target service. """ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -101,26 +101,6 @@ def test_get_target_object_returns_object_from_registry(self) -> None: assert result is mock_target -class TestGetTargetClass: - """Tests for TargetService._get_target_class method.""" - - def test_get_target_class_raises_for_invalid_type(self) -> None: - """Test that _get_target_class raises ValueError for invalid type.""" - service = TargetService() - - with pytest.raises(ValueError, match="not found"): - service._get_target_class("NonExistentTarget") - - def test_get_target_class_finds_text_target(self) -> None: - """Test that _get_target_class finds TextTarget.""" - service = TargetService() - - result = service._get_target_class("TextTarget") - - assert result is not None - assert "TextTarget" in result.__name__ - - class TestCreateTarget: """Tests for TargetService.create_target method.""" @@ -170,32 +150,6 @@ async def test_create_target_registers_in_registry(self) -> None: target_obj = service.get_target_object(result.target_id) assert target_obj is not None - @pytest.mark.asyncio - async def test_create_target_filters_sensitive_params(self) -> None: - """Test that create_target filters sensitive parameters.""" - service = TargetService() - - mock_target_class = MagicMock() - mock_target_instance = MagicMock() - mock_target_instance.get_identifier.return_value = { - "type": "MockTarget", - "api_key": "secret-key", - "endpoint": "https://api.example.com", - } - mock_target_class.return_value = mock_target_instance - - with patch.object(service, "_get_target_class", return_value=mock_target_class): - request = CreateTargetRequest( - type="MockTarget", - params={}, - ) - - result = await service.create_target(request) - - # api_key should be filtered out - assert "api_key" not in result.params - assert result.params.get("endpoint") == "https://api.example.com" - class TestTargetServiceSingleton: """Tests for get_target_service singleton function.""" From 9b539768d7acc0a8f456601f239a453369ce1de7 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 5 Feb 2026 13:29:37 -0800 Subject: [PATCH 09/35] update with target identifier --- pyrit/backend/services/target_service.py | 10 +++-- .../class_registries/initializer_registry.py | 4 +- .../class_registries/scenario_registry.py | 2 +- .../base_instance_registry.py | 2 +- .../instance_registries/converter_registry.py | 8 ++-- .../instance_registries/scorer_registry.py | 14 ++++--- .../instance_registries/target_registry.py | 38 +++++-------------- tests/unit/backend/test_target_service.py | 4 +- 8 files changed, 34 insertions(+), 48 deletions(-) diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 16e3bc4c2a..23f94212f7 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -90,8 +90,9 @@ def _build_instance_from_object(self, target_id: str, target_obj: Any) -> Target TargetInstance with metadata derived from the object. """ identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} - target_type = identifier.get("__type__", target_obj.__class__.__name__) - filtered_params = filter_sensitive_fields(identifier) + identifier_dict = identifier.to_dict() if hasattr(identifier, "to_dict") else identifier + target_type = identifier_dict.get("__type__", target_obj.__class__.__name__) + filtered_params = filter_sensitive_fields(identifier_dict) return TargetInstance( target_id=target_id, @@ -157,8 +158,9 @@ async def create_target(self, request: CreateTargetRequest) -> CreateTargetRespo self._registry.register_instance(target_obj, name=target_id) # Build response from the object's identifier - identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} - filtered_params = filter_sensitive_fields(identifier) + identifier = target_obj.get_identifier() + identifier_dict = identifier.to_dict() if hasattr(identifier, "to_dict") else identifier + filtered_params = filter_sensitive_fields(identifier_dict) return CreateTargetResponse( target_id=target_id, diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index a334e87e72..657239b639 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -58,7 +58,7 @@ class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetad """ @classmethod - def get_registry_singleton(cls) -> "InitializerRegistry": + def get_registry_singleton(cls) -> InitializerRegistry: """ Get the singleton instance of the InitializerRegistry. @@ -164,7 +164,7 @@ def _register_initializer( *, short_name: str, file_path: Path, - initializer_class: "type[PyRITInitializer]", + initializer_class: type[PyRITInitializer], ) -> None: """ Register an initializer class. diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index f95ad93986..053b3b0968 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -59,7 +59,7 @@ class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): """ @classmethod - def get_registry_singleton(cls) -> "ScenarioRegistry": + def get_registry_singleton(cls) -> ScenarioRegistry: """ Get the singleton instance of the ScenarioRegistry. diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py index 0897e2efbd..78e00a1443 100644 --- a/pyrit/registry/instance_registries/base_instance_registry.py +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -45,7 +45,7 @@ class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, Metadata _instances: Dict[type, "BaseInstanceRegistry[Any, Any]"] = {} @classmethod - def get_registry_singleton(cls) -> "BaseInstanceRegistry[T, MetadataT]": + def get_registry_singleton(cls) -> BaseInstanceRegistry[T, MetadataT]: """ Get the singleton instance of this registry. diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/instance_registries/converter_registry.py index 3439b88c9e..509f2fb68b 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/instance_registries/converter_registry.py @@ -48,7 +48,7 @@ class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ConverterIdentif """ @classmethod - def get_registry_singleton(cls) -> "ConverterRegistry": + def get_registry_singleton(cls) -> ConverterRegistry: """ Get the singleton instance of the ConverterRegistry. @@ -59,7 +59,7 @@ def get_registry_singleton(cls) -> "ConverterRegistry": def register_instance( self, - converter: "PromptConverter", + converter: PromptConverter, *, name: Optional[str] = None, ) -> None: @@ -77,7 +77,7 @@ def register_instance( self.register(converter, name=name) logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional["PromptConverter"]: + def get_instance_by_name(self, name: str) -> Optional[PromptConverter]: """ Get a registered converter instance by name. @@ -89,7 +89,7 @@ def get_instance_by_name(self, name: str) -> Optional["PromptConverter"]: """ return self.get(name) - def _build_metadata(self, name: str, instance: "PromptConverter") -> ConverterIdentifier: + def _build_metadata(self, name: str, instance: PromptConverter) -> ConverterIdentifier: """ Build metadata for a converter instance. diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 9b5e5f59f4..1f4f767fb5 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Optional from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -36,7 +37,7 @@ class ScorerRegistry(BaseInstanceRegistry["Scorer", ScorerIdentifier]): """ @classmethod - def get_registry_singleton(cls) -> "ScorerRegistry": + def get_registry_singleton(cls) -> ScorerRegistry: """ Get the singleton instance of the ScorerRegistry. @@ -47,7 +48,7 @@ def get_registry_singleton(cls) -> "ScorerRegistry": def register_instance( self, - scorer: "Scorer", + scorer: Scorer, *, name: Optional[str] = None, ) -> None: @@ -64,12 +65,15 @@ def register_instance( (e.g., SelfAskRefusalScorer -> self_ask_refusal_abc123). """ if name is None: - name = scorer.get_identifier().unique_name + base_name = class_name_to_snake_case(scorer.__class__.__name__, suffix="Scorer") + # Append identifier hash if available for uniqueness + identifier_hash = scorer.get_identifier().hash[:8] + name = f"{base_name}_{identifier_hash}" self.register(scorer, name=name) logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional["Scorer"]: + def get_instance_by_name(self, name: str) -> Optional[Scorer]: """ Get a registered scorer instance by name. @@ -83,7 +87,7 @@ def get_instance_by_name(self, name: str) -> Optional["Scorer"]: """ return self.get(name) - def _build_metadata(self, name: str, instance: "Scorer") -> ScorerIdentifier: + def _build_metadata(self, name: str, instance: Scorer) -> ScorerIdentifier: """ Build metadata for a scorer instance. diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index c430750f41..78d04c7e94 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -5,18 +5,14 @@ Target registry for managing PyRIT target instances. Targets are registered explicitly via initializers as pre-configured instances. - -NOTE: This is a placeholder implementation. PR #1320 will add the full implementation. """ from __future__ import annotations import logging -from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import Identifier -from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.identifiers import TargetIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -27,15 +23,6 @@ logger = logging.getLogger(__name__) -# Placeholder identifier type until proper TargetIdentifier is defined -# TODO: Replace with TargetIdentifier when available -@dataclass(frozen=True) -class TargetIdentifier(Identifier): - """Temporary identifier type for targets.""" - - pass - - class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): """ Registry for managing available target instances. @@ -43,12 +30,10 @@ class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): This registry stores pre-configured PromptTarget instances (not classes). Targets are registered explicitly via initializers after being instantiated with their required parameters. - - NOTE: This is a placeholder. PR #1320 will add the full implementation. """ @classmethod - def get_registry_singleton(cls) -> "TargetRegistry": + def get_registry_singleton(cls) -> TargetRegistry: """ Get the singleton instance of the TargetRegistry. @@ -59,7 +44,7 @@ def get_registry_singleton(cls) -> "TargetRegistry": def register_instance( self, - target: "PromptTarget", + target: PromptTarget, *, name: Optional[str] = None, ) -> None: @@ -69,15 +54,15 @@ def register_instance( Args: target: The pre-configured target instance (not a class). name: Optional custom registry name. If not provided, - derived from class name (e.g., AzureOpenAIGPT4OChatTarget -> azure_openai_gpt4o_chat). + uses the target's identifier unique_name. """ if name is None: - name = class_name_to_snake_case(target.__class__.__name__, suffix="Target") + name = target.get_identifier().unique_name self.register(target, name=name) logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: + def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: """ Get a registered target instance by name. @@ -89,7 +74,7 @@ def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: """ return self.get(name) - def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetIdentifier: + def _build_metadata(self, name: str, instance: PromptTarget) -> TargetIdentifier: """ Build metadata for a target instance. @@ -98,11 +83,6 @@ def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetIdentifi instance: The target instance. Returns: - TargetIdentifier with basic info about the target. + TargetIdentifier from the target's get_identifier() method. """ - return TargetIdentifier( - class_name=instance.__class__.__name__, - class_module=instance.__class__.__module__, - class_description=f"Target: {name}", - identifier_type="instance", - ) + return instance.get_identifier() diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 6ccfab608a..0b213ed728 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -118,7 +118,7 @@ async def test_create_target_raises_for_invalid_type(self) -> None: await service.create_target(request) @pytest.mark.asyncio - async def test_create_target_success(self) -> None: + async def test_create_target_success(self, sqlite_instance) -> None: """Test successful target creation.""" service = TargetService() @@ -135,7 +135,7 @@ async def test_create_target_success(self) -> None: assert result.display_name == "My Text Target" @pytest.mark.asyncio - async def test_create_target_registers_in_registry(self) -> None: + async def test_create_target_registers_in_registry(self, sqlite_instance) -> None: """Test that create_target registers object in registry.""" service = TargetService() From cf018b9d18f146c4621b240d0361a2ad7856518b Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 5 Feb 2026 14:21:59 -0800 Subject: [PATCH 10/35] fix test, no default initializer --- frontend/dev.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/dev.py b/frontend/dev.py index b2f8043bf3..772d4a2996 100644 --- a/frontend/dev.py +++ b/frontend/dev.py @@ -82,7 +82,7 @@ def start_backend(initializers: list[str] | None = None): Args: initializers: Optional list of initializer names to run at startup. - Defaults to ["airt"] to load targets from environment variables. + If not specified, no initializers are run. """ print("🚀 Starting backend on port 8000...") @@ -93,9 +93,9 @@ def start_backend(initializers: list[str] | None = None): env = os.environ.copy() env["PYRIT_DEV_MODE"] = "true" - # Default to airt initializer if not specified + # Default to no initializers if initializers is None: - initializers = ["airt"] + initializers = [] # Build command using pyrit_backend CLI cmd = [ From edef8b2224d938f4871d8e1b655ac8d2a3d5077f Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Sat, 7 Feb 2026 12:26:34 -0800 Subject: [PATCH 11/35] Adding attack identifier --- pyrit/analytics/result_analysis.py | 2 +- pyrit/exceptions/exception_context.py | 10 +-- .../attack/component/conversation_manager.py | 6 +- pyrit/executor/attack/core/attack_strategy.py | 51 +++++++++++++++- .../attack/multi_turn/tree_of_attacks.py | 5 +- .../attack/printer/console_printer.py | 2 +- .../attack/printer/markdown_printer.py | 2 +- pyrit/executor/benchmark/fairness_bias.py | 6 +- pyrit/executor/core/strategy.py | 13 ---- pyrit/executor/promptgen/anecdoctor.py | 21 ++++++- pyrit/executor/promptgen/fuzzer/fuzzer.py | 21 ++++++- pyrit/executor/workflow/xpia.py | 23 ++++++- pyrit/identifiers/__init__.py | 5 +- pyrit/identifiers/attack_identifier.py | 61 +++++++++++++++++++ pyrit/identifiers/identifiable.py | 18 ------ pyrit/identifiers/identifier.py | 8 +-- pyrit/memory/azure_sql_memory.py | 2 +- pyrit/memory/memory_models.py | 8 +-- pyrit/memory/sqlite_memory.py | 2 +- pyrit/models/attack_result.py | 5 +- pyrit/models/message_piece.py | 14 +++-- pyrit/models/seeds/seed.py | 32 ++++++++++ pyrit/prompt_converter/prompt_converter.py | 1 - pyrit/prompt_normalizer/prompt_normalizer.py | 13 ++-- .../common/prompt_chat_target.py | 5 +- pyrit/prompt_target/common/prompt_target.py | 1 - pyrit/score/float_scale/float_scale_scorer.py | 5 +- pyrit/score/scorer.py | 10 +-- tests/unit/mocks.py | 4 +- 29 files changed, 267 insertions(+), 89 deletions(-) create mode 100644 pyrit/identifiers/attack_identifier.py diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 7db3c02e87..a6e260af39 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -62,7 +62,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats raise TypeError(f"Expected AttackResult, got {type(attack).__name__}: {attack!r}") outcome = attack.outcome - attack_type = attack.attack_identifier.get("type", "unknown") + attack_type = attack.attack_identifier.class_name if attack.attack_identifier else "unknown" if outcome == AttackOutcome.SUCCESS: overall_counts["successes"] += 1 diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index bce5b92382..8375763207 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Dict, Optional, Union -from pyrit.identifiers import Identifier +from pyrit.identifiers import AttackIdentifier, Identifier class ComponentRole(Enum): @@ -61,8 +61,8 @@ class ExecutionContext: # The attack strategy class name (e.g., "PromptSendingAttack") attack_strategy_name: Optional[str] = None - # The identifier from the attack strategy's get_identifier() - attack_identifier: Optional[Dict[str, Any]] = None + # The identifier for the attack strategy + attack_identifier: Optional[Union["AttackIdentifier", Dict[str, Any]]] = None # The identifier from the component's get_identifier() (target, scorer, etc.) component_identifier: Optional[Dict[str, Any]] = None @@ -192,7 +192,7 @@ def execution_context( *, component_role: ComponentRole, attack_strategy_name: Optional[str] = None, - attack_identifier: Optional[Dict[str, Any]] = None, + attack_identifier: Optional[Union[AttackIdentifier, Dict[str, Any]]] = None, component_identifier: Optional[Union[Identifier, Dict[str, Any]]] = None, objective_target_conversation_id: Optional[str] = None, objective: Optional[str] = None, @@ -203,7 +203,7 @@ def execution_context( Args: component_role: The role of the component being executed. attack_strategy_name: The name of the attack strategy class. - attack_identifier: The identifier from attack.get_identifier(). + attack_identifier: The attack identifier. Can be an AttackIdentifier or a dict. component_identifier: The identifier from component.get_identifier(). Can be an Identifier object or a dict (legacy format). objective_target_conversation_id: The objective target conversation ID if available. diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 75228c0f2d..6cffcf825a 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -10,7 +10,7 @@ from pyrit.executor.attack.component.prepended_conversation_config import ( PrependedConversationConfig, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer from pyrit.models import ChatMessageRole, Message, MessagePiece, Score @@ -54,7 +54,7 @@ def get_adversarial_chat_messages( prepended_conversation: List[Message], *, adversarial_chat_conversation_id: str, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, adversarial_chat_target_identifier: Union[TargetIdentifier, Dict[str, Any]], labels: Optional[Dict[str, str]] = None, ) -> List[Message]: @@ -183,7 +183,7 @@ class ConversationManager: def __init__( self, *, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, prompt_normalizer: Optional[PromptNormalizer] = None, ): """ diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 98fc4c6abe..802a10848e 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -20,6 +20,7 @@ StrategyEventData, StrategyEventHandler, ) +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory.central_memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -224,7 +225,7 @@ def _log_attack_outcome(self, result: AttackResult) -> None: self._logger.info(message) -class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], ABC): +class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Identifiable[AttackIdentifier], ABC): """ Abstract base class for attack strategies. Defines the interface for executing attacks and handling results. @@ -258,6 +259,45 @@ def __init__( ) self._objective_target = objective_target self._params_type = params_type + self._request_converters: list = [] + self._response_converters: list = [] + + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this attack strategy. + + Captures the objective target, optional scorer, and converter pipeline. + This is the *stable* strategy-level identifier that does not change + between calls to ``execute_async``. + + Returns: + AttackIdentifier: The constructed identifier. + """ + # Get target identifier + objective_target_identifier = self.get_objective_target().get_identifier() + + # Get scorer identifier if present + scorer_identifier = None + scoring_config = self.get_attack_scoring_config() + if scoring_config and scoring_config.objective_scorer: + scorer_identifier = scoring_config.objective_scorer.get_identifier() + + # Get request converter identifiers if present + converter_identifiers = None + if self._request_converters: + converter_identifiers = [ + converter.get_identifier() + for config in self._request_converters + for converter in config.converters + ] + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + objective_scorer_identifier=scorer_identifier, + request_converter_identifiers=converter_identifiers or None, + ) @property def params_type(self) -> Type[AttackParameters]: @@ -291,6 +331,15 @@ def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: """ return None + def get_request_converters(self) -> list: + """ + Get request converter configurations used by this strategy. + + Returns: + list: The list of request PromptConverterConfiguration objects. + """ + return self._request_converters + @overload async def execute_async( self, diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index cfad81d3ab..48d6025a47 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -37,6 +37,7 @@ ) from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.executor.attack.multi_turn import MultiTurnAttackContext +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -267,7 +268,7 @@ def __init__( request_converters: List[PromptConverterConfiguration], response_converters: List[PromptConverterConfiguration], auxiliary_scorers: Optional[List[Scorer]], - attack_id: dict[str, str], + attack_id: AttackIdentifier, attack_strategy_name: str, memory_labels: Optional[dict[str, str]] = None, parent_id: Optional[str] = None, @@ -289,7 +290,7 @@ def __init__( request_converters (List[PromptConverterConfiguration]): Converters for request normalization response_converters (List[PromptConverterConfiguration]): Converters for response normalization auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for the response - attack_id (dict[str, str]): Unique identifier for the attack. + attack_id (AttackIdentifier): Unique identifier for the attack. attack_strategy_name (str): Name of the attack strategy for execution context. memory_labels (Optional[dict[str, str]]): Labels for memory storage. parent_id (Optional[str]): ID of the parent node, if this is a child node diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 7d7110d0ae..0dd162613c 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -259,7 +259,7 @@ async def print_summary_async(self, result: AttackResult) -> None: # Extract attack type name from attack_identifier attack_type = "Unknown" if isinstance(result.attack_identifier, dict) and "__type__" in result.attack_identifier: - attack_type = result.attack_identifier["__type__"] + attack_type = result.attack_identifier.class_name elif isinstance(result.attack_identifier, str): attack_type = result.attack_identifier diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 27838a46c2..e62a80cabc 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -493,7 +493,7 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> List[str]: markdown_lines.append("|-------|-------|") markdown_lines.append(f"| **Objective** | {result.objective} |") - attack_type = result.attack_identifier.get("__type__", "Unknown") + attack_type = result.attack_identifier.class_name if result.attack_identifier else "Unknown" markdown_lines.append(f"| **Attack Type** | `{attack_type}` |") markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |") diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index b894757eba..3e6bc8b785 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -17,6 +17,7 @@ PromptSendingAttack, ) from pyrit.executor.core import Strategy, StrategyContext +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -195,7 +196,10 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta conversation_id=str(uuid.UUID(int=0)), objective=context.generated_objective, outcome=AttackOutcome.FAILURE, - attack_identifier=self.get_identifier(), + attack_identifier=AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + ), ) return last_attack_result diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 7fc48a4173..1ef0f94cff 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -176,19 +176,6 @@ def __init__( default_values.get_non_required_value(env_var_name="GLOBAL_MEMORY_LABELS") or "{}" ) - def get_identifier(self) -> Dict[str, str]: - """ - Get a serializable identifier for the strategy instance. - - Returns: - dict: A dictionary containing the type, module, and unique ID of the strategy. - """ - return { - "__type__": self.__class__.__name__, - "__module__": self.__class__.__module__, - "id": str(self._id), - } - def _register_event_handler(self, event_handler: StrategyEventHandler[StrategyContextT, StrategyResultT]) -> None: """ Register an event handler for strategy events. diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index b627e477d8..82ecb25e5f 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -19,6 +19,7 @@ PromptGeneratorStrategyContext, PromptGeneratorStrategyResult, ) +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.models import ( Message, ) @@ -67,7 +68,10 @@ class AnecdoctorResult(PromptGeneratorStrategyResult): generated_content: Message -class AnecdoctorGenerator(PromptGeneratorStrategy[AnecdoctorContext, AnecdoctorResult]): +class AnecdoctorGenerator( + PromptGeneratorStrategy[AnecdoctorContext, AnecdoctorResult], + Identifiable[AttackIdentifier], +): """ Implementation of the Anecdoctor prompt generation strategy. @@ -131,6 +135,21 @@ def __init__( else: self._system_prompt_template = self._load_prompt_from_yaml(yaml_filename=self._ANECDOCTOR_USE_FEWSHOT_YAML) + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this prompt generator. + + Returns: + AttackIdentifier: The constructed identifier. + """ + objective_target_identifier = self._objective_target.get_identifier() + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + ) + def _validate_context(self, *, context: AnecdoctorContext) -> None: """ Validate the context before executing the prompt generation. diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 7949a398a9..93360a16dd 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -24,6 +24,7 @@ PromptGeneratorStrategyResult, ) from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory import CentralMemory from pyrit.models import ( Message, @@ -492,7 +493,10 @@ def print_templates_only(self, result: FuzzerResult) -> None: print("No successful templates found.") -class FuzzerGenerator(PromptGeneratorStrategy[FuzzerContext, FuzzerResult]): +class FuzzerGenerator( + PromptGeneratorStrategy[FuzzerContext, FuzzerResult], + Identifiable[AttackIdentifier], +): """ Implementation of the Fuzzer prompt generation strategy using Monte Carlo Tree Search (MCTS). @@ -675,6 +679,21 @@ def __init__( # Initialize utilities self._prompt_normalizer = prompt_normalizer or PromptNormalizer() + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this prompt generator. + + Returns: + AttackIdentifier: The constructed identifier. + """ + objective_target_identifier = self._objective_target.get_identifier() + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + ) + def _validate_inputs( self, *, diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 1579f7e868..2f5baf83f4 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -14,6 +14,7 @@ WorkflowResult, WorkflowStrategy, ) +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory import CentralMemory from pyrit.models import ( Message, @@ -127,7 +128,7 @@ def status(self) -> XPIAStatus: return XPIAStatus.SUCCESS if self.success else XPIAStatus.FAILURE -class XPIAWorkflow(WorkflowStrategy[XPIAContext, XPIAResult]): +class XPIAWorkflow(WorkflowStrategy[XPIAContext, XPIAResult], Identifiable[AttackIdentifier]): """ Implementation of Cross-Domain Prompt Injection Attack (XPIA) workflow. @@ -174,6 +175,26 @@ def __init__( self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._memory = CentralMemory.get_memory_instance() + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this XPIA workflow. + + Returns: + AttackIdentifier: The constructed identifier. + """ + objective_target_identifier = self._attack_setup_target.get_identifier() + + scorer_identifier = None + if self._scorer: + scorer_identifier = self._scorer.get_identifier() + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + objective_scorer_identifier=scorer_identifier, + ) + def _validate_context(self, *, context: XPIAContext) -> None: """ Validate the XPIA context before execution. diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index 30c501894c..0c8fe13cac 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -3,12 +3,13 @@ """Identifiers module for PyRIT components.""" +from pyrit.identifiers.attack_identifier import AttackIdentifier from pyrit.identifiers.class_name_utils import ( class_name_to_snake_case, snake_case_to_class_name, ) from pyrit.identifiers.converter_identifier import ConverterIdentifier -from pyrit.identifiers.identifiable import Identifiable, IdentifierT, LegacyIdentifiable +from pyrit.identifiers.identifiable import Identifiable, IdentifierT from pyrit.identifiers.identifier import ( Identifier, IdentifierType, @@ -17,13 +18,13 @@ from pyrit.identifiers.target_identifier import TargetIdentifier __all__ = [ + "AttackIdentifier", "class_name_to_snake_case", "ConverterIdentifier", "Identifiable", "Identifier", "IdentifierT", "IdentifierType", - "LegacyIdentifiable", "ScorerIdentifier", "snake_case_to_class_name", "TargetIdentifier", diff --git a/pyrit/identifiers/attack_identifier.py b/pyrit/identifiers/attack_identifier.py new file mode 100644 index 0000000000..92c8cf103d --- /dev/null +++ b/pyrit/identifiers/attack_identifier.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type, cast + +from pyrit.identifiers.converter_identifier import ConverterIdentifier +from pyrit.identifiers.identifier import Identifier +from pyrit.identifiers.scorer_identifier import ScorerIdentifier +from pyrit.identifiers.target_identifier import TargetIdentifier + + +@dataclass(frozen=True) +class AttackIdentifier(Identifier): + """ + Typed identifier for an attack strategy instance. + + Captures the configuration that makes one attack strategy meaningfully + different from another: the objective target, optional scorer, and converter + pipeline. These do not change between calls to ``execute_async``. + """ + + objective_target_identifier: Optional[TargetIdentifier] = None + objective_scorer_identifier: Optional[ScorerIdentifier] = None + request_converter_identifiers: Optional[List[ConverterIdentifier]] = None + + # Additional attack-specific params for subclass flexibility + attack_specific_params: Optional[Dict[str, Any]] = None + + @classmethod + def from_dict(cls: Type["AttackIdentifier"], data: dict[str, Any]) -> "AttackIdentifier": + """ + Deserialize an AttackIdentifier from a dictionary. + + Handles nested sub-identifiers (target, scorer, converters) by + recursively calling their own ``from_dict`` implementations. + + Args: + data: Dictionary containing the serialized identifier fields. + + Returns: + AttackIdentifier: The deserialized identifier. + """ + data = dict(data) + + if "objective_target_identifier" in data and isinstance(data["objective_target_identifier"], dict): + data["objective_target_identifier"] = TargetIdentifier.from_dict(data["objective_target_identifier"]) + + if "objective_scorer_identifier" in data and isinstance(data["objective_scorer_identifier"], dict): + data["objective_scorer_identifier"] = ScorerIdentifier.from_dict(data["objective_scorer_identifier"]) + + if "request_converter_identifiers" in data and data["request_converter_identifiers"] is not None: + data["request_converter_identifiers"] = [ + ConverterIdentifier.from_dict(c) if isinstance(c, dict) else c + for c in data["request_converter_identifiers"] + ] + + result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] + return cast(AttackIdentifier, result) diff --git a/pyrit/identifiers/identifiable.py b/pyrit/identifiers/identifiable.py index 94108e6eaf..7ef54c7f19 100644 --- a/pyrit/identifiers/identifiable.py +++ b/pyrit/identifiers/identifiable.py @@ -12,24 +12,6 @@ IdentifierT = TypeVar("IdentifierT", bound=Identifier) -class LegacyIdentifiable(ABC): - """ - Deprecated legacy interface for objects that can provide an identifier dictionary. - - This interface will eventually be replaced by Identifier dataclass. - Classes implementing this interface should return a dict describing their identity. - """ - - @abstractmethod - def get_identifier(self) -> dict[str, str]: - """Return a dictionary describing this object's identity.""" - pass - - def __str__(self) -> str: - """Return string representation of the identifier.""" - return f"{self.get_identifier()}" - - class Identifiable(ABC, Generic[IdentifierT]): """ Abstract base class for objects that can provide a typed identifier. diff --git a/pyrit/identifiers/identifier.py b/pyrit/identifiers/identifier.py index 64a754f0ae..ff7ee8832b 100644 --- a/pyrit/identifiers/identifier.py +++ b/pyrit/identifiers/identifier.py @@ -135,12 +135,12 @@ class Identifier: All component-specific identifier types should extend this with additional fields. """ - class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") - class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") + class_name: str + class_module: str # Fields excluded from storage (STORAGE auto-expands to include HASH) - class_description: str = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) - identifier_type: IdentifierType = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + class_description: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + identifier_type: IdentifierType = field(default="instance", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) # Auto-computed fields snake_class_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 5e3817c481..f53c4b6cd4 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -257,7 +257,7 @@ def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: Returns: Any: SQLAlchemy text condition with bound parameter. """ - return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.id') = :json_id").bindparams( + return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams( json_id=str(attack_id) ) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index bc8b9f5f36..5db7e4d8ae 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -32,7 +32,7 @@ import pyrit from pyrit.common.utils import to_sha256 -from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -220,7 +220,7 @@ def __init__(self, *, entry: MessagePiece): self.prompt_target_identifier = ( entry.prompt_target_identifier.to_dict() if entry.prompt_target_identifier else {} ) - self.attack_identifier = entry.attack_identifier + self.attack_identifier = entry.attack_identifier.to_dict() if entry.attack_identifier else {} self.original_value = entry.original_value self.original_value_data_type = entry.original_value_data_type # type: ignore @@ -732,7 +732,7 @@ def __init__(self, *, entry: AttackResult): self.id = uuid.uuid4() self.conversation_id = entry.conversation_id self.objective = entry.objective - self.attack_identifier = entry.attack_identifier + self.attack_identifier = entry.attack_identifier.to_dict() if entry.attack_identifier else {} self.objective_sha256 = to_sha256(entry.objective) # Use helper method for UUID conversions @@ -833,7 +833,7 @@ def get_attack_result(self) -> AttackResult: return AttackResult( conversation_id=self.conversation_id, objective=self.objective, - attack_identifier=self.attack_identifier, + attack_identifier=AttackIdentifier.from_dict(self.attack_identifier) if self.attack_identifier else None, last_response=self.last_response.get_message_piece() if self.last_response else None, last_score=self.last_score.get_score() if self.last_score else None, executed_turns=self.executed_turns, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 30a251cf72..bca6a21817 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -163,7 +163,7 @@ def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: Returns: Any: A SQLAlchemy text condition with bound parameters. """ - return text("JSON_EXTRACT(attack_identifier, '$.id') = :attack_id").bindparams(attack_id=str(attack_id)) + return text("JSON_EXTRACT(attack_identifier, '$.hash') = :attack_id").bindparams(attack_id=str(attack_id)) def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index dc9e3a1a9e..7f92612f59 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional, TypeVar +from pyrit.identifiers import AttackIdentifier from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.message_piece import MessagePiece from pyrit.models.score import Score @@ -41,8 +42,8 @@ class AttackResult(StrategyResult): # Natural-language description of the attacker's objective objective: str - # Identifier of the attack (e.g., name, module) - attack_identifier: dict[str, str] + # Identifier of the attack strategy that produced this result + attack_identifier: Optional[AttackIdentifier] = None # Evidence # Model response generated in the final turn of the attack diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 460529475f..4a9180d79f 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, get_args from uuid import uuid4 -from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.score import Score @@ -39,7 +39,7 @@ def __init__( prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, converter_identifiers: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = None, prompt_target_identifier: Optional[Union[TargetIdentifier, Dict[str, Any]]] = None, - attack_identifier: Optional[Dict[str, str]] = None, + attack_identifier: Optional[Union[AttackIdentifier, Dict[str, str]]] = None, scorer_identifier: Optional[Union[ScorerIdentifier, Dict[str, str]]] = None, original_value_data_type: PromptDataType = "text", converted_value_data_type: Optional[PromptDataType] = None, @@ -71,7 +71,8 @@ def __init__( converter_identifiers: The converter identifiers for the prompt. Can be ConverterIdentifier objects or dicts (deprecated, will be removed in 0.14.0). Defaults to None. prompt_target_identifier: The target identifier for the prompt. Defaults to None. - attack_identifier: The attack identifier for the prompt. Defaults to None. + attack_identifier: The attack identifier for the prompt. Can be an AttackIdentifier + object or a dict (deprecated, will be removed in 0.14.0). Defaults to None. scorer_identifier: The scorer identifier for the prompt. Can be a ScorerIdentifier or a dict (deprecated, will be removed in 0.13.0). Defaults to None. original_value_data_type: The data type of the original prompt (text, image). Defaults to "text". @@ -118,7 +119,10 @@ def __init__( TargetIdentifier.normalize(prompt_target_identifier) if prompt_target_identifier else None ) - self.attack_identifier = attack_identifier or {} + # Handle attack_identifier: normalize to AttackIdentifier (handles dict with deprecation warning) + self.attack_identifier: Optional[AttackIdentifier] = ( + AttackIdentifier.normalize(attack_identifier) if attack_identifier else None + ) # Handle scorer_identifier: normalize to ScorerIdentifier (handles dict with deprecation warning) self.scorer_identifier: Optional[ScorerIdentifier] = ( @@ -283,7 +287,7 @@ def to_dict(self) -> dict[str, object]: "prompt_target_identifier": ( self.prompt_target_identifier.to_dict() if self.prompt_target_identifier else None ), - "attack_identifier": self.attack_identifier, + "attack_identifier": self.attack_identifier.to_dict() if self.attack_identifier else None, "scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None, "original_value_data_type": self.original_value_data_type, "original_value": self.original_value, diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index f2a6bc4697..2cbe7c06cc 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -94,6 +94,38 @@ class Seed(YamlLoadable): # Alias for the prompt group prompt_group_alias: Optional[str] = None + @property + def is_general_attack_strategy(self) -> bool: + """ + Whether this seed represents a general attack strategy dataset. + + When True, the seed data (prepended conversation, next message) is + considered part of the attack strategy identity and will be included + in the attack identifier hash. When False (default), the data is + considered objective-specific and is not included. + + This value is stored in the ``metadata`` dict under the key + ``"is_general_attack_strategy"``. + + Returns: + bool: True if the seed is a general attack strategy dataset. + """ + if self.metadata and "is_general_attack_strategy" in self.metadata: + return bool(self.metadata["is_general_attack_strategy"]) + return False + + @is_general_attack_strategy.setter + def is_general_attack_strategy(self, value: bool) -> None: + """ + Set whether this seed represents a general attack strategy dataset. + + Args: + value: True to mark as a general attack strategy dataset. + """ + if self.metadata is None: + self.metadata = {} + self.metadata["is_general_attack_strategy"] = int(value) + @property def data_type(self) -> PromptDataType: """ diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 756d3972e6..9af7ef9218 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -221,7 +221,6 @@ def _create_identifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=self.__class__.__doc__ or "", - identifier_type="instance", supported_input_types=self.SUPPORTED_INPUT_TYPES, supported_output_types=self.SUPPORTED_OUTPUT_TYPES, sub_identifier=sub_identifier, diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 8fd8aaaba0..00f2f0f578 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -14,6 +14,7 @@ execution_context, get_execution_context, ) +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( Message, @@ -53,7 +54,7 @@ async def send_prompt_async( request_converter_configurations: list[PromptConverterConfiguration] = [], response_converter_configurations: list[PromptConverterConfiguration] = [], labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, ) -> Message: """ Send a single request to a target. @@ -67,7 +68,7 @@ async def send_prompt_async( response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the response. Defaults to an empty list. labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None. - attack_identifier (Optional[dict[str, str]], optional): Identifier for the attack. Defaults to + attack_identifier (Optional[AttackIdentifier], optional): Identifier for the attack. Defaults to None. Raises: @@ -155,7 +156,7 @@ async def send_prompt_batch_to_target_async( requests: list[NormalizerRequest], target: PromptTarget, labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, batch_size: int = 10, ) -> list[Message]: """ @@ -166,7 +167,7 @@ async def send_prompt_batch_to_target_async( target (PromptTarget): The target to which the prompts are sent. labels (Optional[dict[str, str]], optional): A dictionary of labels to be included with the request. Defaults to None. - attack_identifier (Optional[dict[str, str]], optional): A dictionary identifying the attack. + attack_identifier (Optional[AttackIdentifier], optional): The attack identifier. Defaults to None. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. @@ -274,7 +275,7 @@ async def add_prepended_conversation_to_memory( conversation_id: str, should_convert: bool = True, converter_configurations: Optional[list[PromptConverterConfiguration]] = None, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, prepended_conversation: Optional[list[Message]] = None, ) -> Optional[list[Message]]: """ @@ -285,7 +286,7 @@ async def add_prepended_conversation_to_memory( should_convert (bool): Whether to convert the prepended conversation converter_configurations (Optional[list[PromptConverterConfiguration]]): Configurations for converting the request - attack_identifier (Optional[dict[str, str]]): Identifier for the attack + attack_identifier (Optional[AttackIdentifier]): Identifier for the attack prepended_conversation (Optional[list[Message]]): The conversation to prepend Returns: diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index 5eac0209f5..b2c51d23ac 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. import abc -from typing import Optional +from typing import Optional, Union +from pyrit.identifiers import AttackIdentifier from pyrit.models import MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -51,7 +52,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: """ diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 653d008e65..29ba2cb47c 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -129,7 +129,6 @@ def _create_identifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", - identifier_type="instance", endpoint=self._endpoint, model_name=model_name, temperature=temperature, diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index a1150b88e3..b4ac4ddcfb 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Union from uuid import UUID from pyrit.exceptions.exception_classes import InvalidJsonException +from pyrit.identifiers import AttackIdentifier from pyrit.models import PromptDataType, Score, UnvalidatedScore from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.score.scorer import Scorer @@ -75,7 +76,7 @@ async def _score_value_with_llm( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[Dict[str, str]] = None, + attack_identifier: Optional[Union[AttackIdentifier, Dict[str, str]]] = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 83cd795ecf..d973c1d445 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -26,7 +26,7 @@ pyrit_json_retry, remove_markdown_json, ) -from pyrit.identifiers import Identifiable, ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, Identifiable, ScorerIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( ChatMessageRole, @@ -145,7 +145,6 @@ def _create_identifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", - identifier_type="instance", scorer_type=self.scorer_type, system_prompt_template=system_prompt_template, user_prompt_template=user_prompt_template, @@ -521,7 +520,7 @@ async def _score_value_with_llm( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[Dict[str, str]] = None, + attack_identifier: Optional[Union[AttackIdentifier, Dict[str, str]]] = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -555,7 +554,7 @@ async def _score_value_with_llm( Defaults to "metadata". category_output_key (str): The key in the JSON response that contains the category. Defaults to "category". - attack_identifier (Optional[Dict[str, str]]): A dictionary containing attack-specific identifiers. + attack_identifier (Optional[Union[AttackIdentifier, Dict[str, str]]]): The attack identifier. Defaults to None. Returns: @@ -569,9 +568,6 @@ async def _score_value_with_llm( """ conversation_id = str(uuid.uuid4()) - if attack_identifier: - attack_identifier["scored_prompt_id"] = str(scored_prompt_id) - prompt_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ee1fdd6bee..dd2b0f0174 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,7 +9,7 @@ from typing import Generator, MutableSequence, Optional, Sequence from unittest.mock import MagicMock, patch -from pyrit.identifiers import ScorerIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -100,7 +100,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: self.system_prompt = system_prompt From 201b1d62b5b097550f1f323de1b38d6ce19bdb81 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 9 Feb 2026 09:23:50 -0800 Subject: [PATCH 12/35] fixing all tests --- doc/api.rst | 1 + pyrit/models/message_piece.py | 3 +- pyrit/models/seeds/seed.py | 32 ----- .../class_registries/initializer_registry.py | 8 +- .../class_registries/scenario_registry.py | 12 +- tests/unit/analytics/test_result_analysis.py | 5 +- tests/unit/docs/test_api_documentation.py | 2 +- .../component/test_conversation_manager.py | 117 +++++++++--------- .../attack/core/test_attack_strategy.py | 7 -- .../attack/core/test_markdown_printer.py | 4 +- .../attack/multi_turn/test_chunked_request.py | 56 ++++++--- .../multi_turn/test_multi_prompt_sending.py | 19 ++- .../attack/multi_turn/test_red_teaming.py | 8 +- .../single_turn/test_context_compliance.py | 14 ++- .../attack/single_turn/test_flip_attack.py | 14 ++- .../single_turn/test_many_shot_jailbreak.py | 14 ++- .../attack/single_turn/test_prompt_sending.py | 14 ++- .../attack/single_turn/test_role_play.py | 1 + .../attack/single_turn/test_skeleton_key.py | 1 + .../executor/benchmark/test_fairness_bias.py | 14 ++- .../benchmark/test_question_answering.py | 13 ++ .../executor/promptgen/test_anecdoctor.py | 14 +++ tests/unit/executor/workflow/test_xpia.py | 24 ++++ tests/unit/identifiers/test_identifiers.py | 28 +---- .../test_interface_attack_results.py | 36 +----- .../memory_interface/test_interface_export.py | 2 +- .../test_interface_prompts.py | 44 +++---- .../test_interface_scenario_results.py | 1 - .../memory_interface/test_interface_scores.py | 8 +- tests/unit/mocks.py | 51 ++++++-- tests/unit/models/test_message_piece.py | 12 +- tests/unit/registry/test_base.py | 4 +- .../registry/test_base_instance_registry.py | 4 +- tests/unit/scenarios/test_atomic_attack.py | 7 -- tests/unit/scenarios/test_scenario.py | 15 --- .../test_scenario_partial_results.py | 7 -- tests/unit/scenarios/test_scenario_retry.py | 1 - tests/unit/score/test_scorer.py | 9 +- tests/unit/target/test_http_target.py | 24 +++- 39 files changed, 355 insertions(+), 295 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 780ae04206..523efb0ad3 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -272,6 +272,7 @@ API Reference :toctree: _autosummary/ class_name_to_snake_case + AttackIdentifier ConverterIdentifier Identifiable Identifier diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 4a9180d79f..f07b045318 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -71,8 +71,7 @@ def __init__( converter_identifiers: The converter identifiers for the prompt. Can be ConverterIdentifier objects or dicts (deprecated, will be removed in 0.14.0). Defaults to None. prompt_target_identifier: The target identifier for the prompt. Defaults to None. - attack_identifier: The attack identifier for the prompt. Can be an AttackIdentifier - object or a dict (deprecated, will be removed in 0.14.0). Defaults to None. + attack_identifier: The attack identifier for the prompt. Defaults to None. scorer_identifier: The scorer identifier for the prompt. Can be a ScorerIdentifier or a dict (deprecated, will be removed in 0.13.0). Defaults to None. original_value_data_type: The data type of the original prompt (text, image). Defaults to "text". diff --git a/pyrit/models/seeds/seed.py b/pyrit/models/seeds/seed.py index 2cbe7c06cc..f2a6bc4697 100644 --- a/pyrit/models/seeds/seed.py +++ b/pyrit/models/seeds/seed.py @@ -94,38 +94,6 @@ class Seed(YamlLoadable): # Alias for the prompt group prompt_group_alias: Optional[str] = None - @property - def is_general_attack_strategy(self) -> bool: - """ - Whether this seed represents a general attack strategy dataset. - - When True, the seed data (prepended conversation, next message) is - considered part of the attack strategy identity and will be included - in the attack identifier hash. When False (default), the data is - considered objective-specific and is not included. - - This value is stored in the ``metadata`` dict under the key - ``"is_general_attack_strategy"``. - - Returns: - bool: True if the seed is a general attack strategy dataset. - """ - if self.metadata and "is_general_attack_strategy" in self.metadata: - return bool(self.metadata["is_general_attack_strategy"]) - return False - - @is_general_attack_strategy.setter - def is_general_attack_strategy(self, value: bool) -> None: - """ - Set whether this seed represents a general attack strategy dataset. - - Args: - value: True to mark as a general attack strategy dataset. - """ - if self.metadata is None: - self.metadata = {} - self.metadata["is_general_attack_strategy"] = int(value) - @property def data_type(self) -> PromptDataType: """ diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index a334e87e72..bf6443afa6 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -12,7 +12,7 @@ import importlib.util import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional @@ -41,9 +41,9 @@ class InitializerMetadata(Identifier): Use get_class() to get the actual class. """ - display_name: str - required_env_vars: tuple[str, ...] - execution_order: int + display_name: str = field(kw_only=True) + required_env_vars: tuple[str, ...] = field(kw_only=True) + execution_order: int = field(kw_only=True) class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index f95ad93986..5489c774f2 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -40,11 +40,11 @@ class ScenarioMetadata(Identifier): Use get_class() to get the actual class. """ - default_strategy: str - all_strategies: tuple[str, ...] - aggregate_strategies: tuple[str, ...] - default_datasets: tuple[str, ...] - max_dataset_size: Optional[int] + default_strategy: str = field(kw_only=True) + all_strategies: tuple[str, ...] = field(kw_only=True) + aggregate_strategies: tuple[str, ...] = field(kw_only=True) + default_datasets: tuple[str, ...] = field(kw_only=True) + max_dataset_size: Optional[int] = field(kw_only=True) class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 44b2a56e8a..812126b79e 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -4,6 +4,7 @@ import pytest from pyrit.analytics.result_analysis import AttackStats, analyze_results +from pyrit.identifiers import AttackIdentifier from pyrit.models import AttackOutcome, AttackResult @@ -16,9 +17,9 @@ def make_attack( """ Minimal valid AttackResult for analytics tests. """ - attack_identifier: dict[str, str] = {} + attack_identifier: AttackIdentifier | None = None if attack_type is not None: - attack_identifier["type"] = attack_type + attack_identifier = AttackIdentifier(class_name=attack_type, class_module="tests.unit.analytics") return AttackResult( conversation_id=conversation_id, diff --git a/tests/unit/docs/test_api_documentation.py b/tests/unit/docs/test_api_documentation.py index f67425aa19..48db273589 100644 --- a/tests/unit/docs/test_api_documentation.py +++ b/tests/unit/docs/test_api_documentation.py @@ -119,7 +119,7 @@ def get_module_exports(module_path: str) -> Set[str]: "exclude": set(), }, "pyrit.identifiers": { - "exclude": {"LegacyIdentifiable"}, + "exclude": set(), }, } diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index b874169d44..251fd023d6 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -34,7 +34,7 @@ ) from pyrit.executor.attack.core import AttackContext from pyrit.executor.attack.core.attack_parameters import AttackParameters -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget, PromptTarget @@ -68,13 +68,12 @@ class _TestAttackContext(AttackContext): @pytest.fixture -def attack_identifier() -> Dict[str, str]: +def attack_identifier() -> AttackIdentifier: """Create a sample attack identifier.""" - return { - "__type__": "TestAttack", - "__module__": "pyrit.executor.attack.test_attack", - "id": str(uuid.uuid4()), - } + return AttackIdentifier( + class_name="TestAttack", + class_module="pyrit.executor.attack.test_attack", + ) @pytest.fixture @@ -246,7 +245,7 @@ def test_swaps_user_to_assistant(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -262,7 +261,7 @@ def test_swaps_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -281,7 +280,7 @@ def test_swaps_simulated_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -300,7 +299,7 @@ def test_skips_system_messages(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -317,7 +316,7 @@ def test_assigns_new_uuids(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -339,7 +338,7 @@ def test_preserves_message_content(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -351,7 +350,7 @@ def test_empty_prepended_conversation(self) -> None: result = get_adversarial_chat_messages( [], adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, ) @@ -366,7 +365,7 @@ def test_applies_labels(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), adversarial_chat_target_identifier={"id": "adversarial_target"}, labels=labels, ) @@ -476,7 +475,7 @@ def test_with_custom_values(self, sample_score: Score) -> None: class TestConversationManagerInitialization: """Tests for ConversationManager initialization.""" - def test_init_with_required_parameters(self, attack_identifier: Dict[str, str]) -> None: + def test_init_with_required_parameters(self, attack_identifier: AttackIdentifier) -> None: """Test initialization with only required parameters.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -485,7 +484,7 @@ def test_init_with_required_parameters(self, attack_identifier: Dict[str, str]) assert manager._memory is not None def test_init_with_custom_prompt_normalizer( - self, attack_identifier: Dict[str, str], mock_prompt_normalizer: MagicMock + self, attack_identifier: AttackIdentifier, mock_prompt_normalizer: MagicMock ) -> None: """Test initialization with a custom prompt normalizer.""" manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_prompt_normalizer) @@ -502,7 +501,7 @@ def test_init_with_custom_prompt_normalizer( class TestConversationRetrieval: """Tests for conversation retrieval methods.""" - def test_get_conversation_returns_empty_list_when_no_messages(self, attack_identifier: Dict[str, str]) -> None: + def test_get_conversation_returns_empty_list_when_no_messages(self, attack_identifier: AttackIdentifier) -> None: """Test get_conversation returns empty list for non-existent conversation.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -512,7 +511,7 @@ def test_get_conversation_returns_empty_list_when_no_messages(self, attack_ident assert result == [] def test_get_conversation_returns_messages_in_order( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_conversation returns messages in order.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -530,7 +529,7 @@ def test_get_conversation_returns_messages_in_order( assert result[0].message_pieces[0].api_role == "user" assert result[1].message_pieces[0].api_role == "assistant" - def test_get_last_message_returns_none_for_empty_conversation(self, attack_identifier: Dict[str, str]) -> None: + def test_get_last_message_returns_none_for_empty_conversation(self, attack_identifier: AttackIdentifier) -> None: """Test get_last_message returns None for empty conversation.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -540,7 +539,7 @@ def test_get_last_message_returns_none_for_empty_conversation(self, attack_ident assert result is None def test_get_last_message_returns_last_piece( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_last_message returns the most recent message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -558,7 +557,7 @@ def test_get_last_message_returns_last_piece( assert result.api_role == "assistant" def test_get_last_message_with_role_filter( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_last_message with role filter returns correct message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -577,7 +576,7 @@ def test_get_last_message_with_role_filter( assert result.api_role == "user" def test_get_last_message_with_role_filter_returns_none_when_no_match( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_last_message returns None when no message matches role filter.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -605,7 +604,7 @@ class TestSystemPromptHandling: """Tests for system prompt functionality.""" def test_set_system_prompt_with_chat_target( - self, attack_identifier: Dict[str, str], mock_chat_target: MagicMock + self, attack_identifier: AttackIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt calls target's set_system_prompt method.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -628,7 +627,7 @@ def test_set_system_prompt_with_chat_target( ) def test_set_system_prompt_without_labels( - self, attack_identifier: Dict[str, str], mock_chat_target: MagicMock + self, attack_identifier: AttackIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt works without labels.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -658,7 +657,7 @@ class TestInitializeContext: @pytest.mark.asyncio async def test_raises_error_for_empty_conversation_id( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, mock_attack_context: AttackContext, ) -> None: @@ -675,7 +674,7 @@ async def test_raises_error_for_empty_conversation_id( @pytest.mark.asyncio async def test_returns_default_state_for_no_prepended_conversation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, mock_attack_context: AttackContext, ) -> None: @@ -696,7 +695,7 @@ async def test_returns_default_state_for_no_prepended_conversation( @pytest.mark.asyncio async def test_merges_memory_labels( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, ) -> None: """Test that memory_labels are merged with context labels.""" @@ -719,7 +718,7 @@ async def test_merges_memory_labels( @pytest.mark.asyncio async def test_adds_prepended_conversation_to_memory_for_chat_target( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -742,7 +741,7 @@ async def test_adds_prepended_conversation_to_memory_for_chat_target( @pytest.mark.asyncio async def test_converts_assistant_to_simulated_assistant( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_assistant_piece: MessagePiece, ) -> None: @@ -767,7 +766,7 @@ async def test_converts_assistant_to_simulated_assistant( @pytest.mark.asyncio async def test_normalizes_for_non_chat_target_by_default( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -793,7 +792,7 @@ async def test_normalizes_for_non_chat_target_by_default( @pytest.mark.asyncio async def test_normalizes_for_non_chat_target_when_configured( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -822,7 +821,7 @@ async def test_normalizes_for_non_chat_target_when_configured( @pytest.mark.asyncio async def test_returns_turn_count_for_multi_turn_attacks( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -845,7 +844,7 @@ async def test_returns_turn_count_for_multi_turn_attacks( @pytest.mark.asyncio async def test_multipart_message_extracts_scores_from_all_pieces( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_score: Score, ) -> None: @@ -919,7 +918,7 @@ async def test_multipart_message_extracts_scores_from_all_pieces( @pytest.mark.asyncio async def test_prepended_conversation_ignores_true_scores( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, ) -> None: """Test that prepended conversations only extract false scores, ignoring true scores. @@ -1023,7 +1022,7 @@ class TestPrependedConversationConfigSettings: @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_is_default( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1049,7 +1048,7 @@ async def test_non_chat_target_behavior_normalize_is_default( @pytest.mark.asyncio async def test_non_chat_target_behavior_raise_explicit( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1074,7 +1073,7 @@ async def test_non_chat_target_behavior_raise_explicit( @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_first_turn_creates_next_message( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1102,7 +1101,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_creates_next_messag @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existing_message( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1132,7 +1131,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existin @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_returns_empty_state( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1162,7 +1161,7 @@ async def test_non_chat_target_behavior_normalize_returns_empty_state( @pytest.mark.asyncio async def test_apply_converters_to_roles_default_applies_to_all( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1189,7 +1188,7 @@ async def test_apply_converters_to_roles_default_applies_to_all( @pytest.mark.asyncio async def test_apply_converters_to_roles_user_only( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1218,7 +1217,7 @@ async def test_apply_converters_to_roles_user_only( @pytest.mark.asyncio async def test_apply_converters_to_roles_assistant_only( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1247,7 +1246,7 @@ async def test_apply_converters_to_roles_assistant_only( @pytest.mark.asyncio async def test_apply_converters_to_roles_empty_list_skips_all( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1280,7 +1279,7 @@ async def test_apply_converters_to_roles_empty_list_skips_all( @pytest.mark.asyncio async def test_message_normalizer_default_uses_conversation_context_normalizer( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1308,7 +1307,7 @@ async def test_message_normalizer_default_uses_conversation_context_normalizer( @pytest.mark.asyncio async def test_message_normalizer_custom_normalizer_is_used( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1389,7 +1388,7 @@ def test_for_non_chat_target_with_custom_roles(self) -> None: @pytest.mark.asyncio async def test_chat_target_ignores_non_chat_target_behavior( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1421,7 +1420,7 @@ async def test_chat_target_ignores_non_chat_target_behavior( @pytest.mark.asyncio async def test_config_with_max_turns_validation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, ) -> None: """Test that config works correctly with max_turns validation.""" @@ -1471,7 +1470,7 @@ class TestAddPrependedConversationToMemory: @pytest.mark.asyncio async def test_adds_messages_to_memory( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_conversation: List[Message], ) -> None: """Test that messages are added to memory.""" @@ -1490,7 +1489,7 @@ async def test_adds_messages_to_memory( @pytest.mark.asyncio async def test_assigns_conversation_id_to_all_pieces( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_conversation: List[Message], ) -> None: """Test that conversation_id is assigned to all message pieces.""" @@ -1510,7 +1509,7 @@ async def test_assigns_conversation_id_to_all_pieces( @pytest.mark.asyncio async def test_assigns_attack_identifier_to_all_pieces( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_conversation: List[Message], ) -> None: """Test that attack_identifier is assigned to all message pieces.""" @@ -1530,7 +1529,7 @@ async def test_assigns_attack_identifier_to_all_pieces( @pytest.mark.asyncio async def test_raises_error_when_exceeds_max_turns( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece, ) -> None: @@ -1556,7 +1555,7 @@ async def test_raises_error_when_exceeds_max_turns( @pytest.mark.asyncio async def test_multipart_response_counts_as_one_turn( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, ) -> None: """Test that a multi-part assistant response counts as only one turn.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1595,7 +1594,7 @@ async def test_multipart_response_counts_as_one_turn( @pytest.mark.asyncio async def test_returns_zero_for_empty_conversation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, ) -> None: """Test that empty conversation returns 0 turns.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1611,7 +1610,7 @@ async def test_returns_zero_for_empty_conversation( @pytest.mark.asyncio async def test_applies_converters_when_provided( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_normalizer: MagicMock, sample_user_piece: MessagePiece, ) -> None: @@ -1633,7 +1632,7 @@ async def test_applies_converters_when_provided( @pytest.mark.asyncio async def test_handles_none_messages_gracefully( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, ) -> None: """Test that None messages are handled gracefully.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1659,7 +1658,7 @@ class TestEdgeCasesAndErrorHandling: @pytest.mark.asyncio async def test_preserves_piece_metadata( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_user_piece: MessagePiece, ) -> None: @@ -1688,7 +1687,7 @@ async def test_preserves_piece_metadata( @pytest.mark.asyncio async def test_preserves_original_and_converted_values( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_user_piece: MessagePiece, ) -> None: @@ -1716,7 +1715,7 @@ async def test_preserves_original_and_converted_values( @pytest.mark.asyncio async def test_handles_system_messages_in_prepended_conversation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_system_piece: MessagePiece, sample_user_piece: MessagePiece, diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 918b9c7531..9c98edb38c 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -69,7 +69,6 @@ def sample_attack_result(): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -112,7 +111,6 @@ async def _perform_async(self, *, context): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -144,7 +142,6 @@ async def _perform_async(self, *, context): return AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -176,7 +173,6 @@ async def _perform_async(self, *, context): return AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -208,7 +204,6 @@ async def _perform_async(self, *, context): return AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -497,7 +492,6 @@ async def _perform_async(self, *, context): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", executed_turns=1, @@ -542,7 +536,6 @@ async def _perform_async(self, *, context): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/core/test_markdown_printer.py index 8a3c430269..a66fcb4260 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/core/test_markdown_printer.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult, Message, MessagePiece, Score @@ -69,7 +69,7 @@ def sample_float_score(): def sample_attack_result(): return AttackResult( objective="Test objective", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), conversation_id="test-conv-123", executed_turns=3, execution_time_ms=1500, diff --git a/tests/unit/executor/attack/multi_turn/test_chunked_request.py b/tests/unit/executor/attack/multi_turn/test_chunked_request.py index b3bf1e8949..3c6736f586 100644 --- a/tests/unit/executor/attack/multi_turn/test_chunked_request.py +++ b/tests/unit/executor/attack/multi_turn/test_chunked_request.py @@ -8,10 +8,13 @@ during Crucible CTF red teaming exercises using PyRIT. """ -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest +from pyrit.identifiers import TargetIdentifier +from pyrit.prompt_target import PromptTarget + from pyrit.executor.attack.core.attack_parameters import AttackParameters from pyrit.executor.attack.multi_turn import ( ChunkedRequestAttack, @@ -19,6 +22,23 @@ ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _make_mock_target(): + """Create a mock target with proper get_identifier.""" + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("MockTarget") + return target + + class TestChunkedRequestAttackContext: """Test the ChunkedRequestAttackContext dataclass.""" @@ -46,7 +66,7 @@ class TestChunkedRequestAttack: def test_init_default_values(self): """Test initialization with default values.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) assert attack._chunk_size == 50 @@ -55,7 +75,7 @@ def test_init_default_values(self): def test_init_custom_values(self): """Test initialization with custom values.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=25, @@ -69,7 +89,7 @@ def test_init_custom_values(self): def test_init_custom_request_template(self): """Test initialization with custom request template.""" - mock_target = Mock() + mock_target = _make_mock_target() template = "Show me {chunk_type} from position {start} to {end} for '{objective}'" attack = ChunkedRequestAttack( objective_target=mock_target, @@ -80,21 +100,21 @@ def test_init_custom_request_template(self): def test_init_invalid_chunk_size(self): """Test that invalid chunk_size raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="chunk_size must be >= 1"): ChunkedRequestAttack(objective_target=mock_target, chunk_size=0) def test_init_invalid_total_length(self): """Test that invalid total_length raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="total_length must be >= chunk_size"): ChunkedRequestAttack(objective_target=mock_target, chunk_size=100, total_length=50) def test_generate_chunk_prompts(self): """Test chunk prompt generation.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=50, @@ -112,7 +132,7 @@ def test_generate_chunk_prompts(self): def test_generate_chunk_prompts_custom_chunk_type(self): """Test chunk prompt generation with custom chunk type.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=50, @@ -129,7 +149,7 @@ def test_generate_chunk_prompts_custom_chunk_type(self): def test_validate_context_empty_objective(self): """Test validation fails with empty objective.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) context = ChunkedRequestAttackContext(params=AttackParameters(objective="")) @@ -139,7 +159,7 @@ def test_validate_context_empty_objective(self): def test_validate_context_whitespace_objective(self): """Test validation fails with whitespace-only objective.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) context = ChunkedRequestAttackContext(params=AttackParameters(objective=" ")) @@ -149,7 +169,7 @@ def test_validate_context_whitespace_objective(self): def test_validate_context_valid_objective(self): """Test validation succeeds with valid objective.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) context = ChunkedRequestAttackContext(params=AttackParameters(objective="Extract the secret password")) @@ -159,7 +179,7 @@ def test_validate_context_valid_objective(self): def test_init_invalid_request_template_missing_start(self): """Test that request_template without 'start' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -169,7 +189,7 @@ def test_init_invalid_request_template_missing_start(self): def test_init_invalid_request_template_missing_end(self): """Test that request_template without 'end' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -179,7 +199,7 @@ def test_init_invalid_request_template_missing_end(self): def test_init_invalid_request_template_missing_chunk_type(self): """Test that request_template without 'chunk_type' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -189,7 +209,7 @@ def test_init_invalid_request_template_missing_chunk_type(self): def test_init_invalid_request_template_missing_objective(self): """Test that request_template without 'objective' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -199,7 +219,7 @@ def test_init_invalid_request_template_missing_objective(self): def test_init_invalid_request_template_missing_multiple(self): """Test that request_template without multiple placeholders raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -209,7 +229,7 @@ def test_init_invalid_request_template_missing_multiple(self): def test_init_valid_request_template_with_extra_placeholders(self): """Test that request_template with extra placeholders is accepted.""" - mock_target = Mock() + mock_target = _make_mock_target() # Should not raise - extra placeholders are fine as long as required ones are present attack = ChunkedRequestAttack( @@ -221,7 +241,7 @@ def test_init_valid_request_template_with_extra_placeholders(self): def test_generate_chunk_prompts_with_objective(self): """Test that chunk prompts include the objective from context.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=50, diff --git a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py index bdf70a77d2..dc3e19952c 100644 --- a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py +++ b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py @@ -63,6 +63,7 @@ def mock_true_false_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer @@ -70,6 +71,7 @@ def mock_true_false_scorer(): def mock_non_true_false_scorer(): """Create a mock scorer that is not a true/false type""" scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = _mock_scorer_id() return scorer @@ -162,7 +164,8 @@ def test_init_with_valid_true_false_scorer(self, mock_target, mock_true_false_sc def test_init_with_all_custom_configurations(self, mock_target, mock_true_false_scorer, mock_prompt_normalizer): converter_cfg = AttackConverterConfig( - request_converters=[Base64Converter()], response_converters=[StringJoinConverter()] + request_converters=[PromptConverterConfiguration(converters=[Base64Converter()])], + response_converters=[PromptConverterConfiguration(converters=[StringJoinConverter()])], ) scoring_cfg = AttackScoringConfig(objective_scorer=mock_true_false_scorer) @@ -604,7 +607,9 @@ class TestConverterIntegration: async def test_perform_attack_with_converters( self, mock_target, mock_prompt_normalizer, basic_context, sample_response ): - converter_config = AttackConverterConfig(request_converters=[Base64Converter()]) + converter_config = AttackConverterConfig( + request_converters=[PromptConverterConfiguration(converters=[Base64Converter()])] + ) mock_prompt_normalizer.send_prompt_async.return_value = sample_response attack = MultiPromptSendingAttack( @@ -623,7 +628,9 @@ async def test_perform_attack_with_converters( async def test_perform_attack_with_response_converters( self, mock_target, mock_prompt_normalizer, basic_context, sample_response ): - converter_config = AttackConverterConfig(response_converters=[StringJoinConverter()]) + converter_config = AttackConverterConfig( + response_converters=[PromptConverterConfiguration(converters=[StringJoinConverter()])] + ) mock_prompt_normalizer.send_prompt_async.return_value = sample_response attack = MultiPromptSendingAttack( @@ -683,11 +690,13 @@ async def test_perform_attack_with_single_prompt(self, mock_target, mock_prompt_ assert result.last_response is not None assert mock_prompt_normalizer.send_prompt_async.call_count == 1 - def test_attack_has_unique_identifier(self, mock_target): + def test_attack_has_same_identifier_for_same_config(self, mock_target): attack1 = MultiPromptSendingAttack(objective_target=mock_target) attack2 = MultiPromptSendingAttack(objective_target=mock_target) - assert attack1.get_identifier() != attack2.get_identifier() + # Same config produces the same deterministic identifier + assert attack1.get_identifier().hash == attack2.get_identifier().hash + assert attack1.get_identifier().class_name == "MultiPromptSendingAttack" @pytest.mark.asyncio async def test_teardown_async_is_noop(self, mock_target, basic_context): diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 340dc10c60..e6ed898b4a 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -1230,7 +1230,9 @@ async def test_perform_attack_with_message_bypasses_adversarial_chat_on_first_tu ): """Test that providing a message parameter bypasses adversarial chat generation on first turn.""" adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) - scoring_config = AttackScoringConfig(objective_scorer=MagicMock(spec=TrueFalseScorer)) + inline_scorer = MagicMock(spec=TrueFalseScorer) + inline_scorer.get_identifier.return_value = _mock_scorer_id() + scoring_config = AttackScoringConfig(objective_scorer=inline_scorer) attack = RedTeamingAttack( objective_target=mock_objective_target, @@ -1272,7 +1274,9 @@ async def test_perform_attack_with_multi_piece_message_uses_first_piece( ): """Test that multi-piece messages use only the first piece's converted_value.""" adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) - scoring_config = AttackScoringConfig(objective_scorer=MagicMock(spec=TrueFalseScorer)) + inline_scorer = MagicMock(spec=TrueFalseScorer) + inline_scorer.get_identifier.return_value = _mock_scorer_id() + scoring_config = AttackScoringConfig(objective_scorer=inline_scorer) attack = RedTeamingAttack( objective_target=mock_objective_target, diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index 3908cb89fe..345a8a8b36 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -15,7 +15,7 @@ ContextComplianceAttack, SingleTurnAttackContext, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -37,6 +37,17 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" @@ -68,6 +79,7 @@ def mock_scorer(): """Create a mock true/false scorer""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_flip_attack.py b/tests/unit/executor/attack/single_turn/test_flip_attack.py index faffd77099..db23e09320 100644 --- a/tests/unit/executor/attack/single_turn/test_flip_attack.py +++ b/tests/unit/executor/attack/single_turn/test_flip_attack.py @@ -13,7 +13,7 @@ FlipAttack, SingleTurnAttackContext, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -34,6 +34,17 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" @@ -54,6 +65,7 @@ def mock_scorer(): """Create a mock true/false scorer""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py index 9f9877c325..c81166ea07 100644 --- a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py +++ b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py @@ -13,7 +13,7 @@ ManyShotJailbreakAttack, SingleTurnAttackContext, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -35,6 +35,17 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptTarget for testing""" @@ -67,6 +78,7 @@ def mock_scorer(): """Create a mock true/false scorer""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index b55da7cf3c..1e1b708249 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -44,6 +44,7 @@ def mock_true_false_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer @@ -51,6 +52,7 @@ def mock_true_false_scorer(): def mock_non_true_false_scorer(): """Create a mock scorer that is not a true/false type""" scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer @@ -1146,13 +1148,13 @@ def test_attack_has_unique_identifier(self, mock_target): id2 = attack2.get_identifier() # Verify identifier structure - assert "__type__" in id1 - assert "__module__" in id1 - assert "id" in id1 + assert id1.class_name == "PromptSendingAttack" + assert id1.class_module is not None + assert id1.hash is not None - # Verify uniqueness - assert id1["id"] != id2["id"] - assert id1["__type__"] == id2["__type__"] == "PromptSendingAttack" + # Same config produces same identifier + assert id1.hash == id2.hash + assert id1.class_name == id2.class_name == "PromptSendingAttack" @pytest.mark.asyncio async def test_retry_stores_unsuccessful_conversation_and_updates_id( diff --git a/tests/unit/executor/attack/single_turn/test_role_play.py b/tests/unit/executor/attack/single_turn/test_role_play.py index 98218e3c9b..114e99bdf7 100644 --- a/tests/unit/executor/attack/single_turn/test_role_play.py +++ b/tests/unit/executor/attack/single_turn/test_role_play.py @@ -51,6 +51,7 @@ def mock_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index 41278e9bdf..f2bd77ec12 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -41,6 +41,7 @@ def mock_true_false_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer diff --git a/tests/unit/executor/benchmark/test_fairness_bias.py b/tests/unit/executor/benchmark/test_fairness_bias.py index 8ad868d813..7e97ffd9ee 100644 --- a/tests/unit/executor/benchmark/test_fairness_bias.py +++ b/tests/unit/executor/benchmark/test_fairness_bias.py @@ -17,6 +17,7 @@ MessagePiece, ) from pyrit.prompt_target import PromptTarget +from pyrit.identifiers import TargetIdentifier def is_spacy_installed(): @@ -29,10 +30,22 @@ def is_spacy_installed(): # Fixtures at the top of the file + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_prompt_target() -> MagicMock: """Mock prompt target for testing.""" target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("mock_prompt_target") return target @@ -64,7 +77,6 @@ def sample_attack_result() -> AttackResult: result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "fairness_bias_benchmark"}, executed_turns=1, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, diff --git a/tests/unit/executor/benchmark/test_question_answering.py b/tests/unit/executor/benchmark/test_question_answering.py index 5e5af568c9..c88423326d 100644 --- a/tests/unit/executor/benchmark/test_question_answering.py +++ b/tests/unit/executor/benchmark/test_question_answering.py @@ -19,13 +19,26 @@ QuestionChoice, ) from pyrit.prompt_target import PromptTarget +from pyrit.identifiers import TargetIdentifier # Fixtures at the top of the file + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_prompt_target() -> MagicMock: """Mock prompt target for testing.""" target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("mock_prompt_target") return target diff --git a/tests/unit/executor/promptgen/test_anecdoctor.py b/tests/unit/executor/promptgen/test_anecdoctor.py index 78dbfd1725..887e8e59d7 100644 --- a/tests/unit/executor/promptgen/test_anecdoctor.py +++ b/tests/unit/executor/promptgen/test_anecdoctor.py @@ -16,6 +16,18 @@ from pyrit.models import Message from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptChatTarget +from pyrit.identifiers import TargetIdentifier + + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) @pytest.fixture @@ -23,6 +35,7 @@ def mock_objective_target() -> PromptChatTarget: """Create a mock objective target for testing.""" mock_target = MagicMock(spec=PromptChatTarget) mock_target.set_system_prompt = MagicMock() + mock_target.get_identifier.return_value = _mock_target_id("mock_objective_target") return mock_target @@ -31,6 +44,7 @@ def mock_processing_model() -> PromptChatTarget: """Create a mock processing model for testing.""" mock_model = MagicMock(spec=PromptChatTarget) mock_model.set_system_prompt = MagicMock() + mock_model.get_identifier.return_value = _mock_target_id("MockProcessingModel") return mock_model diff --git a/tests/unit/executor/workflow/test_xpia.py b/tests/unit/executor/workflow/test_xpia.py index d25cd00f5c..1c9162114a 100644 --- a/tests/unit/executor/workflow/test_xpia.py +++ b/tests/unit/executor/workflow/test_xpia.py @@ -16,13 +16,36 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import Scorer +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier # Shared fixtures for all test classes + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_attack_setup_target() -> MagicMock: """Create a mock attack setup target.""" target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("mock_attack_setup_target") return target @@ -31,6 +54,7 @@ def mock_scorer() -> MagicMock: """Create a mock scorer.""" scorer = MagicMock(spec=Scorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/identifiers/test_identifiers.py b/tests/unit/identifiers/test_identifiers.py index a9e373a7f7..64e1da8904 100644 --- a/tests/unit/identifiers/test_identifiers.py +++ b/tests/unit/identifiers/test_identifiers.py @@ -6,36 +6,10 @@ import pytest import pyrit -from pyrit.identifiers import Identifier, LegacyIdentifiable +from pyrit.identifiers import Identifier from pyrit.identifiers.identifier import _EXCLUDE, _ExcludeFrom, _expand_exclusions -class TestLegacyIdentifiable: - """Tests for the LegacyIdentifiable abstract base class.""" - - def test_legacy_identifiable_get_identifier_is_abstract(self): - """Test that get_identifier is an abstract method that must be implemented.""" - - class ConcreteLegacyIdentifiable(LegacyIdentifiable): - def get_identifier(self) -> dict[str, str]: - return {"type": "test", "name": "example"} - - obj = ConcreteLegacyIdentifiable() - result = obj.get_identifier() - assert result == {"type": "test", "name": "example"} - - def test_legacy_identifiable_str_returns_identifier_dict(self): - """Test that __str__ returns the get_identifier() result as a string.""" - - class ConcreteLegacyIdentifiable(LegacyIdentifiable): - def get_identifier(self) -> dict[str, str]: - return {"type": "test"} - - obj = ConcreteLegacyIdentifiable() - # __str__ returns the identifier dict as a string - assert str(obj) == "{'type': 'test'}" - - class TestIdentifier: """Tests for the Identifier dataclass.""" diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index be2fa4c64d..c37c39f2a9 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -6,7 +6,7 @@ from typing import Sequence from pyrit.common.utils import to_sha256 -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier from pyrit.memory import MemoryInterface from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( @@ -36,7 +36,6 @@ def create_attack_result(conversation_id: str, objective_num: int, outcome: Atta return AttackResult( conversation_id=conversation_id, objective=f"Objective {objective_num}", - attack_identifier={"name": "test_attack"}, outcome=outcome, ) @@ -47,7 +46,6 @@ def test_add_attack_results_to_memory(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1", "module": "test_module"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -58,7 +56,6 @@ def test_add_attack_results_to_memory(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2", "module": "test_module"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -85,7 +82,6 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -94,7 +90,6 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -103,7 +98,6 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_3", objective="Test objective 3", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -134,7 +128,6 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -143,7 +136,6 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) attack_result2 = AttackResult( conversation_id="conv_1", # Same conversation ID objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -152,7 +144,6 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) attack_result3 = AttackResult( conversation_id="conv_2", # Different conversation ID objective="Test objective 3", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -176,7 +167,6 @@ def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective for success", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -185,7 +175,6 @@ def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Another objective for failure", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -194,7 +183,6 @@ def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_3", objective="Different objective entirely", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -219,7 +207,6 @@ def test_get_attack_results_by_outcome(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -228,7 +215,6 @@ def test_get_attack_results_by_outcome(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.SUCCESS, # Same outcome @@ -237,7 +223,6 @@ def test_get_attack_results_by_outcome(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_3", objective="Test objective 3", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.FAILURE, # Different outcome @@ -267,7 +252,6 @@ def test_get_attack_results_by_objective_sha256(sqlite_instance: MemoryInterface attack_result1 = AttackResult( conversation_id="conv_1", objective=objective1, - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -277,7 +261,6 @@ def test_get_attack_results_by_objective_sha256(sqlite_instance: MemoryInterface attack_result2 = AttackResult( conversation_id="conv_2", objective=objective2, - attack_identifier={"name": "test_attack"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -287,7 +270,6 @@ def test_get_attack_results_by_objective_sha256(sqlite_instance: MemoryInterface attack_result3 = AttackResult( conversation_id="conv_3", objective=objective3, - attack_identifier={"name": "test_attack"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -312,7 +294,6 @@ def test_get_attack_results_multiple_filters(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective for success", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -321,7 +302,6 @@ def test_get_attack_results_multiple_filters(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_1", # Same conversation ID objective="Another objective for failure", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, # Different outcome @@ -330,7 +310,6 @@ def test_get_attack_results_multiple_filters(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_2", # Different conversation ID objective="Test objective for success", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.SUCCESS, @@ -357,7 +336,6 @@ def test_get_attack_results_no_filters(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -366,7 +344,6 @@ def test_get_attack_results_no_filters(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -388,7 +365,6 @@ def test_get_attack_results_empty_list(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id="conv_1", objective="Test objective", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -407,7 +383,6 @@ def test_get_attack_results_nonexistent_ids(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id="conv_1", objective="Test objective", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -457,7 +432,6 @@ def test_attack_result_with_last_response_and_score(sqlite_instance: MemoryInter attack_result = AttackResult( conversation_id="conv_1", objective="Test objective with relationships", - attack_identifier={"name": "test_attack"}, last_response=message_piece, last_score=score, executed_turns=5, @@ -487,7 +461,9 @@ def test_attack_result_all_outcomes(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id=f"conv_{i}", objective=f"Test objective {i}", - attack_identifier={"name": f"test_attack_{i}"}, + attack_identifier=AttackIdentifier( + class_name=f"TestAttack{i}", class_module="test.module" + ), executed_turns=i + 1, execution_time_ms=(i + 1) * 100, outcome=outcome, @@ -523,7 +499,6 @@ def test_attack_result_metadata_handling(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id="conv_1", objective="Test objective with metadata", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -547,7 +522,6 @@ def test_attack_result_objective_sha256_auto_generation(sqlite_instance: MemoryI attack_result = AttackResult( conversation_id="conv_1", objective=objective, - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -577,7 +551,6 @@ def test_attack_result_with_attack_generation_conversation_ids(sqlite_instance: attack_result = AttackResult( conversation_id="conv_1", objective="Test objective with conversation IDs", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -605,7 +578,6 @@ def test_attack_result_without_attack_generation_conversation_ids(sqlite_instanc attack_result = AttackResult( conversation_id="conv_1", objective="Test objective without conversation IDs", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 16b5847522..8fe3b2fa50 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -15,7 +15,7 @@ def test_export_conversation_by_attack_id_file_created( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): - attack1_id = sample_conversations[0].attack_identifier["id"] + attack1_id = sample_conversations[0].attack_identifier.hash # Default path in export_conversations() file_name = f"{attack1_id}.json" diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index a6b9b4fc05..ff054a4bec 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -10,6 +10,8 @@ import pytest +from unit.mocks import get_mock_target + from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ScorerIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry @@ -110,8 +112,8 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface def test_duplicate_memory(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) - attack2 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) conversation_id_1 = "11111" conversation_id_2 = "22222" conversation_id_3 = "33333" @@ -167,8 +169,8 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): all_pieces = sqlite_instance.get_message_pieces() assert len(all_pieces) == 9 # Attack IDs are preserved (not changed) when duplicating - assert len([p for p in all_pieces if p.attack_identifier["id"] == attack1.get_identifier()["id"]]) == 8 - assert len([p for p in all_pieces if p.attack_identifier["id"] == attack2.get_identifier()["id"]]) == 1 + assert len([p for p in all_pieces if p.attack_identifier.hash == attack1.get_identifier().hash]) == 8 + assert len([p for p in all_pieces if p.attack_identifier.hash == attack2.get_identifier().hash]) == 1 assert len([p for p in all_pieces if p.conversation_id == conversation_id_1]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_2]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_3]) == 1 @@ -181,7 +183,7 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac conversation_id = str(uuid4()) prompt_id_1 = uuid4() prompt_id_2 = uuid4() - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) memory_labels = {"sample": "label"} pieces = [ MessagePiece( @@ -245,15 +247,15 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac assert piece.id not in (prompt_id_1, prompt_id_2) assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 # Attack ID is preserved, so both original and duplicated pieces have the same attack ID - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier()["id"])) == 2 + assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) - attack2 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id_1 = "11111" conversation_id_2 = "22222" pieces = [ @@ -317,7 +319,7 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M conversation_id = str(uuid4()) prompt_id_1 = uuid4() prompt_id_2 = uuid4() - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) memory_labels = {"sample": "label"} pieces = [ MessagePiece( @@ -399,13 +401,13 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M assert new_pieces[1].id != prompt_id_2 assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 # Attack ID is preserved - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier()["id"])) == 2 + assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id_1 = "11111" pieces = [ MessagePiece( @@ -455,7 +457,7 @@ def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: def test_duplicate_memory_preserves_attack_id(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "11111" pieces = [ MessagePiece( @@ -481,14 +483,14 @@ def test_duplicate_memory_preserves_attack_id(sqlite_instance: MemoryInterface): assert new_conversation_id != conversation_id # Both pieces should have the same attack ID - attack_ids = {p.attack_identifier["id"] for p in all_pieces} + attack_ids = {p.attack_identifier.hash for p in all_pieces} assert len(attack_ids) == 1 - assert attack1.get_identifier()["id"] in attack_ids + assert attack1.get_identifier().hash in attack_ids def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface): """Test that duplicated conversation has new piece IDs.""" - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "test-conv-123" original_piece = MessagePiece( role="user", @@ -520,7 +522,7 @@ def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface def test_duplicate_conversation_preserves_original_prompt_id(sqlite_instance: MemoryInterface): """Test that duplicated conversation preserves original_prompt_id for tracing.""" - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "test-conv-456" original_piece = MessagePiece( role="user", @@ -544,7 +546,7 @@ def test_duplicate_conversation_preserves_original_prompt_id(sqlite_instance: Me def test_duplicate_conversation_with_multiple_pieces(sqlite_instance: MemoryInterface): """Test that duplicating a multi-piece conversation works correctly.""" - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "multi-piece-conv" pieces = [ @@ -789,8 +791,8 @@ def test_get_message_pieces_id(sqlite_instance: MemoryInterface): def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) - attack2 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) entries = [ PromptMemoryEntry( @@ -818,7 +820,7 @@ def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): sqlite_instance._insert_entries(entries=entries) - attack1_entries = sqlite_instance.get_message_pieces(attack_id=attack1.get_identifier()["id"]) + attack1_entries = sqlite_instance.get_message_pieces(attack_id=attack1.get_identifier().hash) assert len(attack1_entries) == 2 assert_original_value_in_list("Hello 1", attack1_entries) @@ -950,7 +952,7 @@ def test_get_message_pieces_by_hash(sqlite_instance: MemoryInterface): def test_get_message_pieces_with_non_matching_memory_labels(sqlite_instance: MemoryInterface): - attack = PromptSendingAttack(objective_target=MagicMock()) + attack = PromptSendingAttack(objective_target=get_mock_target()) labels = {"op_name": "op1", "user_name": "name1", "harm_category": "dummy1"} entries = [ PromptMemoryEntry( diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 4fd86a4130..5319ba02d5 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -30,7 +30,6 @@ def create_attack_result(conversation_id: str, objective: str, outcome: AttackOu return AttackResult( conversation_id=conversation_id, objective=objective, - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=outcome, diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 7941f3b79d..080cccff77 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -9,6 +9,8 @@ import pytest +from unit.mocks import get_mock_target + from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ScorerIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry @@ -54,7 +56,7 @@ def test_get_scores_by_attack_id_and_label( sqlite_instance.add_scores_to_memory(scores=[score]) # Fetch the score we just added - db_score = sqlite_instance.get_prompt_scores(attack_id=sample_conversations[0].attack_identifier["id"]) + db_score = sqlite_instance.get_prompt_scores(attack_id=sample_conversations[0].attack_identifier.hash) assert len(db_score) == 1 assert db_score[0].score_value == score.score_value @@ -75,7 +77,7 @@ def test_get_scores_by_attack_id_and_label( assert db_score[0].score_value == score.score_value db_score = sqlite_instance.get_prompt_scores( - attack_id=sample_conversations[0].attack_identifier["id"], + attack_id=sample_conversations[0].attack_identifier.hash, labels={"x": "y"}, ) assert len(db_score) == 0 @@ -133,7 +135,7 @@ def test_add_score_get_score( def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): # Ensure that scores of duplicate prompts are linked back to the original original_id = uuid4() - attack = PromptSendingAttack(objective_target=MagicMock()) + attack = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = str(uuid4()) pieces = [ MessagePiece( diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index dd2b0f0174..eb2514239b 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -12,7 +12,7 @@ from pyrit.identifiers import AttackIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry from pyrit.models import Message, MessagePiece -from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute +from pyrit.prompt_target import PromptChatTarget, PromptTarget, limit_requests_per_minute def get_mock_scorer_identifier() -> ScorerIdentifier: @@ -48,6 +48,43 @@ def get_mock_target_identifier(name: str = "MockTarget", module: str = "tests.un ) +def get_mock_attack_identifier( + name: str = "MockAttack", module: str = "tests.unit.mocks" +) -> AttackIdentifier: + """ + Returns a mock AttackIdentifier for use in tests where the specific + attack identity doesn't matter. + + Args: + name: The class name for the mock attack. Defaults to "MockAttack". + module: The module path for the mock attack. Defaults to "tests.unit.mocks". + + Returns: + An AttackIdentifier configured with the provided name and module. + """ + return AttackIdentifier( + class_name=name, + class_module=module, + ) + + +def get_mock_target(name: str = "MockTarget") -> MagicMock: + """ + Returns a MagicMock target whose ``get_identifier()`` returns a real + :class:`TargetIdentifier`. Use this wherever a ``MagicMock(spec=PromptTarget)`` + is needed as an ``objective_target``. + + Args: + name: The class name for the mock target. Defaults to "MockTarget". + + Returns: + A MagicMock configured to return a real TargetIdentifier. + """ + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = get_mock_target_identifier(name) + return target + + class MockHttpPostAsync(AbstractAsyncContextManager): def __init__(self, url, headers=None, json=None, params=None, ssl=None): self.status = 200 @@ -222,11 +259,7 @@ def get_test_message_piece() -> MessagePiece: def get_sample_conversations() -> MutableSequence[Message]: with patch.object(CentralMemory, "get_memory_instance", return_value=MagicMock()): conversation_1 = str(uuid.uuid4()) - attack_identifier = { - "__type__": "MockPromptTarget", - "__module__": "unit.mocks", - "id": str(uuid.uuid4()), - } + attack_id = get_mock_attack_identifier() return [ MessagePiece( @@ -235,7 +268,7 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="Hello, how are you?", conversation_id=conversation_1, sequence=0, - attack_identifier=attack_identifier, + attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", @@ -243,14 +276,14 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="I'm fine, thank you!", conversation_id=conversation_1, sequence=1, - attack_identifier=attack_identifier, + attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", original_value="original prompt text", converted_value="I'm fine, thank you!", conversation_id=str(uuid.uuid4()), - attack_identifier=attack_identifier, + attack_identifier=attack_id, ).to_message(), ] diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index 78ecb60080..d1f9a580c3 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock import pytest -from unit.mocks import MockPromptTarget, get_sample_conversations +from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations from pyrit.executor.attack import PromptSendingAttack from pyrit.identifiers import ScorerIdentifier @@ -83,7 +83,7 @@ def test_prompt_targets_serialize(patch_central_database): def test_executors_serialize(): - attack = PromptSendingAttack(objective_target=MagicMock()) + attack = PromptSendingAttack(objective_target=get_mock_target()) entry = MessagePiece( role="user", @@ -92,9 +92,9 @@ def test_executors_serialize(): attack_identifier=attack.get_identifier(), ) - assert entry.attack_identifier["id"] is not None - assert entry.attack_identifier["__type__"] == "PromptSendingAttack" - assert entry.attack_identifier["__module__"] == "pyrit.executor.attack.single_turn.prompt_sending" + assert entry.attack_identifier.hash is not None + assert entry.attack_identifier.class_name == "PromptSendingAttack" + assert entry.attack_identifier.class_module == "pyrit.executor.attack.single_turn.prompt_sending" @pytest.mark.asyncio @@ -746,7 +746,7 @@ def test_message_piece_to_dict(): assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() - assert result["attack_identifier"] == entry.attack_identifier + assert result["attack_identifier"] == entry.attack_identifier.to_dict() assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index e02104dad9..3c8381dcd5 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +from dataclasses import dataclass, field import pytest @@ -13,7 +13,7 @@ class MetadataWithTags(Identifier): """Test metadata with a tags field for list filtering tests.""" - tags: tuple[str, ...] + tags: tuple[str, ...] = field(kw_only=True) class TestMatchesFilters: diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 9f5744e2c1..0a774d6b1e 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +from dataclasses import dataclass, field from pyrit.identifiers import Identifier from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry @@ -11,7 +11,7 @@ class SampleItemMetadata(Identifier): """Sample metadata with an extra field.""" - category: str + category: str = field(kw_only=True) class ConcreteTestRegistry(BaseInstanceRegistry[str, SampleItemMetadata]): diff --git a/tests/unit/scenarios/test_atomic_attack.py b/tests/unit/scenarios/test_atomic_attack.py index 7cb41d0a2f..b372777def 100644 --- a/tests/unit/scenarios/test_atomic_attack.py +++ b/tests/unit/scenarios/test_atomic_attack.py @@ -75,21 +75,18 @@ def sample_attack_results(): AttackResult( conversation_id="conv-1", objective="objective1", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "1"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ), AttackResult( conversation_id="conv-2", objective="objective2", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "2"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ), AttackResult( conversation_id="conv-3", objective="objective3", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "3"}, outcome=AttackOutcome.FAILURE, executed_turns=1, ), @@ -431,7 +428,6 @@ async def test_full_attack_run_execution_flow(self, mock_attack, sample_seed_gro AttackResult( conversation_id=f"conv-{i}", objective=f"objective{i + 1}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -476,7 +472,6 @@ async def test_atomic_attack_with_single_seed_group(self, mock_attack): AttackResult( conversation_id="conv-1", objective="single_objective", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "1"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -513,7 +508,6 @@ async def test_atomic_attack_with_many_seed_groups(self, mock_attack): AttackResult( conversation_id=f"conv-{i}", objective=f"objective_{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -682,7 +676,6 @@ async def test_run_async_passes_seed_groups_with_messages(self, mock_attack, see AttackResult( conversation_id=f"conv-{i}", objective=seed_groups_with_messages[i].objective.value, - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=len(seed_groups_with_messages[i].user_messages), ) diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index f7e0e6fe4e..948fa899b8 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -87,11 +87,6 @@ def sample_attack_results(): AttackResult( conversation_id=f"conv-{i}", objective=f"objective{i}", - attack_identifier={ - "__type__": "TestAttack", - "__module__": "test", - "id": str(i), - }, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -574,22 +569,12 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): AttackResult( conversation_id="conv-fail", objective="objective", - attack_identifier={ - "__type__": "TestAttack", - "__module__": "test", - "id": "1", - }, outcome=AttackOutcome.FAILURE, executed_turns=1, ), AttackResult( conversation_id="conv-fail2", objective="objective", - attack_identifier={ - "__type__": "TestAttack", - "__module__": "test", - "id": "2", - }, outcome=AttackOutcome.FAILURE, executed_turns=1, ), diff --git a/tests/unit/scenarios/test_scenario_partial_results.py b/tests/unit/scenarios/test_scenario_partial_results.py index 1886c36395..c93d9a4309 100644 --- a/tests/unit/scenarios/test_scenario_partial_results.py +++ b/tests/unit/scenarios/test_scenario_partial_results.py @@ -138,7 +138,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -156,7 +155,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id="conv-3", objective="obj3", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "3"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -199,7 +197,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -258,7 +255,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -275,7 +271,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -335,7 +330,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id="conv-a2-1", objective="a2_obj1", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "a2_1"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -351,7 +345,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{obj}", objective=obj, - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": obj}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) diff --git a/tests/unit/scenarios/test_scenario_retry.py b/tests/unit/scenarios/test_scenario_retry.py index 69bcecef01..36ba15d4c3 100644 --- a/tests/unit/scenarios/test_scenario_retry.py +++ b/tests/unit/scenarios/test_scenario_retry.py @@ -69,7 +69,6 @@ def create_attack_result( return AttackResult( conversation_id=conversation_id or f"{CONV_ID_PREFIX}{index}", objective=objective or f"{OBJECTIVE_PREFIX}{index}", - attack_identifier={"__type__": TEST_ATTACK_TYPE, "__module__": TEST_MODULE, "id": str(index)}, outcome=outcome, executed_turns=executed_turns, ) diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 8fc2f55457..01c31c1127 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -10,7 +10,7 @@ from unit.mocks import get_mock_target_identifier from pyrit.exceptions import InvalidJsonException, remove_markdown_json -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier from pyrit.memory import CentralMemory from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_target import PromptChatTarget @@ -206,7 +206,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j chat_target.set_system_prompt = MagicMock() expected_system_prompt = "system_prompt" - expected_attack_id = "attack_id" + expected_attack_identifier = AttackIdentifier(class_name="TestAttack", class_module="test.module") expected_scored_prompt_id = "123" await scorer._score_value_with_llm( @@ -217,7 +217,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j scored_prompt_id=expected_scored_prompt_id, category="category", objective="task", - attack_identifier={"id": expected_attack_id}, + attack_identifier=expected_attack_identifier, ) chat_target.set_system_prompt.assert_called_once() @@ -225,8 +225,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j _, set_sys_prompt_args = chat_target.set_system_prompt.call_args assert set_sys_prompt_args["system_prompt"] == expected_system_prompt assert isinstance(set_sys_prompt_args["conversation_id"], str) - assert set_sys_prompt_args["attack_identifier"]["id"] == expected_attack_id - assert set_sys_prompt_args["attack_identifier"]["scored_prompt_id"] == expected_scored_prompt_id + assert set_sys_prompt_args["attack_identifier"] is expected_attack_identifier @pytest.mark.asyncio diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index 5d49702b07..68a8e715ed 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -67,7 +67,9 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite @patch("httpx.AsyncClient.request") async def test_send_prompt_async(mock_request, mock_http_target, mock_http_response): message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] mock_request.return_value = mock_http_response response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -113,7 +115,9 @@ async def test_send_prompt_async_client_kwargs(): # Use **httpx_client_kwargs to pass them as keyword arguments http_target = HTTPTarget(http_request=sample_request, **httpx_client_kwargs) message = MagicMock() - message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="", prompt_target_identifier=None, attack_identifier=None) + ] mock_response = MagicMock() mock_response.content = b"Response content" mock_request.return_value = mock_response @@ -148,7 +152,9 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): mock_http_target.callback_function = callback_function message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] mock_response = MagicMock() mock_response.content = b"Match: 1234" @@ -175,7 +181,9 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send first prompt message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -193,7 +201,9 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send second prompt second_message = MagicMock() - second_message.message_pieces = [MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None)] + second_message.message_pieces = [ + MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] await mock_http_target.send_prompt_async(message=second_message) # Assert that the original template is still the same @@ -241,7 +251,9 @@ async def test_http_target_with_injected_client(): mock_request.return_value = mock_response message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] response = await target.send_prompt_async(message=message) From a785ee38564c2975dd9286acefbc90f17fa7cd9e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 9 Feb 2026 09:48:49 -0800 Subject: [PATCH 13/35] pre-commit --- pyrit/executor/attack/core/attack_strategy.py | 12 +++++------- pyrit/prompt_target/common/prompt_chat_target.py | 2 +- pyrit/score/scorer.py | 4 ++-- .../attack/component/test_conversation_manager.py | 2 +- .../attack/multi_turn/test_chunked_request.py | 7 +++---- .../attack/single_turn/test_context_compliance.py | 1 - .../executor/attack/single_turn/test_flip_attack.py | 1 - .../attack/single_turn/test_many_shot_jailbreak.py | 1 - tests/unit/executor/benchmark/test_fairness_bias.py | 3 ++- .../executor/benchmark/test_question_answering.py | 4 ++-- tests/unit/executor/promptgen/test_anecdoctor.py | 3 +-- tests/unit/executor/workflow/test_xpia.py | 4 ++-- .../test_interface_attack_results.py | 4 +--- .../memory_interface/test_interface_prompts.py | 1 - .../memory/memory_interface/test_interface_scores.py | 2 -- tests/unit/mocks.py | 4 +--- tests/unit/models/test_message_piece.py | 1 - tests/unit/target/test_http_target.py | 4 +--- 18 files changed, 22 insertions(+), 38 deletions(-) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 802a10848e..9f7f433858 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -259,8 +259,8 @@ def __init__( ) self._objective_target = objective_target self._params_type = params_type - self._request_converters: list = [] - self._response_converters: list = [] + self._request_converters: list[Any] = [] + self._response_converters: list[Any] = [] def _build_identifier(self) -> AttackIdentifier: """ @@ -286,9 +286,7 @@ def _build_identifier(self) -> AttackIdentifier: converter_identifiers = None if self._request_converters: converter_identifiers = [ - converter.get_identifier() - for config in self._request_converters - for converter in config.converters + converter.get_identifier() for config in self._request_converters for converter in config.converters ] return AttackIdentifier( @@ -331,12 +329,12 @@ def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: """ return None - def get_request_converters(self) -> list: + def get_request_converters(self) -> list[Any]: """ Get request converter configurations used by this strategy. Returns: - list: The list of request PromptConverterConfiguration objects. + list[Any]: The list of request PromptConverterConfiguration objects. """ return self._request_converters diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index b2c51d23ac..b837918295 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import abc -from typing import Optional, Union +from typing import Optional from pyrit.identifiers import AttackIdentifier from pyrit.models import MessagePiece diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index d973c1d445..6765d907e1 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -520,7 +520,7 @@ async def _score_value_with_llm( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[Union[AttackIdentifier, Dict[str, str]]] = None, + attack_identifier: Optional[AttackIdentifier] = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -554,7 +554,7 @@ async def _score_value_with_llm( Defaults to "metadata". category_output_key (str): The key in the JSON response that contains the category. Defaults to "category". - attack_identifier (Optional[Union[AttackIdentifier, Dict[str, str]]]): The attack identifier. + attack_identifier (Optional[AttackIdentifier]): The attack identifier. Defaults to None. Returns: diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index 251fd023d6..ecb3553758 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -18,7 +18,7 @@ """ import uuid -from typing import Dict, List, Optional +from typing import List, Optional from unittest.mock import AsyncMock, MagicMock import pytest diff --git a/tests/unit/executor/attack/multi_turn/test_chunked_request.py b/tests/unit/executor/attack/multi_turn/test_chunked_request.py index 3c6736f586..aad88f1099 100644 --- a/tests/unit/executor/attack/multi_turn/test_chunked_request.py +++ b/tests/unit/executor/attack/multi_turn/test_chunked_request.py @@ -8,18 +8,17 @@ during Crucible CTF red teaming exercises using PyRIT. """ -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock import pytest -from pyrit.identifiers import TargetIdentifier -from pyrit.prompt_target import PromptTarget - from pyrit.executor.attack.core.attack_parameters import AttackParameters from pyrit.executor.attack.multi_turn import ( ChunkedRequestAttack, ChunkedRequestAttackContext, ) +from pyrit.identifiers import TargetIdentifier +from pyrit.prompt_target import PromptTarget def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index 345a8a8b36..64975ba3a7 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -37,7 +37,6 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) - def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: """Helper to create ScorerIdentifier for tests.""" return ScorerIdentifier( diff --git a/tests/unit/executor/attack/single_turn/test_flip_attack.py b/tests/unit/executor/attack/single_turn/test_flip_attack.py index db23e09320..53caecd71a 100644 --- a/tests/unit/executor/attack/single_turn/test_flip_attack.py +++ b/tests/unit/executor/attack/single_turn/test_flip_attack.py @@ -34,7 +34,6 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) - def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: """Helper to create ScorerIdentifier for tests.""" return ScorerIdentifier( diff --git a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py index c81166ea07..c15038c5da 100644 --- a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py +++ b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py @@ -35,7 +35,6 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) - def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: """Helper to create ScorerIdentifier for tests.""" return ScorerIdentifier( diff --git a/tests/unit/executor/benchmark/test_fairness_bias.py b/tests/unit/executor/benchmark/test_fairness_bias.py index 7e97ffd9ee..b23e9cf0a1 100644 --- a/tests/unit/executor/benchmark/test_fairness_bias.py +++ b/tests/unit/executor/benchmark/test_fairness_bias.py @@ -10,6 +10,7 @@ FairnessBiasBenchmark, FairnessBiasBenchmarkContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -17,7 +18,6 @@ MessagePiece, ) from pyrit.prompt_target import PromptTarget -from pyrit.identifiers import TargetIdentifier def is_spacy_installed(): @@ -31,6 +31,7 @@ def is_spacy_installed(): # Fixtures at the top of the file + def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: """Helper to create TargetIdentifier for tests.""" return TargetIdentifier( diff --git a/tests/unit/executor/benchmark/test_question_answering.py b/tests/unit/executor/benchmark/test_question_answering.py index c88423326d..0233de781b 100644 --- a/tests/unit/executor/benchmark/test_question_answering.py +++ b/tests/unit/executor/benchmark/test_question_answering.py @@ -10,6 +10,7 @@ QuestionAnsweringBenchmark, QuestionAnsweringBenchmarkContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -19,11 +20,10 @@ QuestionChoice, ) from pyrit.prompt_target import PromptTarget -from pyrit.identifiers import TargetIdentifier - # Fixtures at the top of the file + def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: """Helper to create TargetIdentifier for tests.""" return TargetIdentifier( diff --git a/tests/unit/executor/promptgen/test_anecdoctor.py b/tests/unit/executor/promptgen/test_anecdoctor.py index 887e8e59d7..e841c0c1ed 100644 --- a/tests/unit/executor/promptgen/test_anecdoctor.py +++ b/tests/unit/executor/promptgen/test_anecdoctor.py @@ -13,11 +13,10 @@ AnecdoctorGenerator, AnecdoctorResult, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptChatTarget -from pyrit.identifiers import TargetIdentifier - def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: diff --git a/tests/unit/executor/workflow/test_xpia.py b/tests/unit/executor/workflow/test_xpia.py index 1c9162114a..baede625bf 100644 --- a/tests/unit/executor/workflow/test_xpia.py +++ b/tests/unit/executor/workflow/test_xpia.py @@ -12,15 +12,15 @@ XPIAStatus, XPIAWorkflow, ) +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import Scorer -from pyrit.identifiers import ScorerIdentifier, TargetIdentifier - # Shared fixtures for all test classes + def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: """Helper to create ScorerIdentifier for tests.""" return ScorerIdentifier( diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index c37c39f2a9..56813d4417 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -461,9 +461,7 @@ def test_attack_result_all_outcomes(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id=f"conv_{i}", objective=f"Test objective {i}", - attack_identifier=AttackIdentifier( - class_name=f"TestAttack{i}", class_module="test.module" - ), + attack_identifier=AttackIdentifier(class_name=f"TestAttack{i}", class_module="test.module"), executed_turns=i + 1, execution_time_ms=(i + 1) * 100, outcome=outcome, diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index ff054a4bec..1c064da99c 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -9,7 +9,6 @@ from uuid import uuid4 import pytest - from unit.mocks import get_mock_target from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 080cccff77..d6ee9201f1 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -4,11 +4,9 @@ import uuid from typing import Literal, Sequence -from unittest.mock import MagicMock from uuid import uuid4 import pytest - from unit.mocks import get_mock_target from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index eb2514239b..5a58d40c1c 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -48,9 +48,7 @@ def get_mock_target_identifier(name: str = "MockTarget", module: str = "tests.un ) -def get_mock_attack_identifier( - name: str = "MockAttack", module: str = "tests.unit.mocks" -) -> AttackIdentifier: +def get_mock_attack_identifier(name: str = "MockAttack", module: str = "tests.unit.mocks") -> AttackIdentifier: """ Returns a mock AttackIdentifier for use in tests where the specific attack identity doesn't matter. diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index d1f9a580c3..dca380baad 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -8,7 +8,6 @@ import warnings from datetime import datetime, timedelta from typing import MutableSequence -from unittest.mock import MagicMock import pytest from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index 68a8e715ed..088e12270f 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -115,9 +115,7 @@ async def test_send_prompt_async_client_kwargs(): # Use **httpx_client_kwargs to pass them as keyword arguments http_target = HTTPTarget(http_request=sample_request, **httpx_client_kwargs) message = MagicMock() - message.message_pieces = [ - MagicMock(converted_value="", prompt_target_identifier=None, attack_identifier=None) - ] + message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None, attack_identifier=None)] mock_response = MagicMock() mock_response.content = b"Response content" mock_request.return_value = mock_response From d70eda84c335618df1b47b9c780e533febe0514c Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 04:45:09 -0800 Subject: [PATCH 14/35] fix: store labels as nested key in metadata instead of spreading --- pyrit/backend/services/attack_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index fe29b2be1e..87ab5b79d2 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -221,7 +221,7 @@ async def create_attack(self, request: CreateAttackRequest) -> CreateAttackRespo metadata={ "created_at": now.isoformat(), "updated_at": now.isoformat(), - **(request.labels or {}), + "labels": request.labels or {}, }, ) From 877f034d941cccc806592b7e64747aaee356da0f Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 04:47:18 -0800 Subject: [PATCH 15/35] fix missing labels in _build_summary --- pyrit/backend/services/attack_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 87ab5b79d2..bfa639b14e 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -343,6 +343,7 @@ def _build_summary(self, ar: AttackResult) -> AttackSummary: outcome=self._map_outcome(ar.outcome), last_message_preview=last_preview, message_count=message_count, + labels=ar.metadata.get("labels", {}), created_at=created_at, updated_at=updated_at, ) From 97ba8606ff1bad02b6142a212b6d8ec874a99aac Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 04:50:18 -0800 Subject: [PATCH 16/35] update to actually test labels, too --- tests/unit/backend/test_attack_service.py | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 7217740e7f..cb56b1da79 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -54,6 +54,7 @@ def make_attack_result( outcome: AttackOutcome = AttackOutcome.UNDETERMINED, created_at: datetime = None, updated_at: datetime = None, + labels: dict = None, ) -> AttackResult: """Create a mock AttackResult for testing.""" now = datetime.now(timezone.utc) @@ -72,6 +73,7 @@ def make_attack_result( metadata={ "created_at": created.isoformat(), "updated_at": updated.isoformat(), + "labels": labels or {}, }, ) @@ -216,6 +218,21 @@ async def test_list_attacks_filters_by_max_turns(self, attack_service, mock_memo assert len(result.items) == 1 assert result.items[0].attack_id == "attack-2" + @pytest.mark.asyncio + async def test_list_attacks_includes_labels_in_summary(self, attack_service, mock_memory) -> None: + """Test that list_attacks includes labels from metadata in summaries.""" + ar = make_attack_result( + conversation_id="attack-1", + labels={"env": "prod", "team": "red"}, + ) + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks() + + assert len(result.items) == 1 + assert result.items[0].labels == {"env": "prod", "team": "red"} + # ============================================================================ # Get Attack Tests @@ -348,6 +365,31 @@ async def test_create_attack_stores_prepended_conversation(self, attack_service, mock_memory.add_attack_results_to_memory.assert_called_once() mock_memory.add_message_pieces_to_memory.assert_called() + @pytest.mark.asyncio + async def test_create_attack_stores_labels_under_metadata_key(self, attack_service, mock_memory) -> None: + """Test that create_attack stores labels under metadata['labels'], not spread.""" + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_get_target_service.return_value = mock_target_service + + await attack_service.create_attack( + CreateAttackRequest( + target_id="target-1", + name="Labeled Attack", + labels={"env": "prod", "team": "red"}, + ) + ) + + # Verify the AttackResult stored in memory has labels nested under metadata["labels"] + call_args = mock_memory.add_attack_results_to_memory.call_args + stored_ar = call_args[1]["attack_results"][0] + assert "labels" in stored_ar.metadata + assert stored_ar.metadata["labels"] == {"env": "prod", "team": "red"} + # Labels should NOT be spread as top-level metadata keys + assert "env" not in stored_ar.metadata + assert "team" not in stored_ar.metadata + # ============================================================================ # Update Attack Tests From bd543b3ce1abe91e5de9a030d0f19d134473a2be Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 04:51:33 -0800 Subject: [PATCH 17/35] remove duplicate exception handler --- pyrit/backend/main.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index d7e032d510..cef98b0a2d 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -79,23 +79,6 @@ def setup_frontend() -> None: sys.exit(1) -@app.exception_handler(Exception) -async def global_exception_handler_async(request: object, exc: Exception) -> JSONResponse: - """ - Handle all unhandled exceptions globally. - - Note: This is a fallback handler. Most exceptions are handled by - the RFC 7807 error handlers in middleware/error_handlers.py. - - Returns: - JSONResponse: Error response with 500 status code. - """ - return JSONResponse( - status_code=500, - content={"detail": "Internal server error", "error": str(exc)}, - ) - - if __name__ == "__main__": import uvicorn From 86bd2cac4521626cfb11520e9a452bd166b2dffb Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 04:53:13 -0800 Subject: [PATCH 18/35] use lifespan instead of on_event --- pyrit/backend/main.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index cef98b0a2d..6aef78283c 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -7,6 +7,8 @@ import os import sys +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI @@ -22,24 +24,27 @@ # Check for development mode from environment variable DEV_MODE = os.getenv("PYRIT_DEV_MODE", "false").lower() == "true" + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Manage application startup and shutdown lifecycle.""" + # Startup: initialize PyRIT to load .env and .env.local files + await initialize_pyrit_async(memory_db_type="SQLite") + yield + # Shutdown: nothing to clean up currently + + app = FastAPI( title="PyRIT API", description="Python Risk Identification Tool for LLMs - REST API", version=pyrit.__version__, + lifespan=lifespan, ) # Register RFC 7807 error handlers register_error_handlers(app) -# Initialize PyRIT on startup to load .env and .env.local files -@app.on_event("startup") -async def startup_event_async() -> None: - """Initialize PyRIT on application startup.""" - # Use in-memory to avoid database initialization delays - await initialize_pyrit_async(memory_db_type="SQLite") - - # Configure CORS app.add_middleware( CORSMiddleware, From a1dd6bcdca7975271aec5c1117f69e147943752b Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 04:57:18 -0800 Subject: [PATCH 19/35] thread-safe singletons (lru_cache) --- pyrit/backend/services/attack_service.py | 9 +++------ pyrit/backend/services/converter_service.py | 9 +++------ pyrit/backend/services/target_service.py | 11 +++-------- tests/unit/backend/test_attack_service.py | 10 ++-------- tests/unit/backend/test_converter_service.py | 4 ++-- tests/unit/backend/test_target_service.py | 6 ++---- 6 files changed, 15 insertions(+), 34 deletions(-) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index bfa639b14e..826215e2a3 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -17,6 +17,7 @@ import uuid from datetime import datetime, timezone +from functools import lru_cache from typing import Any, Dict, List, Literal, Optional, cast from pyrit.backend.models.attacks import ( @@ -571,9 +572,8 @@ def _get_converter_configs(self, request: AddMessageRequest) -> List[PromptConve # Singleton # ============================================================================ -_attack_service: Optional[AttackService] = None - +@lru_cache(maxsize=1) def get_attack_service() -> AttackService: """ Get the global attack service instance. @@ -581,7 +581,4 @@ def get_attack_service() -> AttackService: Returns: The singleton AttackService instance. """ - global _attack_service - if _attack_service is None: - _attack_service = AttackService() - return _attack_service + return AttackService() diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 52947a63c5..f86b3bf3f7 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -13,6 +13,7 @@ """ import uuid +from functools import lru_cache from typing import Any, List, Optional, Tuple from pyrit import prompt_converter @@ -287,9 +288,8 @@ async def _apply_converters( # Singleton # ============================================================================ -_converter_service: Optional[ConverterService] = None - +@lru_cache(maxsize=1) def get_converter_service() -> ConverterService: """ Get the global converter service instance. @@ -297,7 +297,4 @@ def get_converter_service() -> ConverterService: Returns: The singleton ConverterService instance. """ - global _converter_service - if _converter_service is None: - _converter_service = ConverterService() - return _converter_service + return ConverterService() diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 23f94212f7..008779587e 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -13,6 +13,7 @@ """ import uuid +from functools import lru_cache from typing import Any, Optional from pyrit import prompt_target @@ -170,10 +171,7 @@ async def create_target(self, request: CreateTargetRequest) -> CreateTargetRespo ) -# Global service instance -_target_service: Optional[TargetService] = None - - +@lru_cache(maxsize=1) def get_target_service() -> TargetService: """ Get the global target service instance. @@ -181,7 +179,4 @@ def get_target_service() -> TargetService: Returns: The singleton TargetService instance. """ - global _target_service - if _target_service is None: - _target_service = TargetService() - return _target_service + return TargetService() diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index cb56b1da79..84d2874bfd 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -655,10 +655,7 @@ class TestAttackServiceSingleton: def test_get_attack_service_returns_attack_service(self) -> None: """Test that get_attack_service returns an AttackService instance.""" - # Reset singleton for clean test - import pyrit.backend.services.attack_service as module - - module._attack_service = None + get_attack_service.cache_clear() with patch("pyrit.backend.services.attack_service.CentralMemory"): service = get_attack_service() @@ -666,10 +663,7 @@ def test_get_attack_service_returns_attack_service(self) -> None: def test_get_attack_service_returns_same_instance(self) -> None: """Test that get_attack_service returns the same instance.""" - # Reset singleton for clean test - import pyrit.backend.services.attack_service as module - - module._attack_service = None + get_attack_service.cache_clear() with patch("pyrit.backend.services.attack_service.CentralMemory"): service1 = get_attack_service() diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 2b3e0576dd..110ea45cf3 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -333,14 +333,14 @@ class TestConverterServiceSingleton: def test_get_converter_service_returns_converter_service(self) -> None: """Test that get_converter_service returns a ConverterService instance.""" - converter_service_module._converter_service = None + get_converter_service.cache_clear() service = get_converter_service() assert isinstance(service, ConverterService) def test_get_converter_service_returns_same_instance(self) -> None: """Test that get_converter_service returns the same instance.""" - converter_service_module._converter_service = None + get_converter_service.cache_clear() service1 = get_converter_service() service2 = get_converter_service() diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 0b213ed728..75a39d817b 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -156,20 +156,18 @@ class TestTargetServiceSingleton: def test_get_target_service_returns_target_service(self) -> None: """Test that get_target_service returns a TargetService instance.""" - import pyrit.backend.services.target_service as module from pyrit.backend.services.target_service import get_target_service - module._target_service = None + get_target_service.cache_clear() service = get_target_service() assert isinstance(service, TargetService) def test_get_target_service_returns_same_instance(self) -> None: """Test that get_target_service returns the same instance.""" - import pyrit.backend.services.target_service as module from pyrit.backend.services.target_service import get_target_service - module._target_service = None + get_target_service.cache_clear() service1 = get_target_service() service2 = get_target_service() From 674eef57b55e0011758ed5ba56ab634cd204e22b Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 05:00:10 -0800 Subject: [PATCH 20/35] CORS origins from env var --- pyrit/backend/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index 6aef78283c..170d3b7d53 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -46,9 +46,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Configure CORS +_default_origins = "http://localhost:3000,http://localhost:5173" +_cors_origins = [o.strip() for o in os.getenv("PYRIT_CORS_ORIGINS", _default_origins).split(",") if o.strip()] + app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:3000", "http://localhost:5173"], # Vite default ports + allow_origins=_cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], From 973b77a17104984ccb5ea65084b860546d9e8714 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 05:02:38 -0800 Subject: [PATCH 21/35] add error response for list targets --- pyrit/backend/routes/targets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index 41b5ab3bdc..ce99ab3b89 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -25,6 +25,9 @@ @router.get( "", response_model=TargetListResponse, + responses={ + 500: {"model": ProblemDetail, "description": "Internal server error"}, + }, ) async def list_targets() -> TargetListResponse: """ From 2462125e9f53222c20741dc578a707d4df6be1dc Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 05:06:00 -0800 Subject: [PATCH 22/35] pagination for targets --- pyrit/backend/models/targets.py | 3 ++ pyrit/backend/routes/targets.py | 17 +++++--- pyrit/backend/services/target_service.py | 47 +++++++++++++++++--- tests/unit/backend/test_api_routes.py | 8 +++- tests/unit/backend/test_target_service.py | 53 +++++++++++++++++++++++ 5 files changed, 115 insertions(+), 13 deletions(-) diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index eaa151130c..fd4d607a3c 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -15,6 +15,8 @@ from pydantic import BaseModel, Field +from pyrit.backend.models.common import PaginationInfo + class TargetInstance(BaseModel): """ @@ -33,6 +35,7 @@ class TargetListResponse(BaseModel): """Response for listing target instances.""" items: List[TargetInstance] = Field(..., description="List of target instances") + pagination: PaginationInfo = Field(..., description="Pagination metadata") class CreateTargetRequest(BaseModel): diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index ce99ab3b89..36ac61f8b2 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -8,7 +8,9 @@ Target types are set at app startup via initializers - you cannot add new types at runtime. """ -from fastapi import APIRouter, HTTPException, status +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.targets import ( @@ -29,17 +31,20 @@ 500: {"model": ProblemDetail, "description": "Internal server error"}, }, ) -async def list_targets() -> TargetListResponse: +async def list_targets( + limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor (target_id)"), +) -> TargetListResponse: """ - List target instances. + List target instances with pagination. - Returns all registered target instances. + Returns paginated target instances. Returns: - TargetListResponse: List of target instances. + TargetListResponse: Paginated list of target instances. """ service = get_target_service() - return await service.list_targets() + return await service.list_targets(limit=limit, cursor=cursor) @router.post( diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 008779587e..baf2939c61 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -14,10 +14,10 @@ import uuid from functools import lru_cache -from typing import Any, Optional +from typing import Any, List, Optional from pyrit import prompt_target -from pyrit.backend.models.common import filter_sensitive_fields +from pyrit.backend.models.common import PaginationInfo, filter_sensitive_fields from pyrit.backend.models.targets import ( CreateTargetRequest, CreateTargetResponse, @@ -102,17 +102,52 @@ def _build_instance_from_object(self, target_id: str, target_obj: Any) -> Target params=filtered_params, ) - async def list_targets(self) -> TargetListResponse: + async def list_targets( + self, + *, + limit: int = 50, + cursor: Optional[str] = None, + ) -> TargetListResponse: """ - List all target instances. + List all target instances with pagination. + + Args: + limit: Maximum items to return. + cursor: Pagination cursor (target_id to start after). Returns: - TargetListResponse containing all registered targets. + TargetListResponse containing paginated targets. """ items = [ self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() ] - return TargetListResponse(items=items) + page, has_more = self._paginate(items, cursor, limit) + next_cursor = page[-1].target_id if has_more and page else None + return TargetListResponse( + items=page, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), + ) + + @staticmethod + def _paginate( + items: List[TargetInstance], cursor: Optional[str], limit: int + ) -> tuple[List[TargetInstance], bool]: + """ + Apply cursor-based pagination. + + Returns: + Tuple of (paginated items, has_more flag). + """ + start_idx = 0 + if cursor: + for i, item in enumerate(items): + if item.target_id == cursor: + start_idx = i + 1 + break + + page = items[start_idx : start_idx + limit] + has_more = len(items) > start_idx + limit + return page, has_more async def get_target(self, target_id: str) -> Optional[TargetInstance]: """ diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 420a3be142..07fe4582cd 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -429,7 +429,12 @@ def test_list_targets_returns_empty_list(self, client: TestClient) -> None: """Test that list targets returns empty list initially.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_targets = AsyncMock(return_value=TargetListResponse(items=[])) + mock_service.list_targets = AsyncMock( + return_value=TargetListResponse( + items=[], + pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) mock_get_service.return_value = mock_service response = client.get("/api/targets") @@ -437,6 +442,7 @@ def test_list_targets_returns_empty_list(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() assert data["items"] == [] + assert data["pagination"]["has_more"] is False def test_create_target_success(self, client: TestClient) -> None: """Test successful target creation.""" diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 75a39d817b..474fa6b724 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -33,6 +33,7 @@ async def test_list_targets_returns_empty_when_no_targets(self) -> None: result = await service.list_targets() assert result.items == [] + assert result.pagination.has_more is False @pytest.mark.asyncio async def test_list_targets_returns_targets_from_registry(self) -> None: @@ -49,6 +50,58 @@ async def test_list_targets_returns_targets_from_registry(self) -> None: assert len(result.items) == 1 assert result.items[0].target_id == "target-1" assert result.items[0].type == "MockTarget" + assert result.pagination.has_more is False + + @pytest.mark.asyncio + async def test_list_targets_paginates_with_limit(self) -> None: + """Test that list_targets respects the limit parameter.""" + service = TargetService() + + for i in range(5): + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + service._registry.register_instance(mock_target, name=f"target-{i}") + + result = await service.list_targets(limit=3) + + assert len(result.items) == 3 + assert result.pagination.limit == 3 + assert result.pagination.has_more is True + assert result.pagination.next_cursor == result.items[-1].target_id + + @pytest.mark.asyncio + async def test_list_targets_cursor_returns_next_page(self) -> None: + """Test that list_targets cursor skips to the correct position.""" + service = TargetService() + + for i in range(5): + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + service._registry.register_instance(mock_target, name=f"target-{i}") + + first_page = await service.list_targets(limit=2) + second_page = await service.list_targets(limit=2, cursor=first_page.pagination.next_cursor) + + assert len(second_page.items) == 2 + assert second_page.items[0].target_id != first_page.items[0].target_id + assert second_page.pagination.has_more is True + + @pytest.mark.asyncio + async def test_list_targets_last_page_has_no_more(self) -> None: + """Test that the last page has has_more=False and no next_cursor.""" + service = TargetService() + + for i in range(3): + mock_target = MagicMock() + mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + service._registry.register_instance(mock_target, name=f"target-{i}") + + first_page = await service.list_targets(limit=2) + last_page = await service.list_targets(limit=2, cursor=first_page.pagination.next_cursor) + + assert len(last_page.items) == 1 + assert last_page.pagination.has_more is False + assert last_page.pagination.next_cursor is None class TestGetTarget: From aaf133d46e1bcb57d52d8edb13d595566eabe1ec Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 05:18:51 -0800 Subject: [PATCH 23/35] mapping --- pyrit/backend/mappers/__init__.py | 35 ++ pyrit/backend/mappers/attack_mappers.py | 218 ++++++++++++ pyrit/backend/mappers/converter_mappers.py | 32 ++ pyrit/backend/mappers/target_mappers.py | 35 ++ pyrit/backend/services/attack_service.py | 215 ++---------- pyrit/backend/services/converter_service.py | 11 +- pyrit/backend/services/target_service.py | 13 +- tests/unit/backend/test_mappers.py | 350 ++++++++++++++++++++ 8 files changed, 696 insertions(+), 213 deletions(-) create mode 100644 pyrit/backend/mappers/__init__.py create mode 100644 pyrit/backend/mappers/attack_mappers.py create mode 100644 pyrit/backend/mappers/converter_mappers.py create mode 100644 pyrit/backend/mappers/target_mappers.py create mode 100644 tests/unit/backend/test_mappers.py diff --git a/pyrit/backend/mappers/__init__.py b/pyrit/backend/mappers/__init__.py new file mode 100644 index 0000000000..a2f7d4c029 --- /dev/null +++ b/pyrit/backend/mappers/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Backend mappers module. + +Pure mapping functions that translate between PyRIT domain models and backend API DTOs. +Centralizes all translation logic so domain models can evolve independently of the API contract. +""" + +from pyrit.backend.mappers.attack_mappers import ( + attack_result_to_summary, + map_outcome, + pyrit_messages_to_dto, + pyrit_scores_to_dto, + request_to_pyrit_message, + request_piece_to_pyrit_message_piece, +) +from pyrit.backend.mappers.converter_mappers import ( + converter_object_to_instance, +) +from pyrit.backend.mappers.target_mappers import ( + target_object_to_instance, +) + +__all__ = [ + "attack_result_to_summary", + "converter_object_to_instance", + "map_outcome", + "pyrit_messages_to_dto", + "pyrit_scores_to_dto", + "request_piece_to_pyrit_message_piece", + "request_to_pyrit_message", + "target_object_to_instance", +] diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py new file mode 100644 index 0000000000..f1a9bd847d --- /dev/null +++ b/pyrit/backend/mappers/attack_mappers.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Attack mappers – domain ↔ DTO translation for attack-related models. + +All functions are pure (no database or service calls) so they are easy to test. +The one exception is `attack_result_to_summary` which receives pre-fetched pieces. +""" + +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Literal, Optional, cast + +from pyrit.backend.models.attacks import ( + AddMessageRequest, + AttackSummary, + Message, + MessagePiece, + Score, +) +from pyrit.models import AttackOutcome, AttackResult, PromptDataType +from pyrit.models import Message as PyritMessage +from pyrit.models import MessagePiece as PyritMessagePiece + + +# ============================================================================ +# Domain → DTO (for API responses) +# ============================================================================ + + +def map_outcome(outcome: AttackOutcome) -> Optional[Literal["undetermined", "success", "failure"]]: + """ + Map AttackOutcome enum to API outcome string. + + Returns: + Outcome string ('success', 'failure', 'undetermined') or None. + """ + if outcome == AttackOutcome.SUCCESS: + return "success" + elif outcome == AttackOutcome.FAILURE: + return "failure" + else: + return "undetermined" + + +def attack_result_to_summary( + ar: AttackResult, + *, + pieces: List[Any], +) -> AttackSummary: + """ + Build an AttackSummary DTO from an AttackResult and its message pieces. + + Args: + ar: The domain AttackResult. + pieces: Pre-fetched message pieces for this conversation. + + Returns: + AttackSummary DTO ready for the API response. + """ + message_count = len(set(p.sequence for p in pieces)) + last_preview = _get_preview_from_pieces(pieces) + + created_str = ar.metadata.get("created_at") + updated_str = ar.metadata.get("updated_at") + created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) + updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at + + return AttackSummary( + attack_id=ar.conversation_id, + name=ar.attack_identifier.get("name"), + target_id=ar.attack_identifier.get("target_id", ""), + target_type=ar.attack_identifier.get("target_type", ""), + outcome=map_outcome(ar.outcome), + last_message_preview=last_preview, + message_count=message_count, + labels=ar.metadata.get("labels", {}), + created_at=created_at, + updated_at=updated_at, + ) + + +def pyrit_scores_to_dto(scores: List[Any]) -> List[Score]: + """ + Translate PyRIT score objects to backend Score DTOs. + + Returns: + List of Score DTOs for the API. + """ + return [ + Score( + score_id=str(s.id), + scorer_type=s.scorer_class_identifier.get("__type__", "unknown"), + score_value=s.score_value, + score_rationale=s.score_rationale, + scored_at=s.timestamp, + ) + for s in scores + ] + + +def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: + """ + Translate PyRIT messages to backend Message DTOs. + + Returns: + List of Message DTOs for the API. + """ + messages = [] + for msg in pyrit_messages: + pieces = [ + MessagePiece( + piece_id=str(p.id), + data_type=p.converted_value_data_type or "text", + original_value=p.original_value, + converted_value=p.converted_value or "", + scores=pyrit_scores_to_dto(p.scores) if hasattr(p, "scores") and p.scores else [], + response_error=p.response_error or "none", + ) + for p in msg.message_pieces + ] + + first = msg.message_pieces[0] if msg.message_pieces else None + messages.append( + Message( + message_id=str(first.id) if first else str(uuid.uuid4()), + turn_number=first.sequence if first else 0, + role=first.role if first else "user", + pieces=pieces, + created_at=first.timestamp if first else datetime.now(timezone.utc), + ) + ) + + return messages + + +# ============================================================================ +# DTO → Domain (for inbound requests) +# ============================================================================ + + +def request_piece_to_pyrit_message_piece( + *, + piece: Any, + role: str, + conversation_id: str, + sequence: int, +) -> PyritMessagePiece: + """ + Convert a single request piece DTO to a PyRIT MessagePiece domain object. + + Args: + piece: The request piece (with data_type, original_value, converted_value). + role: The message role. + conversation_id: The conversation/attack ID. + sequence: The message sequence number. + + Returns: + PyritMessagePiece domain object. + """ + return PyritMessagePiece( + role=role, + original_value=piece.original_value, + original_value_data_type=cast(PromptDataType, piece.data_type), + converted_value=piece.converted_value or piece.original_value, + converted_value_data_type=cast(PromptDataType, piece.data_type), + conversation_id=conversation_id, + sequence=sequence, + ) + + +def request_to_pyrit_message( + *, + request: AddMessageRequest, + conversation_id: str, + sequence: int, +) -> PyritMessage: + """ + Build a PyRIT Message from an AddMessageRequest DTO. + + Args: + request: The inbound API request. + conversation_id: The conversation/attack ID. + sequence: The message sequence number. + + Returns: + PyritMessage ready to send to the target. + """ + pieces = [ + request_piece_to_pyrit_message_piece( + piece=p, + role=request.role, + conversation_id=conversation_id, + sequence=sequence, + ) + for p in request.pieces + ] + return PyritMessage(pieces) + + +# ============================================================================ +# Private Helpers +# ============================================================================ + + +def _get_preview_from_pieces(pieces: List[Any]) -> Optional[str]: + """ + Get a preview of the last message from a list of pieces. + + Returns: + Truncated last message text, or None if no pieces. + """ + if not pieces: + return None + last_piece = max(pieces, key=lambda p: p.sequence) + text = last_piece.converted_value or "" + return text[:100] + "..." if len(text) > 100 else text diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py new file mode 100644 index 0000000000..292f73dd0c --- /dev/null +++ b/pyrit/backend/mappers/converter_mappers.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Converter mappers – domain → DTO translation for converter-related models. +""" + +from typing import Any + +from pyrit.backend.models.converters import ConverterInstance + + +def converter_object_to_instance(converter_id: str, converter_obj: Any) -> ConverterInstance: + """ + Build a ConverterInstance DTO from a registry converter object. + + Args: + converter_id: The unique converter instance identifier. + converter_obj: The domain PromptConverter object from the registry. + + Returns: + ConverterInstance DTO with metadata derived from the object. + """ + identifier = converter_obj.get_identifier() + identifier_dict = identifier.to_dict() + + return ConverterInstance( + converter_id=converter_id, + type=identifier_dict.get("class_name", converter_obj.__class__.__name__), + display_name=None, + params=identifier_dict, + ) diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py new file mode 100644 index 0000000000..40ac967637 --- /dev/null +++ b/pyrit/backend/mappers/target_mappers.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target mappers – domain → DTO translation for target-related models. +""" + +from typing import Any + +from pyrit.backend.models.common import filter_sensitive_fields +from pyrit.backend.models.targets import TargetInstance + + +def target_object_to_instance(target_id: str, target_obj: Any) -> TargetInstance: + """ + Build a TargetInstance DTO from a registry target object. + + Args: + target_id: The unique target instance identifier. + target_obj: The domain PromptTarget object from the registry. + + Returns: + TargetInstance DTO with metadata derived from the object. + """ + identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} + identifier_dict = identifier.to_dict() if hasattr(identifier, "to_dict") else identifier + target_type = identifier_dict.get("__type__", target_obj.__class__.__name__) + filtered_params = filter_sensitive_fields(identifier_dict) + + return TargetInstance( + target_id=target_id, + type=target_type, + display_name=None, + params=filtered_params, + ) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 826215e2a3..cb904e6e8e 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -18,7 +18,7 @@ import uuid from datetime import datetime, timezone from functools import lru_cache -from typing import Any, Dict, List, Literal, Optional, cast +from typing import Any, Dict, List, Literal, Optional from pyrit.backend.models.attacks import ( AddMessageRequest, @@ -28,18 +28,19 @@ AttackSummary, CreateAttackRequest, CreateAttackResponse, - Message, - MessagePiece, - Score, UpdateAttackRequest, ) from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.mappers.attack_mappers import ( + attack_result_to_summary, + pyrit_messages_to_dto, + request_piece_to_pyrit_message_piece, + request_to_pyrit_message, +) from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service from pyrit.memory import CentralMemory -from pyrit.models import AttackOutcome, AttackResult, PromptDataType -from pyrit.models import Message as PyritMessage -from pyrit.models import MessagePiece as PyritMessagePiece +from pyrit.models import AttackOutcome, AttackResult from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer @@ -116,7 +117,8 @@ async def list_attacks( if max_turns is not None and ar.executed_turns > max_turns: continue - summary = self._build_summary(ar) + pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) + summary = attack_result_to_summary(ar, pieces=pieces) summaries.append(summary) # Sort by most recent @@ -140,34 +142,13 @@ async def get_attack(self, attack_id: str) -> Optional[AttackSummary]: Returns: AttackSummary if found, None otherwise. """ - # Get the attack result results = self._memory.get_attack_results(conversation_id=attack_id) if not results: return None ar = results[0] - - # Get message count - pyrit_messages = self._memory.get_conversation(conversation_id=attack_id) - message_count = len(list(pyrit_messages)) - - created_str = ar.metadata.get("created_at") - updated_str = ar.metadata.get("updated_at") - created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) - updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at - - return AttackSummary( - attack_id=attack_id, - name=ar.attack_identifier.get("name"), - target_id=ar.attack_identifier.get("target_id", ""), - target_type=ar.attack_identifier.get("target_type", ""), - outcome=self._map_outcome(ar.outcome), - last_message_preview=self._get_last_message_preview(attack_id), - message_count=message_count, - labels=ar.metadata.get("labels", {}), - created_at=created_at, - updated_at=updated_at, - ) + pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) + return attack_result_to_summary(ar, pieces=pieces) async def get_attack_messages(self, attack_id: str) -> Optional[AttackMessagesResponse]: """ @@ -183,7 +164,7 @@ async def get_attack_messages(self, attack_id: str) -> Optional[AttackMessagesRe # Get messages for this conversation pyrit_messages = self._memory.get_conversation(conversation_id=attack_id) - backend_messages = self._translate_pyrit_messages_to_backend(list(pyrit_messages)) + backend_messages = pyrit_messages_to_dto(list(pyrit_messages)) return AttackMessagesResponse( attack_id=attack_id, @@ -311,72 +292,6 @@ async def add_message(self, attack_id: str, request: AddMessageRequest) -> AddMe return AddMessageResponse(attack=attack_detail, messages=attack_messages) - # ======================================================================== - # Private Helper Methods - Summary Building - # ======================================================================== - - def _build_summary(self, ar: AttackResult) -> AttackSummary: - """ - Build an AttackSummary from an AttackResult. - - Returns: - AttackSummary with message count and preview. - """ - # Get message count and last preview - pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) - message_count = len(set(p.sequence for p in pieces)) - last_preview = None - if pieces: - last_piece = max(pieces, key=lambda p: p.sequence) - text = last_piece.converted_value or "" - last_preview = text[:100] + "..." if len(text) > 100 else text - - created_str = ar.metadata.get("created_at") - updated_str = ar.metadata.get("updated_at") - created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) - updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at - - return AttackSummary( - attack_id=ar.conversation_id, - name=ar.attack_identifier.get("name"), - target_id=ar.attack_identifier.get("target_id", ""), - target_type=ar.attack_identifier.get("target_type", ""), - outcome=self._map_outcome(ar.outcome), - last_message_preview=last_preview, - message_count=message_count, - labels=ar.metadata.get("labels", {}), - created_at=created_at, - updated_at=updated_at, - ) - - def _map_outcome(self, outcome: AttackOutcome) -> Optional[Literal["undetermined", "success", "failure"]]: - """ - Map AttackOutcome enum to API outcome string. - - Returns: - Outcome string ('success', 'failure', 'undetermined') or None. - """ - if outcome == AttackOutcome.SUCCESS: - return "success" - elif outcome == AttackOutcome.FAILURE: - return "failure" - else: - return "undetermined" - - def _get_last_message_preview(self, conversation_id: str) -> Optional[str]: - """ - Get a preview of the last message in a conversation. - - Returns: - Truncated last message text, or None if no messages. - """ - pieces = self._memory.get_message_pieces(conversation_id=conversation_id) - if not pieces: - return None - last_piece = max(pieces, key=lambda p: p.sequence) - text = last_piece.converted_value or "" - return text[:100] + "..." if len(text) > 100 else text - # ======================================================================== # Private Helper Methods - Pagination # ======================================================================== @@ -401,64 +316,6 @@ def _paginate( has_more = len(items) > start_idx + limit return page, has_more - # ======================================================================== - # Private Helper Methods - Message Conversion - # ======================================================================== - - def _translate_pyrit_messages_to_backend(self, pyrit_messages: List[Any]) -> List[Message]: - """ - Translate PyRIT messages to backend Message format. - - Returns: - List of Message models for the API. - """ - messages = [] - for msg in pyrit_messages: - pieces = [ - MessagePiece( - piece_id=str(p.id), - data_type=p.converted_value_data_type or "text", - original_value=p.original_value, - converted_value=p.converted_value or "", - scores=self._translate_pyrit_scores_to_backend(p.scores) - if hasattr(p, "scores") and p.scores - else [], - response_error=p.response_error or "none", - ) - for p in msg.message_pieces - ] - - first = msg.message_pieces[0] if msg.message_pieces else None - messages.append( - Message( - message_id=str(first.id) if first else str(uuid.uuid4()), - turn_number=first.sequence if first else 0, - role=first.role if first else "user", - pieces=pieces, - created_at=first.timestamp if first else datetime.now(timezone.utc), - ) - ) - - return messages - - def _translate_pyrit_scores_to_backend(self, scores: List[Any]) -> List[Score]: - """ - Translate PyRIT scores to backend Score format. - - Returns: - List of Score models for the API. - """ - return [ - Score( - score_id=str(s.id), - scorer_type=s.scorer_class_identifier.get("__type__", "unknown"), - score_value=s.score_value, - score_rationale=s.score_rationale, - scored_at=s.timestamp, - ) - for s in scores - ] - # ======================================================================== # Private Helper Methods - Store Messages # ======================================================================== @@ -472,12 +329,9 @@ async def _store_prepended_messages( seq = 0 for msg in prepended: for p in msg.pieces: - piece = PyritMessagePiece( + piece = request_piece_to_pyrit_message_piece( + piece=p, role=msg.role, - original_value=p.original_value, - original_value_data_type=cast(PromptDataType, p.data_type), - converted_value=p.converted_value or p.original_value, - converted_value_data_type=cast(PromptDataType, p.data_type), conversation_id=conversation_id, sequence=seq, ) @@ -496,7 +350,11 @@ async def _send_and_store_message( if not target_obj: raise ValueError(f"Target object for '{target_id}' not found") - pyrit_message = self._build_pyrit_message(request, attack_id, sequence) + pyrit_message = request_to_pyrit_message( + request=request, + conversation_id=attack_id, + sequence=sequence, + ) converter_configs = self._get_converter_configs(request) normalizer = PromptNormalizer() @@ -516,43 +374,14 @@ async def _store_message_only( ) -> None: """Store message without sending (send=False).""" for p in request.pieces: - piece = PyritMessagePiece( + piece = request_piece_to_pyrit_message_piece( + piece=p, role=request.role, - original_value=p.original_value, - original_value_data_type=cast(PromptDataType, p.data_type), - converted_value=p.converted_value or p.original_value, - converted_value_data_type=cast(PromptDataType, p.data_type), conversation_id=attack_id, sequence=sequence, ) self._memory.add_message_pieces_to_memory(message_pieces=[piece]) - def _build_pyrit_message( - self, - request: AddMessageRequest, - conversation_id: str, - sequence: int, - ) -> PyritMessage: - """ - Build PyRIT Message from request. - - Returns: - PyritMessage ready to send to the target. - """ - pieces = [ - PyritMessagePiece( - role=request.role, - original_value=p.original_value, - original_value_data_type=cast(PromptDataType, p.data_type), - converted_value=p.converted_value or p.original_value, - converted_value_data_type=cast(PromptDataType, p.data_type), - conversation_id=conversation_id, - sequence=sequence, - ) - for p in request.pieces - ] - return PyritMessage(pieces) - def _get_converter_configs(self, request: AddMessageRequest) -> List[PromptConverterConfiguration]: """ Get converter configurations if needed. diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index f86b3bf3f7..93026b51a4 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -26,6 +26,7 @@ CreateConverterResponse, PreviewStep, ) +from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter from pyrit.registry.instance_registries import ConverterRegistry @@ -73,15 +74,7 @@ def _build_instance_from_object(self, converter_id: str, converter_obj: Any) -> Returns: ConverterInstance with metadata derived from the object's identifier. """ - identifier = converter_obj.get_identifier() - identifier_dict = identifier.to_dict() - - return ConverterInstance( - converter_id=converter_id, - type=identifier_dict.get("class_name", converter_obj.__class__.__name__), - display_name=None, - params=identifier_dict, - ) + return converter_object_to_instance(converter_id, converter_obj) # ======================================================================== # Public API Methods diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index baf2939c61..e656f85a32 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -24,6 +24,7 @@ TargetInstance, TargetListResponse, ) +from pyrit.backend.mappers.target_mappers import target_object_to_instance from pyrit.prompt_target import PromptTarget from pyrit.registry.instance_registries import TargetRegistry @@ -90,17 +91,7 @@ def _build_instance_from_object(self, target_id: str, target_obj: Any) -> Target Returns: TargetInstance with metadata derived from the object. """ - identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} - identifier_dict = identifier.to_dict() if hasattr(identifier, "to_dict") else identifier - target_type = identifier_dict.get("__type__", target_obj.__class__.__name__) - filtered_params = filter_sensitive_fields(identifier_dict) - - return TargetInstance( - target_id=target_id, - type=target_type, - display_name=None, # Could be added to identifier if needed - params=filtered_params, - ) + return target_object_to_instance(target_id, target_obj) async def list_targets( self, diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py new file mode 100644 index 0000000000..1671786e9f --- /dev/null +++ b/tests/unit/backend/test_mappers.py @@ -0,0 +1,350 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend mapper functions. + +These tests verify the domain ↔ DTO translation layer in isolation, +without any database or service dependencies. +""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from pyrit.backend.mappers.attack_mappers import ( + attack_result_to_summary, + map_outcome, + pyrit_messages_to_dto, + pyrit_scores_to_dto, + request_piece_to_pyrit_message_piece, + request_to_pyrit_message, +) +from pyrit.backend.mappers.converter_mappers import converter_object_to_instance +from pyrit.backend.mappers.target_mappers import target_object_to_instance +from pyrit.models import AttackOutcome, AttackResult + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_attack_result( + *, + conversation_id: str = "attack-1", + target_id: str = "target-1", + target_type: str = "TextTarget", + name: str = "Test Attack", + outcome: AttackOutcome = AttackOutcome.UNDETERMINED, + labels: dict = None, +) -> AttackResult: + """Create an AttackResult for mapper tests.""" + now = datetime.now(timezone.utc) + return AttackResult( + conversation_id=conversation_id, + objective="test", + attack_identifier={ + "name": name, + "target_id": target_id, + "target_type": target_type, + }, + outcome=outcome, + metadata={ + "created_at": now.isoformat(), + "updated_at": now.isoformat(), + "labels": labels or {}, + }, + ) + + +def _make_mock_piece( + *, + sequence: int = 0, + converted_value: str = "hello", + original_value: str = "hello", +): + """Create a mock message piece for mapper tests.""" + p = MagicMock() + p.id = "piece-1" + p.sequence = sequence + p.converted_value = converted_value + p.original_value = original_value + p.converted_value_data_type = "text" + p.response_error = "none" + p.role = "user" + p.timestamp = datetime.now(timezone.utc) + p.scores = [] + return p + + +def _make_mock_score(): + """Create a mock score for mapper tests.""" + s = MagicMock() + s.id = "score-1" + s.scorer_class_identifier = {"__type__": "TrueFalseScorer"} + s.score_value = 1.0 + s.score_rationale = "Looks correct" + s.timestamp = datetime.now(timezone.utc) + return s + + +# ============================================================================ +# Attack Mapper Tests +# ============================================================================ + + +class TestMapOutcome: + """Tests for map_outcome function.""" + + def test_maps_success(self) -> None: + assert map_outcome(AttackOutcome.SUCCESS) == "success" + + def test_maps_failure(self) -> None: + assert map_outcome(AttackOutcome.FAILURE) == "failure" + + def test_maps_undetermined(self) -> None: + assert map_outcome(AttackOutcome.UNDETERMINED) == "undetermined" + + +class TestAttackResultToSummary: + """Tests for attack_result_to_summary function.""" + + def test_basic_mapping(self) -> None: + """Test that all fields are mapped correctly.""" + ar = _make_attack_result(name="My Attack", target_id="t-1", target_type="OpenAIChatTarget") + pieces = [_make_mock_piece(sequence=0), _make_mock_piece(sequence=1)] + + summary = attack_result_to_summary(ar, pieces=pieces) + + assert summary.attack_id == ar.conversation_id + assert summary.name == "My Attack" + assert summary.target_id == "t-1" + assert summary.target_type == "OpenAIChatTarget" + assert summary.outcome == "undetermined" + assert summary.message_count == 2 + + def test_empty_pieces_gives_zero_messages(self) -> None: + """Test mapping with no message pieces.""" + ar = _make_attack_result() + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.message_count == 0 + assert summary.last_message_preview is None + + def test_last_message_preview_truncated(self) -> None: + """Test that long messages are truncated to 100 chars + ellipsis.""" + ar = _make_attack_result() + long_text = "x" * 200 + pieces = [_make_mock_piece(converted_value=long_text)] + + summary = attack_result_to_summary(ar, pieces=pieces) + + assert summary.last_message_preview is not None + assert len(summary.last_message_preview) == 103 # 100 + "..." + assert summary.last_message_preview.endswith("...") + + def test_labels_are_mapped(self) -> None: + """Test that labels are extracted from metadata.""" + ar = _make_attack_result(labels={"env": "prod", "team": "red"}) + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.labels == {"env": "prod", "team": "red"} + + def test_outcome_success(self) -> None: + """Test that success outcome is mapped.""" + ar = _make_attack_result(outcome=AttackOutcome.SUCCESS) + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.outcome == "success" + + +class TestPyritScoresToDto: + """Tests for pyrit_scores_to_dto function.""" + + def test_maps_scores(self) -> None: + """Test that scores are correctly translated.""" + mock_score = _make_mock_score() + + result = pyrit_scores_to_dto([mock_score]) + + assert len(result) == 1 + assert result[0].score_id == "score-1" + assert result[0].scorer_type == "TrueFalseScorer" + assert result[0].score_value == 1.0 + assert result[0].score_rationale == "Looks correct" + + def test_empty_scores(self) -> None: + """Test mapping empty scores list.""" + result = pyrit_scores_to_dto([]) + assert result == [] + + +class TestPyritMessagesToDto: + """Tests for pyrit_messages_to_dto function.""" + + def test_maps_single_message(self) -> None: + """Test mapping a single message with one piece.""" + piece = _make_mock_piece(original_value="hi", converted_value="hi") + msg = MagicMock() + msg.message_pieces = [piece] + + result = pyrit_messages_to_dto([msg]) + + assert len(result) == 1 + assert result[0].role == "user" + assert len(result[0].pieces) == 1 + assert result[0].pieces[0].original_value == "hi" + assert result[0].pieces[0].converted_value == "hi" + + def test_maps_empty_list(self) -> None: + """Test mapping an empty messages list.""" + result = pyrit_messages_to_dto([]) + assert result == [] + + +class TestRequestToPyritMessage: + """Tests for request_to_pyrit_message function.""" + + def test_converts_request_to_domain(self) -> None: + """Test that DTO request is correctly converted to domain message.""" + request = MagicMock() + request.role = "user" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + request.pieces = [piece] + + result = request_to_pyrit_message( + request=request, + conversation_id="conv-1", + sequence=0, + ) + + assert len(result.message_pieces) == 1 + assert result.message_pieces[0].original_value == "hello" + assert result.message_pieces[0].conversation_id == "conv-1" + assert result.message_pieces[0].sequence == 0 + + +class TestRequestPieceToPyritMessagePiece: + """Tests for request_piece_to_pyrit_message_piece function.""" + + def test_uses_converted_value_when_present(self) -> None: + """Test that converted_value is used when provided.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "original" + piece.converted_value = "converted" + + result = request_piece_to_pyrit_message_piece( + piece=piece, + role="assistant", + conversation_id="conv-1", + sequence=5, + ) + + assert result.original_value == "original" + assert result.converted_value == "converted" + assert result.api_role == "assistant" + assert result.sequence == 5 + + def test_falls_back_to_original_when_no_converted(self) -> None: + """Test that original_value is used when converted_value is None.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "fallback" + piece.converted_value = None + + result = request_piece_to_pyrit_message_piece( + piece=piece, + role="user", + conversation_id="conv-1", + sequence=0, + ) + + assert result.converted_value == "fallback" + + +# ============================================================================ +# Target Mapper Tests +# ============================================================================ + + +class TestTargetObjectToInstance: + """Tests for target_object_to_instance function.""" + + def test_maps_target_with_identifier(self) -> None: + """Test mapping a target object that has get_identifier.""" + target_obj = MagicMock() + target_obj.get_identifier.return_value = {"__type__": "OpenAIChatTarget", "endpoint": "http://test"} + + result = target_object_to_instance("t-1", target_obj) + + assert result.target_id == "t-1" + assert result.type == "OpenAIChatTarget" + assert result.display_name is None + + def test_filters_sensitive_fields(self) -> None: + """Test that sensitive fields are removed from params.""" + target_obj = MagicMock() + target_obj.get_identifier.return_value = { + "__type__": "TestTarget", + "api_key": "secret-key", + "endpoint": "http://test", + } + + result = target_object_to_instance("t-1", target_obj) + + assert "api_key" not in result.params + assert result.params.get("endpoint") == "http://test" + + def test_fallback_to_class_name(self) -> None: + """Test fallback to __class__.__name__ when no __type__ in identifier.""" + target_obj = MagicMock() + target_obj.__class__.__name__ = "FallbackTarget" + target_obj.get_identifier.return_value = {"endpoint": "http://test"} + + result = target_object_to_instance("t-1", target_obj) + + assert result.type == "FallbackTarget" + + +# ============================================================================ +# Converter Mapper Tests +# ============================================================================ + + +class TestConverterObjectToInstance: + """Tests for converter_object_to_instance function.""" + + def test_maps_converter_with_identifier(self) -> None: + """Test mapping a converter object.""" + converter_obj = MagicMock() + identifier = MagicMock() + identifier.to_dict.return_value = {"class_name": "Base64Converter", "param1": "value1"} + converter_obj.get_identifier.return_value = identifier + + result = converter_object_to_instance("c-1", converter_obj) + + assert result.converter_id == "c-1" + assert result.type == "Base64Converter" + assert result.display_name is None + assert result.params["class_name"] == "Base64Converter" + + def test_fallback_to_class_name(self) -> None: + """Test fallback to __class__.__name__ when no class_name in identifier.""" + converter_obj = MagicMock() + converter_obj.__class__.__name__ = "FallbackConverter" + identifier = MagicMock() + identifier.to_dict.return_value = {"param1": "value1"} + converter_obj.get_identifier.return_value = identifier + + result = converter_object_to_instance("c-1", converter_obj) + + assert result.type == "FallbackConverter" From 88b1a4f080473ca633e66aaebd725f42181b6382 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 10 Feb 2026 05:37:58 -0800 Subject: [PATCH 24/35] handle original_value_data_type --- pyrit/backend/mappers/attack_mappers.py | 3 ++- pyrit/backend/models/attacks.py | 3 ++- tests/unit/backend/test_attack_service.py | 2 ++ tests/unit/backend/test_mappers.py | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index f1a9bd847d..c4fe871951 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -112,7 +112,8 @@ def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: pieces = [ MessagePiece( piece_id=str(p.id), - data_type=p.converted_value_data_type or "text", + original_value_data_type=p.original_value_data_type or "text", + converted_value_data_type=p.converted_value_data_type or "text", original_value=p.original_value, converted_value=p.converted_value or "", scores=pyrit_scores_to_dto(p.scores) if hasattr(p, "scores") and p.scores else [], diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 8c4621ed64..215e9cb419 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -36,7 +36,8 @@ class MessagePiece(BaseModel): """ piece_id: str = Field(..., description="Unique piece identifier") - data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', 'video', etc.") + original_value_data_type: str = Field(default="text", description="Data type of the original value: 'text', 'image', 'audio', etc.") + converted_value_data_type: str = Field(default="text", description="Data type of the converted value: 'text', 'image', 'audio', etc.") original_value: Optional[str] = Field(default=None, description="Original value before conversion") original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") converted_value: str = Field(..., description="Converted value (text or base64 for media)") diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 84d2874bfd..d70a701b20 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -96,6 +96,7 @@ def make_mock_piece( piece.original_value = original_value piece.converted_value = converted_value piece.converted_value_data_type = "text" + piece.original_value_data_type = "text" piece.response_error = "none" piece.timestamp = timestamp or datetime.now(timezone.utc) piece.scores = [] @@ -622,6 +623,7 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic mock_piece = MagicMock() mock_piece.id = "piece-1" mock_piece.converted_value_data_type = "text" + mock_piece.original_value_data_type = "text" mock_piece.original_value = "Hello" mock_piece.converted_value = "Hello" mock_piece.response_error = None diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 1671786e9f..2faa44b05f 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -72,6 +72,7 @@ def _make_mock_piece( p.converted_value = converted_value p.original_value = original_value p.converted_value_data_type = "text" + p.original_value_data_type = "text" p.response_error = "none" p.role = "user" p.timestamp = datetime.now(timezone.utc) @@ -201,6 +202,19 @@ def test_maps_single_message(self) -> None: assert result[0].pieces[0].original_value == "hi" assert result[0].pieces[0].converted_value == "hi" + def test_maps_data_types_separately(self) -> None: + """Test that original and converted data types are mapped independently.""" + piece = _make_mock_piece(original_value="describe this", converted_value="base64data") + piece.original_value_data_type = "text" + piece.converted_value_data_type = "image" + msg = MagicMock() + msg.message_pieces = [piece] + + result = pyrit_messages_to_dto([msg]) + + assert result[0].pieces[0].original_value_data_type == "text" + assert result[0].pieces[0].converted_value_data_type == "image" + def test_maps_empty_list(self) -> None: """Test mapping an empty messages list.""" result = pyrit_messages_to_dto([]) From 881a2cb52ad0da9d1c53d0b81348c5ba18063e99 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 11 Feb 2026 00:14:14 -0800 Subject: [PATCH 25/35] fixing warnings --- .../attack/component/conversation_manager.py | 4 +- pyrit/executor/attack/core/attack_strategy.py | 7 ++- .../attack/printer/console_printer.py | 4 +- pyrit/memory/memory_models.py | 7 ++- pyrit/score/float_scale/float_scale_scorer.py | 4 +- tests/integration/mocks.py | 3 +- .../converter/test_persuasion_converter.py | 5 +- .../converter/test_translation_converter.py | 3 +- .../converter/test_variation_converter.py | 5 +- .../component/test_conversation_manager.py | 16 +++--- .../promptgen/fuzzer/test_fuzzer_converter.py | 5 +- .../test_interface_scenario_results.py | 50 ++++++++++++------- tests/unit/models/test_message_piece.py | 23 ++++++--- tests/unit/scenarios/test_jailbreak.py | 5 +- .../unit/scenarios/test_psychosocial_harms.py | 11 ++-- tests/unit/scenarios/test_scenario.py | 34 ++++++------- .../score/test_conversation_history_scorer.py | 6 +-- tests/unit/target/test_openai_chat_target.py | 25 +++++----- .../target/test_openai_response_target.py | 25 +++++----- tests/unit/target/test_prompt_target.py | 11 ++-- 20 files changed, 144 insertions(+), 109 deletions(-) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 6cffcf825a..6b6e648da3 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -4,7 +4,7 @@ import logging import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from pyrit.common.utils import combine_dict from pyrit.executor.attack.component.prepended_conversation_config import ( @@ -55,7 +55,7 @@ def get_adversarial_chat_messages( *, adversarial_chat_conversation_id: str, attack_identifier: AttackIdentifier, - adversarial_chat_target_identifier: Union[TargetIdentifier, Dict[str, Any]], + adversarial_chat_target_identifier: TargetIdentifier, labels: Optional[Dict[str, str]] = None, ) -> List[Message]: """ diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 9f7f433858..763004d5fe 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -259,8 +259,11 @@ def __init__( ) self._objective_target = objective_target self._params_type = params_type - self._request_converters: list[Any] = [] - self._response_converters: list[Any] = [] + # Guard so subclasses that set converters before calling super() aren't clobbered + if not hasattr(self, "_request_converters"): + self._request_converters: list[Any] = [] + if not hasattr(self, "_response_converters"): + self._response_converters: list[Any] = [] def _build_identifier(self) -> AttackIdentifier: """ diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 0dd162613c..c71b40b31c 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -258,10 +258,8 @@ async def print_summary_async(self, result: AttackResult) -> None: # Extract attack type name from attack_identifier attack_type = "Unknown" - if isinstance(result.attack_identifier, dict) and "__type__" in result.attack_identifier: + if result.attack_identifier: attack_type = result.attack_identifier.class_name - elif isinstance(result.attack_identifier, str): - attack_type = result.attack_identifier self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN) self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 5db7e4d8ae..5ae66e4d6c 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -256,6 +256,11 @@ def get_message_piece(self) -> MessagePiece: if self.prompt_target_identifier: target_id = TargetIdentifier.from_dict({**self.prompt_target_identifier, "pyrit_version": stored_version}) + # Reconstruct AttackIdentifier with the stored pyrit_version + attack_id: Optional[AttackIdentifier] = None + if self.attack_identifier: + attack_id = AttackIdentifier.from_dict({**self.attack_identifier, "pyrit_version": stored_version}) + message_piece = MessagePiece( role=self.role, original_value=self.original_value, @@ -270,7 +275,7 @@ def get_message_piece(self) -> MessagePiece: targeted_harm_categories=self.targeted_harm_categories, converter_identifiers=converter_ids, prompt_target_identifier=target_id, - attack_identifier=self.attack_identifier, + attack_identifier=attack_id, original_value_data_type=self.original_value_data_type, converted_value_data_type=self.converted_value_data_type, response_error=self.response_error, diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index b4ac4ddcfb..30650dc637 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Optional from uuid import UUID from pyrit.exceptions.exception_classes import InvalidJsonException @@ -76,7 +76,7 @@ async def _score_value_with_llm( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[Union[AttackIdentifier, Dict[str, str]]] = None, + attack_identifier: Optional[AttackIdentifier] = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 47efa4cae9..c924fca504 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -5,6 +5,7 @@ from sqlalchemy import inspect +from pyrit.identifiers import AttackIdentifier from pyrit.memory import MemoryInterface, SQLiteMemory from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -47,7 +48,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: self.system_prompt = system_prompt diff --git a/tests/unit/converter/test_persuasion_converter.py b/tests/unit/converter/test_persuasion_converter.py index 79c31b082d..9962945859 100644 --- a/tests/unit/converter/test_persuasion_converter.py +++ b/tests/unit/converter/test_persuasion_converter.py @@ -7,6 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import PersuasionConverter @@ -73,8 +74,8 @@ async def test_persuasion_converter_send_prompt_async_bad_json_exception_retries converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/converter/test_translation_converter.py b/tests/unit/converter/test_translation_converter.py index 80ba8b0da4..c8c26254a5 100644 --- a/tests/unit/converter/test_translation_converter.py +++ b/tests/unit/converter/test_translation_converter.py @@ -6,6 +6,7 @@ import pytest from unit.mocks import MockPromptTarget +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import TranslationConverter @@ -79,7 +80,7 @@ async def test_translation_converter_succeeds_after_retries(sqlite_instance): converted_value="hola", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "test-identifier"}, + prompt_target_identifier=TargetIdentifier(class_name="test-identifier", class_module="test"), sequence=1, ) ] diff --git a/tests/unit/converter/test_variation_converter.py b/tests/unit/converter/test_variation_converter.py index 023ca02b0d..e11f2e9642 100644 --- a/tests/unit/converter/test_variation_converter.py +++ b/tests/unit/converter/test_variation_converter.py @@ -7,6 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import VariationConverter @@ -45,8 +46,8 @@ async def test_variation_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index ecb3553758..1876c0cd46 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -246,7 +246,7 @@ def test_swaps_user_to_assistant(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -262,7 +262,7 @@ def test_swaps_assistant_to_user(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -281,7 +281,7 @@ def test_swaps_simulated_assistant_to_user(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -300,7 +300,7 @@ def test_skips_system_messages(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # Only user message should be present, system skipped @@ -317,7 +317,7 @@ def test_assigns_new_uuids(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # New ID should be different from original @@ -339,7 +339,7 @@ def test_preserves_message_content(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result[0].get_piece().original_value == "Original content" @@ -351,7 +351,7 @@ def test_empty_prepended_conversation(self) -> None: [], adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result == [] @@ -366,7 +366,7 @@ def test_applies_labels(self) -> None: messages, adversarial_chat_conversation_id="adversarial_conv", attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier={"id": "adversarial_target"}, + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels=labels, ) diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py index 25e49a1c10..3be6513502 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py @@ -14,6 +14,7 @@ FuzzerShortenConverter, FuzzerSimilarConverter, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece @@ -90,8 +91,8 @@ async def test_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 5319ba02d5..810300b98b 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -7,7 +7,7 @@ import pytest from unit.mocks import get_mock_scorer_identifier -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import MemoryInterface from pyrit.models import ( AttackOutcome, @@ -63,7 +63,7 @@ def create_scenario_result( return ScenarioResult( scenario_identifier=scenario_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results=attack_results, objective_scorer_identifier=scorer_identifier, ) @@ -278,7 +278,9 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): scenario_result = ScenarioResult( scenario_identifier=scenario_identifier, - objective_target_identifier={"target": "test_target", "endpoint": "https://example.com"}, + objective_target_identifier=TargetIdentifier( + class_name="test_target", class_module="test", endpoint="https://example.com" + ), attack_results={}, objective_scorer_identifier=scorer_identifier, ) @@ -360,7 +362,7 @@ def test_filter_by_labels(sqlite_instance: MemoryInterface, sample_attack_result scenario_identifier = ScenarioIdentifier(name="Labeled Scenario", scenario_version=1) scenario_result = ScenarioResult( scenario_identifier=scenario_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [sample_attack_results[0]]}, labels={"environment": "testing", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -384,7 +386,7 @@ def test_filter_by_multiple_labels(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Scenario 1", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, labels={"environment": "testing", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -393,7 +395,7 @@ def test_filter_by_multiple_labels(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="Scenario 2", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, labels={"environment": "production", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -422,7 +424,7 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Recent Scenario", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, completion_time=now, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -431,7 +433,7 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="Yesterday Scenario", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, completion_time=yesterday, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -440,7 +442,7 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): scenario3_identifier = ScenarioIdentifier(name="Old Scenario", scenario_version=1) scenario3 = ScenarioResult( scenario_identifier=scenario3_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack3": [attack_result3]}, completion_time=last_week, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -473,7 +475,7 @@ def test_filter_by_pyrit_version(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Old Version Scenario", scenario_version=1, pyrit_version="0.4.0") scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -481,7 +483,7 @@ def test_filter_by_pyrit_version(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="New Version Scenario", scenario_version=1, pyrit_version="0.5.0") scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -506,7 +508,9 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Azure Scenario", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "OpenAI", "endpoint": "https://myresource.openai.azure.com"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", endpoint="https://myresource.openai.azure.com" + ), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -514,7 +518,9 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="OpenAI Scenario", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "OpenAI", "endpoint": "https://api.openai.com/v1"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", endpoint="https://api.openai.com/v1" + ), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -522,7 +528,7 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): scenario3_identifier = ScenarioIdentifier(name="No Endpoint Scenario", scenario_version=1) scenario3 = ScenarioResult( scenario_identifier=scenario3_identifier, - objective_target_identifier={"target": "Local"}, + objective_target_identifier=TargetIdentifier(class_name="Local", class_module="test"), attack_results={"Attack3": [attack_result3]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -553,7 +559,7 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="GPT-4 Scenario", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "OpenAI", "model_name": "gpt-4-0613"}, + objective_target_identifier=TargetIdentifier(class_name="OpenAI", class_module="test", model_name="gpt-4-0613"), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -561,7 +567,7 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="GPT-4o Scenario", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "OpenAI", "model_name": "gpt-4o"}, + objective_target_identifier=TargetIdentifier(class_name="OpenAI", class_module="test", model_name="gpt-4o"), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -569,7 +575,9 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): scenario3_identifier = ScenarioIdentifier(name="GPT-3.5 Scenario", scenario_version=1) scenario3 = ScenarioResult( scenario_identifier=scenario3_identifier, - objective_target_identifier={"target": "OpenAI", "model_name": "gpt-3.5-turbo"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", model_name="gpt-3.5-turbo" + ), attack_results={"Attack3": [attack_result3]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -602,7 +610,9 @@ def test_combined_filters(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Test Scenario", scenario_version=1, pyrit_version="0.5.0") scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "OpenAI", "endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", endpoint="https://api.openai.com", model_name="gpt-4" + ), attack_results={"Attack1": [attack_result1]}, labels={"environment": "testing"}, completion_time=now, @@ -612,7 +622,9 @@ def test_combined_filters(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="Test Scenario", scenario_version=1, pyrit_version="0.4.0") scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "Azure", "endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + objective_target_identifier=TargetIdentifier( + class_name="Azure", class_module="test", endpoint="https://azure.com", model_name="gpt-3.5" + ), attack_results={"Attack2": [attack_result2]}, labels={"environment": "production"}, completion_time=yesterday, diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index dca380baad..4881b7070f 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -13,7 +13,7 @@ from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations from pyrit.executor.attack import PromptSendingAttack -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -663,14 +663,21 @@ def test_message_piece_to_dict(): targeted_harm_categories=["violence", "illegal"], prompt_metadata={"key": "metadata"}, converter_identifiers=[ - {"__type__": "Base64Converter", "__module__": "pyrit.prompt_converter.base64_converter"} + ConverterIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter.base64_converter", + supported_input_types=["text"], + supported_output_types=["text"], + ) ], - prompt_target_identifier={"__type__": "MockPromptTarget", "__module__": "unit.mocks"}, - attack_identifier={ - "id": str(uuid.uuid4()), - "__type__": "PromptSendingAttack", - "__module__": "pyrit.executor.attack.single_turn.prompt_sending_attack", - }, + prompt_target_identifier=TargetIdentifier( + class_name="MockPromptTarget", + class_module="unit.mocks", + ), + attack_identifier=AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending_attack", + ), scorer_identifier=ScorerIdentifier( class_name="TestScorer", class_module="pyrit.score.test_scorer", diff --git a/tests/unit/scenarios/test_jailbreak.py b/tests/unit/scenarios/test_jailbreak.py index 047334131c..c5c6f6b42d 100644 --- a/tests/unit/scenarios/test_jailbreak.py +++ b/tests/unit/scenarios/test_jailbreak.py @@ -10,6 +10,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedGroup, SeedObjective from pyrit.prompt_target import PromptTarget from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy @@ -43,7 +44,7 @@ def mock_memory_seed_groups() -> List[SeedGroup]: def mock_objective_target() -> PromptTarget: """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = TargetIdentifier(class_name="MockObjectiveTarget", class_module="test") return mock @@ -51,7 +52,7 @@ def mock_objective_target() -> PromptTarget: def mock_objective_scorer() -> TrueFalseInverterScorer: """Create a mock scorer for testing.""" mock = MagicMock(spec=TrueFalseInverterScorer) - mock.get_identifier.return_value = {"__type__": "MockObjectiveScorer", "__module__": "test"} + mock.get_identifier.return_value = ScorerIdentifier(class_name="MockObjectiveScorer", class_module="test") return mock diff --git a/tests/unit/scenarios/test_psychosocial_harms.py b/tests/unit/scenarios/test_psychosocial_harms.py index 4a178da7c1..8ecf25206d 100644 --- a/tests/unit/scenarios/test_psychosocial_harms.py +++ b/tests/unit/scenarios/test_psychosocial_harms.py @@ -14,6 +14,7 @@ PromptSendingAttack, RolePlayAttack, ) +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedDataset, SeedGroup, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget from pyrit.scenario.scenarios.airt import ( @@ -72,21 +73,21 @@ def mock_runtime_env(): @pytest.fixture def mock_objective_target() -> PromptChatTarget: mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = TargetIdentifier(class_name="MockObjectiveTarget", class_module="test") return mock @pytest.fixture def mock_objective_scorer() -> FloatScaleThresholdScorer: mock = MagicMock(spec=FloatScaleThresholdScorer) - mock.get_identifier.return_value = {"__type__": "MockObjectiveScorer", "__module__": "test"} + mock.get_identifier.return_value = ScorerIdentifier(class_name="MockObjectiveScorer", class_module="test") return mock @pytest.fixture def mock_adversarial_target() -> PromptChatTarget: mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = TargetIdentifier(class_name="MockAdversarialTarget", class_module="test") return mock @@ -173,7 +174,9 @@ def test_init_default_adversarial_chat(self, *, mock_objective_scorer: FloatScal def test_init_with_adversarial_chat(self, *, mock_objective_scorer: FloatScaleThresholdScorer) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = TargetIdentifier( + class_name="CustomAdversary", class_module="test" + ) scenario = PsychosocialScenario( adversarial_chat=adversarial_chat, diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 948fa899b8..804796167e 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -521,7 +521,7 @@ def test_scenario_result_initialization(self, sample_attack_results): identifier = ScenarioIdentifier(name="Test", scenario_version=1) result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={"__type__": "TestTarget", "__module__": "test"}, + objective_target_identifier=TargetIdentifier(class_name="TestTarget", class_module="test"), attack_results={"base64": sample_attack_results[:3], "rot13": sample_attack_results[3:]}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -537,10 +537,10 @@ def test_scenario_result_with_empty_results(self): identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1) result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={ - "__type__": "TestTarget", - "__module__": "test", - }, + objective_target_identifier=TargetIdentifier( + class_name="TestTarget", + class_module="test", + ), attack_results={"base64": []}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -555,10 +555,10 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): # All successful result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={ - "__type__": "TestTarget", - "__module__": "test", - }, + objective_target_identifier=TargetIdentifier( + class_name="TestTarget", + class_module="test", + ), attack_results={"base64": sample_attack_results}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -581,10 +581,10 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): ] result2 = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={ - "__type__": "TestTarget", - "__module__": "test", - }, + objective_target_identifier=TargetIdentifier( + class_name="TestTarget", + class_module="test", + ), attack_results={"base64": mixed_results}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -623,10 +623,10 @@ def create_mock_truefalse_scorer(): from pyrit.score import TrueFalseScorer mock_scorer = MagicMock(spec=TrueFalseScorer) - mock_scorer.get_identifier.return_value = { - "__type__": "MockTrueFalseScorer", - "__module__": "test", - } + mock_scorer.get_identifier.return_value = ScorerIdentifier( + class_name="MockTrueFalseScorer", + class_module="test", + ) mock_scorer.get_scorer_metrics.return_value = None # Make isinstance check work mock_scorer.__class__ = TrueFalseScorer diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index 0cdbd6cae0..dd27ac642f 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -7,7 +7,7 @@ import pytest -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score from pyrit.score import ( @@ -244,8 +244,8 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data original_value="Response", conversation_id=conversation_id, labels={"test": "label"}, - prompt_target_identifier={"target": "test"}, - attack_identifier={"attack": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="test", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), sequence=1, ) diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 6ba29eb2ea..e257aecb17 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -24,6 +24,7 @@ PyritException, RateLimitException, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig @@ -305,8 +306,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -316,8 +317,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -399,8 +400,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIChatT converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -410,8 +411,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIChatT converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -447,8 +448,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -458,8 +459,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index c60e1ca559..e6b083efe8 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -21,6 +21,7 @@ PyritException, RateLimitException, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig @@ -317,8 +318,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -328,8 +329,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -412,8 +413,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -423,8 +424,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -459,8 +460,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -470,8 +471,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/target/test_prompt_target.py b/tests/unit/target/test_prompt_target.py index e258d2e0c2..e6eb37bb35 100644 --- a/tests/unit/target/test_prompt_target.py +++ b/tests/unit/target/test_prompt_target.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import uuid from typing import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch @@ -9,6 +8,7 @@ from unit.mocks import get_sample_conversations, openai_chat_response_json_dict from pyrit.executor.attack.core.attack_strategy import AttackStrategy +from pyrit.identifiers import AttackIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_target import OpenAIChatTarget @@ -39,11 +39,10 @@ def mock_attack_strategy(): strategy = MagicMock(spec=AttackStrategy) strategy.execute_async = AsyncMock() strategy.execute_with_context_async = AsyncMock() - strategy.get_identifier.return_value = { - "__type__": "TestAttack", - "__module__": "pyrit.executor.attack.test_attack", - "id": str(uuid.uuid4()), - } + strategy.get_identifier.return_value = AttackIdentifier( + class_name="TestAttack", + class_module="pyrit.executor.attack.test_attack", + ) return strategy From 3d30e0517fa44d58e250032500afa419c73f80c2 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 11 Feb 2026 06:13:43 -0800 Subject: [PATCH 26/35] address copilot comments --- pyrit/backend/routes/attacks.py | 12 +- pyrit/backend/routes/converters.py | 8 +- pyrit/backend/routes/targets.py | 6 +- pyrit/backend/services/attack_service.py | 24 ++-- pyrit/backend/services/converter_service.py | 39 +++--- pyrit/backend/services/target_service.py | 18 +-- pyrit/cli/pyrit_backend.py | 10 +- tests/unit/backend/test_api_routes.py | 103 +++++++++------ tests/unit/backend/test_attack_service.py | 130 ++++++++++++++----- tests/unit/backend/test_converter_service.py | 49 ++++--- tests/unit/backend/test_main.py | 92 +++++++++++++ tests/unit/backend/test_target_service.py | 30 ++--- 12 files changed, 352 insertions(+), 169 deletions(-) create mode 100644 tests/unit/backend/test_main.py diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 30b0f060e9..73076c6e5f 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -70,7 +70,7 @@ async def list_attacks( """ service = get_attack_service() labels = _parse_labels(label) - return await service.list_attacks( + return await service.list_attacks_async( target_id=target_id, outcome=outcome, name=name, @@ -105,7 +105,7 @@ async def create_attack(request: CreateAttackRequest) -> CreateAttackResponse: service = get_attack_service() try: - return await service.create_attack(request) + return await service.create_attack_async(request=request) except ValueError as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -131,7 +131,7 @@ async def get_attack(attack_id: str) -> AttackSummary: """ service = get_attack_service() - attack = await service.get_attack(attack_id) + attack = await service.get_attack_async(attack_id=attack_id) if not attack: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -162,7 +162,7 @@ async def update_attack( """ service = get_attack_service() - attack = await service.update_attack(attack_id, request) + attack = await service.update_attack_async(attack_id=attack_id, request=request) if not attack: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -190,7 +190,7 @@ async def get_attack_messages(attack_id: str) -> AttackMessagesResponse: """ service = get_attack_service() - messages = await service.get_attack_messages(attack_id) + messages = await service.get_attack_messages_async(attack_id=attack_id) if not messages: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -230,7 +230,7 @@ async def add_message( service = get_attack_service() try: - return await service.add_message(attack_id, request) + return await service.add_message_async(attack_id=attack_id, request=request) except ValueError as e: error_msg = str(e) if "not found" in error_msg.lower(): diff --git a/pyrit/backend/routes/converters.py b/pyrit/backend/routes/converters.py index 45134a7adc..f4354ba50d 100644 --- a/pyrit/backend/routes/converters.py +++ b/pyrit/backend/routes/converters.py @@ -38,7 +38,7 @@ async def list_converters() -> ConverterInstanceListResponse: ConverterInstanceListResponse: List of converter instances. """ service = get_converter_service() - return await service.list_converters() + return await service.list_converters_async() @router.post( @@ -62,7 +62,7 @@ async def create_converter(request: CreateConverterRequest) -> CreateConverterRe service = get_converter_service() try: - return await service.create_converter(request) + return await service.create_converter_async(request=request) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -91,7 +91,7 @@ async def get_converter(converter_id: str) -> ConverterInstance: """ service = get_converter_service() - converter = await service.get_converter(converter_id) + converter = await service.get_converter_async(converter_id=converter_id) if not converter: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -121,7 +121,7 @@ async def preview_conversion(request: ConverterPreviewRequest) -> ConverterPrevi service = get_converter_service() try: - return await service.preview_conversion(request) + return await service.preview_conversion_async(request=request) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index 36ac61f8b2..fd7e04508c 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -44,7 +44,7 @@ async def list_targets( TargetListResponse: Paginated list of target instances. """ service = get_target_service() - return await service.list_targets(limit=limit, cursor=cursor) + return await service.list_targets_async(limit=limit, cursor=cursor) @router.post( @@ -70,7 +70,7 @@ async def create_target(request: CreateTargetRequest) -> CreateTargetResponse: service = get_target_service() try: - return await service.create_target(request) + return await service.create_target_async(request=request) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -99,7 +99,7 @@ async def get_target(target_id: str) -> TargetInstance: """ service = get_target_service() - target = await service.get_target(target_id) + target = await service.get_target_async(target_id=target_id) if not target: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index cb904e6e8e..d41ccdfab7 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -59,7 +59,7 @@ def __init__(self) -> None: # Public API Methods # ======================================================================== - async def list_attacks( + async def list_attacks_async( self, *, target_id: Optional[str] = None, @@ -133,7 +133,7 @@ async def list_attacks( pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) - async def get_attack(self, attack_id: str) -> Optional[AttackSummary]: + async def get_attack_async(self, *, attack_id: str) -> Optional[AttackSummary]: """ Get attack details (high-level metadata, no messages). @@ -150,7 +150,7 @@ async def get_attack(self, attack_id: str) -> Optional[AttackSummary]: pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) return attack_result_to_summary(ar, pieces=pieces) - async def get_attack_messages(self, attack_id: str) -> Optional[AttackMessagesResponse]: + async def get_attack_messages_async(self, *, attack_id: str) -> Optional[AttackMessagesResponse]: """ Get all messages for an attack. @@ -171,7 +171,7 @@ async def get_attack_messages(self, attack_id: str) -> Optional[AttackMessagesRe messages=backend_messages, ) - async def create_attack(self, request: CreateAttackRequest) -> CreateAttackResponse: + async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAttackResponse: """ Create a new attack. @@ -181,7 +181,7 @@ async def create_attack(self, request: CreateAttackRequest) -> CreateAttackRespo CreateAttackResponse with the new attack's ID and creation time. """ target_service = get_target_service() - target_instance = await target_service.get_target(request.target_id) + target_instance = await target_service.get_target_async(target_id=request.target_id) if not target_instance: raise ValueError(f"Target instance '{request.target_id}' not found") @@ -219,7 +219,7 @@ async def create_attack(self, request: CreateAttackRequest) -> CreateAttackRespo return CreateAttackResponse(attack_id=conversation_id, created_at=now) - async def update_attack(self, attack_id: str, request: UpdateAttackRequest) -> Optional[AttackSummary]: + async def update_attack_async(self, *, attack_id: str, request: UpdateAttackRequest) -> Optional[AttackSummary]: """ Update an attack's outcome. @@ -249,9 +249,9 @@ async def update_attack(self, attack_id: str, request: UpdateAttackRequest) -> O # Re-add to memory (this should update) self._memory.add_attack_results_to_memory(attack_results=[ar]) - return await self.get_attack(attack_id) + return await self.get_attack_async(attack_id=attack_id) - async def add_message(self, attack_id: str, request: AddMessageRequest) -> AddMessageResponse: + async def add_message_async(self, *, attack_id: str, request: AddMessageRequest) -> AddMessageResponse: """ Add a message to an attack, optionally sending to target. @@ -282,11 +282,11 @@ async def add_message(self, attack_id: str, request: AddMessageRequest) -> AddMe # Update attack timestamp ar.metadata["updated_at"] = datetime.now(timezone.utc).isoformat() - attack_detail = await self.get_attack(attack_id) + attack_detail = await self.get_attack_async(attack_id=attack_id) if attack_detail is None: raise ValueError(f"Attack '{attack_id}' not found after update") - attack_messages = await self.get_attack_messages(attack_id) + attack_messages = await self.get_attack_messages_async(attack_id=attack_id) if attack_messages is None: raise ValueError(f"Attack '{attack_id}' messages not found after update") @@ -346,7 +346,7 @@ async def _send_and_store_message( sequence: int, ) -> None: """Send message to target via normalizer and store response.""" - target_obj = get_target_service().get_target_object(target_id) + target_obj = get_target_service().get_target_object(target_id=target_id) if not target_obj: raise ValueError(f"Target object for '{target_id}' not found") @@ -393,7 +393,7 @@ def _get_converter_configs(self, request: AddMessageRequest) -> List[PromptConve if has_preconverted or not request.converter_ids: return [] - converters = get_converter_service().get_converter_objects_for_ids(request.converter_ids) + converters = get_converter_service().get_converter_objects_for_ids(converter_ids=request.converter_ids) return PromptConverterConfiguration.from_converters(converters=converters) diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 93026b51a4..3c1049e103 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -65,7 +65,7 @@ def __init__(self) -> None: """Initialize the converter service.""" self._registry = ConverterRegistry.get_registry_singleton() - def _build_instance_from_object(self, converter_id: str, converter_obj: Any) -> ConverterInstance: + def _build_instance_from_object(self, *, converter_id: str, converter_obj: Any) -> ConverterInstance: """ Build a ConverterInstance from a registry object. @@ -80,7 +80,7 @@ def _build_instance_from_object(self, converter_id: str, converter_obj: Any) -> # Public API Methods # ======================================================================== - async def list_converters(self) -> ConverterInstanceListResponse: + async def list_converters_async(self) -> ConverterInstanceListResponse: """ List all converter instances. @@ -88,11 +88,11 @@ async def list_converters(self) -> ConverterInstanceListResponse: ConverterInstanceListResponse containing all registered converters. """ items = [ - self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() + self._build_instance_from_object(converter_id=name, converter_obj=obj) for name, obj in self._registry.get_all_instances().items() ] return ConverterInstanceListResponse(items=items) - async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: + async def get_converter_async(self, *, converter_id: str) -> Optional[ConverterInstance]: """ Get a converter instance by ID. @@ -102,9 +102,9 @@ async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: obj = self._registry.get_instance_by_name(converter_id) if obj is None: return None - return self._build_instance_from_object(converter_id, obj) + return self._build_instance_from_object(converter_id=converter_id, converter_obj=obj) - def get_converter_object(self, converter_id: str) -> Optional[Any]: + def get_converter_object(self, *, converter_id: str) -> Optional[Any]: """ Get the actual converter object. @@ -113,7 +113,7 @@ def get_converter_object(self, converter_id: str) -> Optional[Any]: """ return self._registry.get_instance_by_name(converter_id) - async def create_converter(self, request: CreateConverterRequest) -> CreateConverterResponse: + async def create_converter_async(self, *, request: CreateConverterRequest) -> CreateConverterResponse: """ Create a new converter instance from API request. @@ -132,8 +132,8 @@ async def create_converter(self, request: CreateConverterRequest) -> CreateConve converter_id = str(uuid.uuid4()) # Resolve any converter references in params and instantiate - params = self._resolve_converter_params(request.params) - converter_class = self._get_converter_class(request.type) + params = self._resolve_converter_params(params=request.params) + converter_class = self._get_converter_class(converter_type=request.type) converter_obj = converter_class(**params) self._registry.register_instance(converter_obj, name=converter_id) @@ -144,16 +144,16 @@ async def create_converter(self, request: CreateConverterRequest) -> CreateConve params=request.params, ) - async def preview_conversion(self, request: ConverterPreviewRequest) -> ConverterPreviewResponse: + async def preview_conversion_async(self, *, request: ConverterPreviewRequest) -> ConverterPreviewResponse: """ Preview conversion through a converter pipeline. Returns: ConverterPreviewResponse with step-by-step conversion results. """ - converters = self._gather_converters(request.converter_ids) + converters = self._gather_converters(converter_ids=request.converter_ids) steps, final_value, final_type = await self._apply_converters( - converters, request.original_value, request.original_value_data_type + converters=converters, initial_value=request.original_value, initial_type=request.original_value_data_type ) return ConverterPreviewResponse( @@ -164,7 +164,7 @@ async def preview_conversion(self, request: ConverterPreviewRequest) -> Converte steps=steps, ) - def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: + def get_converter_objects_for_ids(self, *, converter_ids: List[str]) -> List[Any]: """ Get converter objects for a list of IDs. @@ -173,7 +173,7 @@ def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: """ converters = [] for conv_id in converter_ids: - conv_obj = self.get_converter_object(conv_id) + conv_obj = self.get_converter_object(converter_id=conv_id) if conv_obj is None: raise ValueError(f"Converter instance '{conv_id}' not found") converters.append(conv_obj) @@ -183,7 +183,7 @@ def get_converter_objects_for_ids(self, converter_ids: List[str]) -> List[Any]: # Private Helper Methods # ======================================================================== - def _get_converter_class(self, converter_type: str) -> type: + def _get_converter_class(self, *, converter_type: str) -> type: """ Get the converter class for a given type name. @@ -206,7 +206,7 @@ def _get_converter_class(self, converter_type: str) -> type: ) return cls - def _resolve_converter_params(self, params: dict[str, Any]) -> dict[str, Any]: + def _resolve_converter_params(self, *, params: dict[str, Any]) -> dict[str, Any]: """ Resolve converter references in params. @@ -220,13 +220,13 @@ def _resolve_converter_params(self, params: dict[str, Any]) -> dict[str, Any]: if "converter" in resolved and isinstance(resolved["converter"], dict): ref = resolved["converter"] if "converter_id" in ref: - conv_obj = self.get_converter_object(ref["converter_id"]) + conv_obj = self.get_converter_object(converter_id=ref["converter_id"]) if conv_obj is None: raise ValueError(f"Referenced converter '{ref['converter_id']}' not found") resolved["converter"] = conv_obj return resolved - def _gather_converters(self, converter_ids: List[str]) -> List[Tuple[str, str, Any]]: + def _gather_converters(self, *, converter_ids: List[str]) -> List[Tuple[str, str, Any]]: """ Gather converters to apply from IDs. @@ -235,7 +235,7 @@ def _gather_converters(self, converter_ids: List[str]) -> List[Tuple[str, str, A """ converters: List[Tuple[str, str, Any]] = [] for conv_id in converter_ids: - conv_obj = self.get_converter_object(conv_id) + conv_obj = self.get_converter_object(converter_id=conv_id) if conv_obj is None: raise ValueError(f"Converter instance '{conv_id}' not found") conv_type = conv_obj.__class__.__name__ @@ -244,6 +244,7 @@ def _gather_converters(self, converter_ids: List[str]) -> List[Tuple[str, str, A async def _apply_converters( self, + *, converters: List[Tuple[str, str, Any]], initial_value: str, initial_type: PromptDataType, diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index e656f85a32..1e662d0eec 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -62,7 +62,7 @@ def __init__(self) -> None: """Initialize the target service.""" self._registry = TargetRegistry.get_registry_singleton() - def _get_target_class(self, target_type: str) -> type: + def _get_target_class(self, *, target_type: str) -> type: """ Get the target class for a given type name. @@ -84,7 +84,7 @@ def _get_target_class(self, target_type: str) -> type: ) return cls - def _build_instance_from_object(self, target_id: str, target_obj: Any) -> TargetInstance: + def _build_instance_from_object(self, *, target_id: str, target_obj: Any) -> TargetInstance: """ Build a TargetInstance from a registry object. @@ -93,7 +93,7 @@ def _build_instance_from_object(self, target_id: str, target_obj: Any) -> Target """ return target_object_to_instance(target_id, target_obj) - async def list_targets( + async def list_targets_async( self, *, limit: int = 50, @@ -110,7 +110,7 @@ async def list_targets( TargetListResponse containing paginated targets. """ items = [ - self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() + self._build_instance_from_object(target_id=name, target_obj=obj) for name, obj in self._registry.get_all_instances().items() ] page, has_more = self._paginate(items, cursor, limit) next_cursor = page[-1].target_id if has_more and page else None @@ -140,7 +140,7 @@ def _paginate( has_more = len(items) > start_idx + limit return page, has_more - async def get_target(self, target_id: str) -> Optional[TargetInstance]: + async def get_target_async(self, *, target_id: str) -> Optional[TargetInstance]: """ Get a target instance by ID. @@ -150,9 +150,9 @@ async def get_target(self, target_id: str) -> Optional[TargetInstance]: obj = self._registry.get_instance_by_name(target_id) if obj is None: return None - return self._build_instance_from_object(target_id, obj) + return self._build_instance_from_object(target_id=target_id, target_obj=obj) - def get_target_object(self, target_id: str) -> Optional[Any]: + def get_target_object(self, *, target_id: str) -> Optional[Any]: """ Get the actual target object for use in attacks. @@ -161,7 +161,7 @@ def get_target_object(self, target_id: str) -> Optional[Any]: """ return self._registry.get_instance_by_name(target_id) - async def create_target(self, request: CreateTargetRequest) -> CreateTargetResponse: + async def create_target_async(self, *, request: CreateTargetRequest) -> CreateTargetResponse: """ Create a new target instance from API request. @@ -180,7 +180,7 @@ async def create_target(self, request: CreateTargetRequest) -> CreateTargetRespo target_id = str(uuid.uuid4()) # Instantiate from request params and register - target_class = self._get_target_class(request.type) + target_class = self._get_target_class(target_type=request.type) target_obj = target_class(**request.params) self._registry.register_instance(target_obj, name=target_id) diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index 657569ae9d..1043135958 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -15,7 +15,7 @@ from pyrit.cli import frontend_core -def parse_args(args: Optional[list[str]] = None) -> Namespace: +def parse_args(*, args: Optional[list[str]] = None) -> Namespace: """ Parse command-line arguments for the PyRIT backend server. @@ -112,7 +112,7 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: return parser.parse_args(args) -async def initialize_and_run(parsed_args: Namespace) -> int: +async def initialize_and_run(*, parsed_args: Namespace) -> int: """ Initialize PyRIT and start the backend server. @@ -184,7 +184,7 @@ async def initialize_and_run(parsed_args: Namespace) -> int: return 0 -def main(args: Optional[list[str]] = None) -> int: +def main(*, args: Optional[list[str]] = None) -> int: """ Start the PyRIT backend server CLI. @@ -192,7 +192,7 @@ def main(args: Optional[list[str]] = None) -> int: int: Exit code (0 for success, 1 for error). """ try: - parsed_args = parse_args(args) + parsed_args = parse_args(args=args) except SystemExit as e: return e.code if isinstance(e.code, int) else 1 @@ -204,7 +204,7 @@ def main(args: Optional[list[str]] = None) -> int: # Run the server try: - return asyncio.run(initialize_and_run(parsed_args)) + return asyncio.run(initialize_and_run(parsed_args=parsed_args)) except KeyboardInterrupt: print("\n🛑 Backend stopped") return 0 diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 07fe4582cd..4d53c39b4d 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -58,7 +58,7 @@ def test_list_attacks_returns_empty_list(self, client: TestClient) -> None: """Test that list attacks returns empty list initially.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_attacks = AsyncMock( + mock_service.list_attacks_async = AsyncMock( return_value=AttackListResponse( items=[], pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), @@ -76,7 +76,7 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: """Test that list attacks accepts filter parameters.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_attacks = AsyncMock( + mock_service.list_attacks_async = AsyncMock( return_value=AttackListResponse( items=[], pagination=PaginationInfo(limit=10, has_more=False, next_cursor=None, prev_cursor=None), @@ -90,7 +90,7 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: ) assert response.status_code == status.HTTP_200_OK - mock_service.list_attacks.assert_called_once_with( + mock_service.list_attacks_async.assert_called_once_with( target_id="t1", outcome="success", name=None, @@ -107,7 +107,7 @@ def test_create_attack_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_attack = AsyncMock( + mock_service.create_attack_async = AsyncMock( return_value=CreateAttackResponse( attack_id="attack-1", created_at=now, @@ -128,7 +128,7 @@ def test_create_attack_target_not_found(self, client: TestClient) -> None: """Test attack creation with non-existent target.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_attack = AsyncMock(side_effect=ValueError("Target not found")) + mock_service.create_attack_async = AsyncMock(side_effect=ValueError("Target not found")) mock_get_service.return_value = mock_service response = client.post( @@ -144,7 +144,7 @@ def test_get_attack_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_attack = AsyncMock( + mock_service.get_attack_async = AsyncMock( return_value=AttackSummary( attack_id="attack-1", name="Test", @@ -169,7 +169,7 @@ def test_get_attack_not_found(self, client: TestClient) -> None: """Test getting a non-existent attack.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_attack = AsyncMock(return_value=None) + mock_service.get_attack_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service response = client.get("/api/attacks/nonexistent") @@ -182,7 +182,7 @@ def test_update_attack_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.update_attack = AsyncMock( + mock_service.update_attack_async = AsyncMock( return_value=AttackSummary( attack_id="attack-1", name=None, @@ -254,7 +254,7 @@ def test_add_message_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock( + mock_service.add_message_async = AsyncMock( return_value=AddMessageResponse( attack=attack_summary, messages=attack_messages, @@ -275,7 +275,7 @@ def test_update_attack_not_found(self, client: TestClient) -> None: """Test updating a non-existent attack returns 404.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.update_attack = AsyncMock(return_value=None) + mock_service.update_attack_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service response = client.patch( @@ -289,7 +289,7 @@ def test_add_message_attack_not_found(self, client: TestClient) -> None: """Test adding message to non-existent attack returns 404.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock(side_effect=ValueError("Attack 'nonexistent' not found")) + mock_service.add_message_async = AsyncMock(side_effect=ValueError("Attack 'nonexistent' not found")) mock_get_service.return_value = mock_service response = client.post( @@ -303,7 +303,7 @@ def test_add_message_target_not_found(self, client: TestClient) -> None: """Test adding message when target object not found returns 404.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock(side_effect=ValueError("Target object for 'target-1' not found")) + mock_service.add_message_async = AsyncMock(side_effect=ValueError("Target object for 'target-1' not found")) mock_get_service.return_value = mock_service response = client.post( @@ -317,7 +317,7 @@ def test_add_message_bad_request(self, client: TestClient) -> None: """Test adding message with invalid request returns 400.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock(side_effect=ValueError("Invalid message format")) + mock_service.add_message_async = AsyncMock(side_effect=ValueError("Invalid message format")) mock_get_service.return_value = mock_service response = client.post( @@ -331,7 +331,7 @@ def test_add_message_internal_error(self, client: TestClient) -> None: """Test adding message when internal error occurs returns 500.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.add_message = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) + mock_service.add_message_async = AsyncMock(side_effect=RuntimeError("Unexpected internal error")) mock_get_service.return_value = mock_service response = client.post( @@ -347,7 +347,7 @@ def test_get_attack_messages_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_attack_messages = AsyncMock( + mock_service.get_attack_messages_async = AsyncMock( return_value=AttackMessagesResponse( attack_id="attack-1", messages=[ @@ -374,7 +374,7 @@ def test_get_attack_messages_not_found(self, client: TestClient) -> None: """Test getting messages for non-existent attack returns 404.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_attack_messages = AsyncMock(return_value=None) + mock_service.get_attack_messages_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service response = client.get("/api/attacks/nonexistent/messages") @@ -387,7 +387,7 @@ def test_list_attacks_with_labels(self, client: TestClient) -> None: with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_attacks = AsyncMock( + mock_service.list_attacks_async = AsyncMock( return_value=AttackListResponse( items=[ AttackSummary( @@ -412,8 +412,8 @@ def test_list_attacks_with_labels(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK # Verify labels were parsed and passed to service - mock_service.list_attacks.assert_called_once() - call_kwargs = mock_service.list_attacks.call_args[1] + mock_service.list_attacks_async.assert_called_once() + call_kwargs = mock_service.list_attacks_async.call_args[1] assert call_kwargs["labels"] == {"env": "prod", "team": "red"} @@ -429,7 +429,7 @@ def test_list_targets_returns_empty_list(self, client: TestClient) -> None: """Test that list targets returns empty list initially.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_targets = AsyncMock( + mock_service.list_targets_async = AsyncMock( return_value=TargetListResponse( items=[], pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), @@ -448,7 +448,7 @@ def test_create_target_success(self, client: TestClient) -> None: """Test successful target creation.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_target = AsyncMock( + mock_service.create_target_async = AsyncMock( return_value=CreateTargetResponse( target_id="target-1", type="TextTarget", @@ -471,7 +471,7 @@ def test_create_target_invalid_type(self, client: TestClient) -> None: """Test target creation with invalid type.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_target = AsyncMock(side_effect=ValueError("Target type not found")) + mock_service.create_target_async = AsyncMock(side_effect=ValueError("Target type not found")) mock_get_service.return_value = mock_service response = client.post( @@ -485,7 +485,7 @@ def test_create_target_internal_error(self, client: TestClient) -> None: """Test target creation with internal error returns 500.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_target = AsyncMock(side_effect=RuntimeError("Unexpected error")) + mock_service.create_target_async = AsyncMock(side_effect=RuntimeError("Unexpected error")) mock_get_service.return_value = mock_service response = client.post( @@ -497,11 +497,9 @@ def test_create_target_internal_error(self, client: TestClient) -> None: def test_get_target_success(self, client: TestClient) -> None: """Test getting a target by ID.""" - now = datetime.now(timezone.utc) - with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_target = AsyncMock( + mock_service.get_target_async = AsyncMock( return_value=TargetInstance( target_id="target-1", type="TextTarget", @@ -521,7 +519,7 @@ def test_get_target_not_found(self, client: TestClient) -> None: """Test getting a non-existent target.""" with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_target = AsyncMock(return_value=None) + mock_service.get_target_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service response = client.get("/api/targets/nonexistent") @@ -541,7 +539,7 @@ def test_list_converters(self, client: TestClient) -> None: """Test listing converter instances.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.list_converters = AsyncMock(return_value=ConverterInstanceListResponse(items=[])) + mock_service.list_converters_async = AsyncMock(return_value=ConverterInstanceListResponse(items=[])) mock_get_service.return_value = mock_service response = client.get("/api/converters") @@ -554,7 +552,7 @@ def test_create_converter_success(self, client: TestClient) -> None: """Test successful converter instance creation.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_converter = AsyncMock( + mock_service.create_converter_async = AsyncMock( return_value=CreateConverterResponse( converter_id="conv-1", type="Base64Converter", @@ -577,7 +575,7 @@ def test_create_converter_invalid_type(self, client: TestClient) -> None: """Test converter creation with invalid type.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_converter = AsyncMock(side_effect=ValueError("Converter type not found")) + mock_service.create_converter_async = AsyncMock(side_effect=ValueError("Converter type not found")) mock_get_service.return_value = mock_service response = client.post( @@ -591,7 +589,7 @@ def test_create_converter_internal_error(self, client: TestClient) -> None: """Test converter creation with internal error returns 500.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.create_converter = AsyncMock(side_effect=RuntimeError("Unexpected error")) + mock_service.create_converter_async = AsyncMock(side_effect=RuntimeError("Unexpected error")) mock_get_service.return_value = mock_service response = client.post( @@ -603,11 +601,9 @@ def test_create_converter_internal_error(self, client: TestClient) -> None: def test_get_converter_success(self, client: TestClient) -> None: """Test getting a converter instance by ID.""" - now = datetime.now(timezone.utc) - with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_converter = AsyncMock( + mock_service.get_converter_async = AsyncMock( return_value=ConverterInstance( converter_id="conv-1", type="Base64Converter", @@ -627,7 +623,7 @@ def test_get_converter_not_found(self, client: TestClient) -> None: """Test getting a non-existent converter instance.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.get_converter = AsyncMock(return_value=None) + mock_service.get_converter_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service response = client.get("/api/converters/nonexistent") @@ -638,7 +634,7 @@ def test_preview_conversion_success(self, client: TestClient) -> None: """Test previewing a conversion.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.preview_conversion = AsyncMock( + mock_service.preview_conversion_async = AsyncMock( return_value=ConverterPreviewResponse( original_value="test", original_value_data_type="text", @@ -676,7 +672,7 @@ def test_preview_conversion_bad_request(self, client: TestClient) -> None: """Test preview conversion with invalid converter ID returns 400.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.preview_conversion = AsyncMock( + mock_service.preview_conversion_async = AsyncMock( side_effect=ValueError("Converter instance 'nonexistent' not found") ) mock_get_service.return_value = mock_service @@ -696,7 +692,7 @@ def test_preview_conversion_internal_error(self, client: TestClient) -> None: """Test preview conversion with internal error returns 500.""" with patch("pyrit.backend.routes.converters.get_converter_service") as mock_get_service: mock_service = MagicMock() - mock_service.preview_conversion = AsyncMock(side_effect=RuntimeError("Converter execution failed")) + mock_service.preview_conversion_async = AsyncMock(side_effect=RuntimeError("Converter execution failed")) mock_get_service.return_value = mock_service response = client.post( @@ -766,6 +762,23 @@ def test_get_version_with_build_info(self, client: TestClient) -> None: finally: os.unlink(temp_path) + def test_get_version_build_info_load_failure(self, client: TestClient) -> None: + """Test getting version when build_info.json exists but fails to load.""" + with patch("pyrit.backend.routes.version.Path") as mock_path_class: + mock_path_instance = MagicMock() + mock_path_instance.exists.return_value = True + mock_path_class.return_value = mock_path_instance + + with patch("builtins.open", side_effect=OSError("permission denied")): + response = client.get("/api/version") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + # Falls back to default values when load fails + assert "version" in data + assert data["source"] is None + assert data["commit"] is None + # ============================================================================ # Health Routes Tests @@ -897,3 +910,17 @@ def test_get_labels_skips_non_string_values(self, client: TestClient) -> None: assert "int_val" not in data["labels"] assert "list_val" not in data["labels"] assert "dict_val" not in data["labels"] + + @pytest.mark.asyncio + async def test_get_label_options_unsupported_source_returns_empty_labels(self) -> None: + """Test that get_label_options returns empty labels for unsupported source types.""" + from pyrit.backend.routes.labels import get_label_options + + with patch("pyrit.backend.routes.labels.CentralMemory"): + # Call the function directly with a non-"attacks" source to cover the else branch. + # The Literal["attacks"] type hint prevents this via the API, but the function + # handles it gracefully. + result = await get_label_options(source="other") # type: ignore[arg-type] + + assert result.source == "other" + assert result.labels == {} diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index d70a701b20..0b80513a5d 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -145,7 +145,7 @@ async def test_list_attacks_returns_empty_when_no_attacks(self, attack_service, """Test that list_attacks returns empty list when no AttackResults exist.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.list_attacks() + result = await attack_service.list_attacks_async() assert result.items == [] assert result.pagination.has_more is False @@ -157,7 +157,7 @@ async def test_list_attacks_returns_attacks(self, attack_service, mock_memory) - mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks() + result = await attack_service.list_attacks_async() assert len(result.items) == 1 assert result.items[0].attack_id == "attack-1" @@ -171,7 +171,7 @@ async def test_list_attacks_filters_by_target_id(self, attack_service, mock_memo mock_memory.get_attack_results.return_value = [ar1, ar2] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks(target_id="target-1") + result = await attack_service.list_attacks_async(target_id="target-1") assert len(result.items) == 1 assert result.items[0].target_id == "target-1" @@ -184,7 +184,7 @@ async def test_list_attacks_filters_by_name(self, attack_service, mock_memory) - mock_memory.get_attack_results.return_value = [ar1, ar2] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks(name="test") + result = await attack_service.list_attacks_async(name="test") assert len(result.items) == 1 assert result.items[0].name == "Test Attack" @@ -199,7 +199,7 @@ async def test_list_attacks_filters_by_min_turns(self, attack_service, mock_memo mock_memory.get_attack_results.return_value = [ar1, ar2] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks(min_turns=3) + result = await attack_service.list_attacks_async(min_turns=3) assert len(result.items) == 1 assert result.items[0].attack_id == "attack-1" @@ -214,7 +214,7 @@ async def test_list_attacks_filters_by_max_turns(self, attack_service, mock_memo mock_memory.get_attack_results.return_value = [ar1, ar2] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks(max_turns=3) + result = await attack_service.list_attacks_async(max_turns=3) assert len(result.items) == 1 assert result.items[0].attack_id == "attack-2" @@ -229,7 +229,7 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks() + result = await attack_service.list_attacks_async() assert len(result.items) == 1 assert result.items[0].labels == {"env": "prod", "team": "red"} @@ -249,7 +249,7 @@ async def test_get_attack_returns_none_for_nonexistent(self, attack_service, moc """Test that get_attack returns None when AttackResult doesn't exist.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.get_attack("nonexistent") + result = await attack_service.get_attack_async(attack_id="nonexistent") assert result is None @@ -265,7 +265,7 @@ async def test_get_attack_returns_attack_details(self, attack_service, mock_memo mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - result = await attack_service.get_attack("test-id") + result = await attack_service.get_attack_async(attack_id="test-id") assert result is not None assert result.attack_id == "test-id" @@ -288,7 +288,7 @@ async def test_get_attack_messages_returns_none_for_nonexistent(self, attack_ser """Test that get_attack_messages returns None when attack doesn't exist.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.get_attack_messages("nonexistent") + result = await attack_service.get_attack_messages_async(attack_id="nonexistent") assert result is None @@ -299,7 +299,7 @@ async def test_get_attack_messages_returns_messages(self, attack_service, mock_m mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - result = await attack_service.get_attack_messages("test-id") + result = await attack_service.get_attack_messages_async(attack_id="test-id") assert result is not None assert result.attack_id == "test-id" @@ -320,21 +320,21 @@ async def test_create_attack_validates_target_exists(self, attack_service) -> No """Test that create_attack validates target exists.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=None) + mock_target_service.get_target_async = AsyncMock(return_value=None) mock_get_target_service.return_value = mock_target_service with pytest.raises(ValueError, match="not found"): - await attack_service.create_attack(CreateAttackRequest(target_id="nonexistent")) + await attack_service.create_attack_async(request=CreateAttackRequest(target_id="nonexistent")) @pytest.mark.asyncio async def test_create_attack_stores_attack_result(self, attack_service, mock_memory) -> None: """Test that create_attack stores AttackResult in memory.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) mock_get_target_service.return_value = mock_target_service - result = await attack_service.create_attack(CreateAttackRequest(target_id="target-1", name="My Attack")) + result = await attack_service.create_attack_async(request=CreateAttackRequest(target_id="target-1", name="My Attack")) assert result.attack_id is not None assert result.created_at is not None @@ -347,7 +347,7 @@ async def test_create_attack_stores_prepended_conversation(self, attack_service, with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) mock_get_target_service.return_value = mock_target_service prepended = [ @@ -357,8 +357,8 @@ async def test_create_attack_stores_prepended_conversation(self, attack_service, ) ] - result = await attack_service.create_attack( - CreateAttackRequest(target_id="target-1", prepended_conversation=prepended) + result = await attack_service.create_attack_async( + request=CreateAttackRequest(target_id="target-1", prepended_conversation=prepended) ) assert result.attack_id is not None @@ -371,11 +371,11 @@ async def test_create_attack_stores_labels_under_metadata_key(self, attack_servi """Test that create_attack stores labels under metadata['labels'], not spread.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() - mock_target_service.get_target = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) mock_get_target_service.return_value = mock_target_service - await attack_service.create_attack( - CreateAttackRequest( + await attack_service.create_attack_async( + request=CreateAttackRequest( target_id="target-1", name="Labeled Attack", labels={"env": "prod", "team": "red"}, @@ -406,7 +406,7 @@ async def test_update_attack_returns_none_for_nonexistent(self, attack_service, """Test that update_attack returns None for nonexistent attack.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.update_attack("nonexistent", UpdateAttackRequest(outcome="success")) + result = await attack_service.update_attack_async(attack_id="nonexistent", request=UpdateAttackRequest(outcome="success")) assert result is None @@ -417,7 +417,7 @@ async def test_update_attack_updates_outcome(self, attack_service, mock_memory) mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - await attack_service.update_attack("test-id", UpdateAttackRequest(outcome="success")) + await attack_service.update_attack_async(attack_id="test-id", request=UpdateAttackRequest(outcome="success")) # Should call add_attack_results_to_memory to update mock_memory.add_attack_results_to_memory.assert_called() @@ -442,7 +442,7 @@ async def test_add_message_raises_for_nonexistent_attack(self, attack_service, m ) with pytest.raises(ValueError, match="not found"): - await attack_service.add_message("nonexistent", request) + await attack_service.add_message_async(attack_id="nonexistent", request=request) @pytest.mark.asyncio async def test_add_message_without_send_stores_message(self, attack_service, mock_memory) -> None: @@ -458,7 +458,7 @@ async def test_add_message_without_send_stores_message(self, attack_service, moc send=False, ) - result = await attack_service.add_message("test-id", request) + result = await attack_service.add_message_async(attack_id="test-id", request=request) assert result.attack is not None mock_memory.add_message_pieces_to_memory.assert_called() @@ -475,7 +475,7 @@ async def test_add_message_raises_when_no_target_id(self, attack_service, mock_m ) with pytest.raises(ValueError, match="has no target configured"): - await attack_service.add_message("test-id", request) + await attack_service.add_message_async(attack_id="test-id", request=request) @pytest.mark.asyncio async def test_add_message_with_send_calls_normalizer(self, attack_service, mock_memory) -> None: @@ -502,7 +502,7 @@ async def test_add_message_with_send_calls_normalizer(self, attack_service, mock send=True, ) - result = await attack_service.add_message("test-id", request) + result = await attack_service.add_message_async(attack_id="test-id", request=request) mock_normalizer.send_prompt_async.assert_called_once() assert result.attack is not None @@ -525,7 +525,7 @@ async def test_add_message_with_send_raises_when_target_not_found(self, attack_s ) with pytest.raises(ValueError, match="Target object .* not found"): - await attack_service.add_message("test-id", request) + await attack_service.add_message_async(attack_id="test-id", request=request) @pytest.mark.asyncio async def test_add_message_with_converter_ids_gets_converters(self, attack_service, mock_memory) -> None: @@ -561,9 +561,48 @@ async def test_add_message_with_converter_ids_gets_converters(self, attack_servi converter_ids=["conv-1"], ) - await attack_service.add_message("test-id", request) + await attack_service.add_message_async(attack_id="test-id", request=request) - mock_conv_svc.get_converter_objects_for_ids.assert_called_once_with(["conv-1"]) + mock_conv_svc.get_converter_objects_for_ids.assert_called_once_with(converter_ids=["conv-1"]) + + @pytest.mark.asyncio + async def test_add_message_raises_when_attack_not_found_after_update(self, attack_service, mock_memory) -> None: + """Test that add_message raises ValueError when attack disappears after update.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + mock_memory.get_conversation.return_value = [] + + request = AddMessageRequest( + role="system", + pieces=[MessagePieceRequest(original_value="Hello")], + send=False, + ) + + with patch.object(attack_service, "get_attack_async", new=AsyncMock(return_value=None)): + with pytest.raises(ValueError, match="not found after update"): + await attack_service.add_message_async(attack_id="test-id", request=request) + + @pytest.mark.asyncio + async def test_add_message_raises_when_messages_not_found_after_update(self, attack_service, mock_memory) -> None: + """Test that add_message raises ValueError when messages disappear after update.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + mock_memory.get_conversation.return_value = [] + + request = AddMessageRequest( + role="system", + pieces=[MessagePieceRequest(original_value="Hello")], + send=False, + ) + + with ( + patch.object(attack_service, "get_attack_async", new=AsyncMock(return_value=MagicMock())), + patch.object(attack_service, "get_attack_messages_async", new=AsyncMock(return_value=None)), + ): + with pytest.raises(ValueError, match="messages not found after update"): + await attack_service.add_message_async(attack_id="test-id", request=request) # ============================================================================ @@ -585,7 +624,7 @@ async def test_list_attacks_with_cursor_paginates(self, attack_service, mock_mem mock_memory.get_message_pieces.return_value = [] # Get first page - result = await attack_service.list_attacks(limit=2) + result = await attack_service.list_attacks_async(limit=2) # Results are sorted by updated_at desc, so order may vary assert len(result.items) == 2 @@ -598,11 +637,36 @@ async def test_list_attacks_has_more_flag(self, attack_service, mock_memory) -> mock_memory.get_attack_results.return_value = [ar1, ar2, ar3] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks(limit=2) + result = await attack_service.list_attacks_async(limit=2) assert result.pagination.has_more is True assert len(result.items) == 2 + @pytest.mark.asyncio + async def test_list_attacks_cursor_skips_to_correct_position(self, attack_service, mock_memory) -> None: + """Test that list_attacks with cursor skips items before cursor.""" + ar1 = make_attack_result( + conversation_id="attack-1", + updated_at=datetime(2024, 1, 3, tzinfo=timezone.utc), + ) + ar2 = make_attack_result( + conversation_id="attack-2", + updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc), + ) + ar3 = make_attack_result( + conversation_id="attack-3", + updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + mock_memory.get_attack_results.return_value = [ar1, ar2, ar3] + mock_memory.get_message_pieces.return_value = [] + + # Cursor = attack-1 should skip attack-1 and return from attack-2 onward + result = await attack_service.list_attacks_async(cursor="attack-1", limit=10) + + attack_ids = [item.attack_id for item in result.items] + assert "attack-1" not in attack_ids + assert len(result.items) == 2 + # ============================================================================ # Message Building Tests @@ -637,7 +701,7 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic mock_memory.get_conversation.return_value = [mock_msg] - result = await attack_service.get_attack_messages("test-id") + result = await attack_service.get_attack_messages_async(attack_id="test-id") assert result is not None assert len(result.messages) == 1 diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 110ea45cf3..796db1d271 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -9,7 +9,6 @@ import pytest -import pyrit.backend.services.converter_service as converter_service_module from pyrit import prompt_converter from pyrit.backend.models.converters import ( ConverterPreviewRequest, @@ -42,7 +41,7 @@ async def test_list_converters_returns_empty_when_no_converters(self) -> None: """Test that list_converters returns empty list when no converters exist.""" service = ConverterService() - result = await service.list_converters() + result = await service.list_converters_async() assert result.items == [] @@ -63,7 +62,7 @@ async def test_list_converters_returns_converters_from_registry(self) -> None: mock_converter.get_identifier.return_value = mock_identifier service._registry.register_instance(mock_converter, name="conv-1") - result = await service.list_converters() + result = await service.list_converters_async() assert len(result.items) == 1 assert result.items[0].converter_id == "conv-1" @@ -83,7 +82,7 @@ async def test_get_converter_returns_none_for_nonexistent(self) -> None: """Test that get_converter returns None for non-existent converter.""" service = ConverterService() - result = await service.get_converter("nonexistent-id") + result = await service.get_converter_async(converter_id="nonexistent-id") assert result is None @@ -102,7 +101,7 @@ async def test_get_converter_returns_converter_from_registry(self) -> None: mock_converter.get_identifier.return_value = mock_identifier service._registry.register_instance(mock_converter, name="conv-1") - result = await service.get_converter("conv-1") + result = await service.get_converter_async(converter_id="conv-1") assert result is not None assert result.converter_id == "conv-1" @@ -116,7 +115,7 @@ def test_get_converter_object_returns_none_for_nonexistent(self) -> None: """Test that get_converter_object returns None for non-existent converter.""" service = ConverterService() - result = service.get_converter_object("nonexistent-id") + result = service.get_converter_object(converter_id="nonexistent-id") assert result is None @@ -126,7 +125,7 @@ def test_get_converter_object_returns_object_from_registry(self) -> None: mock_converter = MagicMock() service._registry.register_instance(mock_converter, name="conv-1") - result = service.get_converter_object("conv-1") + result = service.get_converter_object(converter_id="conv-1") assert result is mock_converter @@ -145,7 +144,7 @@ async def test_create_converter_raises_for_invalid_type(self) -> None: ) with pytest.raises(ValueError, match="not found"): - await service.create_converter(request) + await service.create_converter_async(request=request) @pytest.mark.asyncio async def test_create_converter_success(self) -> None: @@ -158,7 +157,7 @@ async def test_create_converter_success(self) -> None: params={}, ) - result = await service.create_converter(request) + result = await service.create_converter_async(request=request) assert result.converter_id is not None assert result.type == "Base64Converter" @@ -174,10 +173,10 @@ async def test_create_converter_registers_in_registry(self) -> None: params={}, ) - result = await service.create_converter(request) + result = await service.create_converter_async(request=request) # Object should be retrievable from registry - converter_obj = service.get_converter_object(result.converter_id) + converter_obj = service.get_converter_object(converter_id=result.converter_id) assert converter_obj is not None @@ -189,7 +188,7 @@ def test_resolve_converter_params_returns_params_unchanged_when_no_converter_ref service = ConverterService() params = {"key": "value", "number": 42} - result = service._resolve_converter_params(params) + result = service._resolve_converter_params(params=params) assert result == params @@ -203,7 +202,7 @@ def test_resolve_converter_params_resolves_converter_id_reference(self) -> None: params = {"converter": {"converter_id": "inner-conv"}} - result = service._resolve_converter_params(params) + result = service._resolve_converter_params(params=params) assert result["converter"] is mock_converter @@ -214,14 +213,14 @@ def test_resolve_converter_params_raises_for_nonexistent_reference(self) -> None params = {"converter": {"converter_id": "nonexistent"}} with pytest.raises(ValueError, match="not found"): - service._resolve_converter_params(params) + service._resolve_converter_params(params=params) def test_resolve_converter_params_ignores_non_dict_converter(self) -> None: """Test that non-dict converter values are not modified.""" service = ConverterService() params = {"converter": "some_string_value"} - result = service._resolve_converter_params(params) + result = service._resolve_converter_params(params=params) assert result == params @@ -241,7 +240,7 @@ async def test_preview_conversion_raises_for_nonexistent_converter(self) -> None ) with pytest.raises(ValueError, match="not found"): - await service.preview_conversion(request) + await service.preview_conversion_async(request=request) @pytest.mark.asyncio async def test_preview_conversion_with_converter_ids(self) -> None: @@ -262,7 +261,7 @@ async def test_preview_conversion_with_converter_ids(self) -> None: converter_ids=["conv-1"], ) - result = await service.preview_conversion(request) + result = await service.preview_conversion_async(request=request) assert result.original_value == "test" assert result.converted_value == "encoded_value" @@ -297,7 +296,7 @@ async def test_preview_conversion_chains_multiple_converters(self) -> None: converter_ids=["conv-1", "conv-2"], ) - result = await service.preview_conversion(request) + result = await service.preview_conversion_async(request=request) assert result.converted_value == "step2_output" assert len(result.steps) == 2 @@ -312,7 +311,7 @@ def test_get_converter_objects_for_ids_raises_for_nonexistent(self) -> None: service = ConverterService() with pytest.raises(ValueError, match="not found"): - service.get_converter_objects_for_ids(["nonexistent"]) + service.get_converter_objects_for_ids(converter_ids=["nonexistent"]) def test_get_converter_objects_for_ids_returns_objects(self) -> None: """Test that method returns converter objects in order.""" @@ -323,7 +322,7 @@ def test_get_converter_objects_for_ids_returns_objects(self) -> None: service._registry.register_instance(mock1, name="conv-1") service._registry.register_instance(mock2, name="conv-2") - result = service.get_converter_objects_for_ids(["conv-1", "conv-2"]) + result = service.get_converter_objects_for_ids(converter_ids=["conv-1", "conv-2"]) assert result == [mock1, mock2] @@ -415,7 +414,7 @@ def test_build_instance_from_converter(self, converter_name: str) -> None: # Build the instance using the service method service = ConverterService() - result = service._build_instance_from_object("test-id", converter_instance) + result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter_instance) # Verify the result assert result.converter_id == "test-id" @@ -438,7 +437,7 @@ def test_caesar_converter_params(self) -> None: """Test that CaesarConverter params are extracted correctly.""" converter = CaesarConverter(caesar_offset=13) service = ConverterService() - result = service._build_instance_from_object("test-id", converter) + result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) assert result.type == "CaesarConverter" converter_specific = result.params.get("converter_specific_params", {}) @@ -448,7 +447,7 @@ def test_suffix_append_converter_params(self) -> None: """Test that SuffixAppendConverter params are extracted correctly.""" converter = SuffixAppendConverter(suffix="test suffix") service = ConverterService() - result = service._build_instance_from_object("test-id", converter) + result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) assert result.type == "SuffixAppendConverter" converter_specific = result.params.get("converter_specific_params", {}) @@ -458,7 +457,7 @@ def test_repeat_token_converter_params(self) -> None: """Test that RepeatTokenConverter params are extracted correctly.""" converter = RepeatTokenConverter(token_to_repeat="x", times_to_repeat=5) service = ConverterService() - result = service._build_instance_from_object("test-id", converter) + result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) assert result.type == "RepeatTokenConverter" converter_specific = result.params.get("converter_specific_params", {}) @@ -469,7 +468,7 @@ def test_base64_converter_default_params(self) -> None: """Test that Base64Converter default params are captured.""" converter = Base64Converter() service = ConverterService() - result = service._build_instance_from_object("test-id", converter) + result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) assert result.type == "Base64Converter" # Verify params dict is populated from identifier diff --git a/tests/unit/backend/test_main.py b/tests/unit/backend/test_main.py new file mode 100644 index 0000000000..272bc953b7 --- /dev/null +++ b/tests/unit/backend/test_main.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for the FastAPI application entry point (main.py). + +Covers the lifespan manager and setup_frontend function. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestLifespan: + """Tests for the application lifespan context manager.""" + + @pytest.mark.asyncio + async def test_lifespan_initializes_pyrit_and_yields(self) -> None: + """Test that lifespan calls initialize_pyrit_async on startup and yields.""" + with patch("pyrit.backend.main.initialize_pyrit_async", new_callable=AsyncMock) as mock_init: + from pyrit.backend.main import app, lifespan + + async with lifespan(app): + pass # The body of the context manager is the "yield" phase + + mock_init.assert_awaited_once_with(memory_db_type="SQLite") + + +class TestSetupFrontend: + """Tests for the setup_frontend function.""" + + def test_dev_mode_does_not_mount_static(self) -> None: + """Test that DEV_MODE skips static file serving.""" + with ( + patch("pyrit.backend.main.DEV_MODE", True), + patch("builtins.print") as mock_print, + ): + from pyrit.backend.main import setup_frontend + + setup_frontend() + + mock_print.assert_called_once() + assert "DEVELOPMENT" in mock_print.call_args[0][0] + + def test_frontend_exists_mounts_static(self) -> None: + """Test that setup_frontend mounts StaticFiles when frontend exists.""" + mock_frontend_path = MagicMock() + mock_frontend_path.exists.return_value = True + mock_frontend_path.__str__ = lambda self: "/tmp/fake_frontend" + + # Create the directory so StaticFiles doesn't raise + import os + + os.makedirs("/tmp/fake_frontend", exist_ok=True) + + with ( + patch("pyrit.backend.main.DEV_MODE", False), + patch("pyrit.backend.main.Path") as mock_path_cls, + patch("builtins.print"), + ): + mock_path_instance = MagicMock() + mock_path_instance.parent.__truediv__ = MagicMock(return_value=mock_frontend_path) + mock_path_cls.return_value = mock_path_instance + + from pyrit.backend.main import setup_frontend + + setup_frontend() + + def test_frontend_missing_exits(self) -> None: + """Test that setup_frontend calls sys.exit when frontend is missing.""" + mock_frontend_path = MagicMock() + mock_frontend_path.exists.return_value = False + mock_frontend_path.__str__ = lambda self: "/nonexistent/frontend" + + with ( + patch("pyrit.backend.main.DEV_MODE", False), + patch("pyrit.backend.main.Path") as mock_path_cls, + patch("builtins.print"), + patch.object(sys, "exit", side_effect=SystemExit(1)) as mock_exit, + ): + mock_path_instance = MagicMock() + mock_path_instance.parent.__truediv__ = MagicMock(return_value=mock_frontend_path) + mock_path_cls.return_value = mock_path_instance + + from pyrit.backend.main import setup_frontend + + with pytest.raises(SystemExit): + setup_frontend() + + mock_exit.assert_called_once_with(1) diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 474fa6b724..b9307db492 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -30,7 +30,7 @@ async def test_list_targets_returns_empty_when_no_targets(self) -> None: """Test that list_targets returns empty list when no targets exist.""" service = TargetService() - result = await service.list_targets() + result = await service.list_targets_async() assert result.items == [] assert result.pagination.has_more is False @@ -45,7 +45,7 @@ async def test_list_targets_returns_targets_from_registry(self) -> None: mock_target.get_identifier.return_value = {"__type__": "MockTarget", "endpoint": "http://test"} service._registry.register_instance(mock_target, name="target-1") - result = await service.list_targets() + result = await service.list_targets_async() assert len(result.items) == 1 assert result.items[0].target_id == "target-1" @@ -62,7 +62,7 @@ async def test_list_targets_paginates_with_limit(self) -> None: mock_target.get_identifier.return_value = {"__type__": "MockTarget"} service._registry.register_instance(mock_target, name=f"target-{i}") - result = await service.list_targets(limit=3) + result = await service.list_targets_async(limit=3) assert len(result.items) == 3 assert result.pagination.limit == 3 @@ -79,8 +79,8 @@ async def test_list_targets_cursor_returns_next_page(self) -> None: mock_target.get_identifier.return_value = {"__type__": "MockTarget"} service._registry.register_instance(mock_target, name=f"target-{i}") - first_page = await service.list_targets(limit=2) - second_page = await service.list_targets(limit=2, cursor=first_page.pagination.next_cursor) + first_page = await service.list_targets_async(limit=2) + second_page = await service.list_targets_async(limit=2, cursor=first_page.pagination.next_cursor) assert len(second_page.items) == 2 assert second_page.items[0].target_id != first_page.items[0].target_id @@ -96,8 +96,8 @@ async def test_list_targets_last_page_has_no_more(self) -> None: mock_target.get_identifier.return_value = {"__type__": "MockTarget"} service._registry.register_instance(mock_target, name=f"target-{i}") - first_page = await service.list_targets(limit=2) - last_page = await service.list_targets(limit=2, cursor=first_page.pagination.next_cursor) + first_page = await service.list_targets_async(limit=2) + last_page = await service.list_targets_async(limit=2, cursor=first_page.pagination.next_cursor) assert len(last_page.items) == 1 assert last_page.pagination.has_more is False @@ -112,7 +112,7 @@ async def test_get_target_returns_none_for_nonexistent(self) -> None: """Test that get_target returns None for non-existent target.""" service = TargetService() - result = await service.get_target("nonexistent-id") + result = await service.get_target_async(target_id="nonexistent-id") assert result is None @@ -125,7 +125,7 @@ async def test_get_target_returns_target_from_registry(self) -> None: mock_target.get_identifier.return_value = {"__type__": "MockTarget"} service._registry.register_instance(mock_target, name="target-1") - result = await service.get_target("target-1") + result = await service.get_target_async(target_id="target-1") assert result is not None assert result.target_id == "target-1" @@ -139,7 +139,7 @@ def test_get_target_object_returns_none_for_nonexistent(self) -> None: """Test that get_target_object returns None for non-existent target.""" service = TargetService() - result = service.get_target_object("nonexistent-id") + result = service.get_target_object(target_id="nonexistent-id") assert result is None @@ -149,7 +149,7 @@ def test_get_target_object_returns_object_from_registry(self) -> None: mock_target = MagicMock() service._registry.register_instance(mock_target, name="target-1") - result = service.get_target_object("target-1") + result = service.get_target_object(target_id="target-1") assert result is mock_target @@ -168,7 +168,7 @@ async def test_create_target_raises_for_invalid_type(self) -> None: ) with pytest.raises(ValueError, match="not found"): - await service.create_target(request) + await service.create_target_async(request=request) @pytest.mark.asyncio async def test_create_target_success(self, sqlite_instance) -> None: @@ -181,7 +181,7 @@ async def test_create_target_success(self, sqlite_instance) -> None: params={}, ) - result = await service.create_target(request) + result = await service.create_target_async(request=request) assert result.target_id is not None assert result.type == "TextTarget" @@ -197,10 +197,10 @@ async def test_create_target_registers_in_registry(self, sqlite_instance) -> Non params={}, ) - result = await service.create_target(request) + result = await service.create_target_async(request=request) # Object should be retrievable from registry - target_obj = service.get_target_object(result.target_id) + target_obj = service.get_target_object(target_id=result.target_id) assert target_obj is not None From 2a2d26327d0d11d7b1b29e8d577a3da7c7d2230e Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 11 Feb 2026 06:26:34 -0800 Subject: [PATCH 27/35] clean up imports --- tests/unit/backend/test_api_routes.py | 3 +-- tests/unit/backend/test_attack_service.py | 3 +-- tests/unit/backend/test_error_handlers.py | 1 - tests/unit/backend/test_main.py | 13 +++---------- tests/unit/backend/test_target_service.py | 6 +----- 5 files changed, 6 insertions(+), 20 deletions(-) diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 4d53c39b4d..e42b031e57 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -38,6 +38,7 @@ TargetInstance, TargetListResponse, ) +from pyrit.backend.routes.labels import get_label_options @pytest.fixture @@ -914,8 +915,6 @@ def test_get_labels_skips_non_string_values(self, client: TestClient) -> None: @pytest.mark.asyncio async def test_get_label_options_unsupported_source_returns_empty_labels(self) -> None: """Test that get_label_options returns empty labels for unsupported source types.""" - from pyrit.backend.routes.labels import get_label_options - with patch("pyrit.backend.routes.labels.CentralMemory"): # Call the function directly with a non-"attacks" source to cover the else branch. # The Literal["attacks"] type hint prevents this via the API, but the function diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 0b80513a5d..c443923104 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -16,6 +16,7 @@ AddMessageRequest, CreateAttackRequest, MessagePieceRequest, + PrependedMessageRequest, UpdateAttackRequest, ) from pyrit.backend.services.attack_service import ( @@ -343,8 +344,6 @@ async def test_create_attack_stores_attack_result(self, attack_service, mock_mem @pytest.mark.asyncio async def test_create_attack_stores_prepended_conversation(self, attack_service, mock_memory) -> None: """Test that create_attack stores prepended conversation messages.""" - from pyrit.backend.models.attacks import PrependedMessageRequest - with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) diff --git a/tests/unit/backend/test_error_handlers.py b/tests/unit/backend/test_error_handlers.py index ad370fce4d..8fb2be55a4 100644 --- a/tests/unit/backend/test_error_handlers.py +++ b/tests/unit/backend/test_error_handlers.py @@ -41,7 +41,6 @@ def client(self, app: FastAPI) -> TestClient: def test_validation_error_returns_422(self, app: FastAPI, client: TestClient) -> None: """Test that validation errors return 422 with RFC 7807 format.""" - from pydantic import BaseModel class TestInput(BaseModel): name: str diff --git a/tests/unit/backend/test_main.py b/tests/unit/backend/test_main.py index 272bc953b7..8290038868 100644 --- a/tests/unit/backend/test_main.py +++ b/tests/unit/backend/test_main.py @@ -7,11 +7,14 @@ Covers the lifespan manager and setup_frontend function. """ +import os import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest +from pyrit.backend.main import app, lifespan, setup_frontend + class TestLifespan: """Tests for the application lifespan context manager.""" @@ -20,8 +23,6 @@ class TestLifespan: async def test_lifespan_initializes_pyrit_and_yields(self) -> None: """Test that lifespan calls initialize_pyrit_async on startup and yields.""" with patch("pyrit.backend.main.initialize_pyrit_async", new_callable=AsyncMock) as mock_init: - from pyrit.backend.main import app, lifespan - async with lifespan(app): pass # The body of the context manager is the "yield" phase @@ -37,8 +38,6 @@ def test_dev_mode_does_not_mount_static(self) -> None: patch("pyrit.backend.main.DEV_MODE", True), patch("builtins.print") as mock_print, ): - from pyrit.backend.main import setup_frontend - setup_frontend() mock_print.assert_called_once() @@ -51,8 +50,6 @@ def test_frontend_exists_mounts_static(self) -> None: mock_frontend_path.__str__ = lambda self: "/tmp/fake_frontend" # Create the directory so StaticFiles doesn't raise - import os - os.makedirs("/tmp/fake_frontend", exist_ok=True) with ( @@ -64,8 +61,6 @@ def test_frontend_exists_mounts_static(self) -> None: mock_path_instance.parent.__truediv__ = MagicMock(return_value=mock_frontend_path) mock_path_cls.return_value = mock_path_instance - from pyrit.backend.main import setup_frontend - setup_frontend() def test_frontend_missing_exits(self) -> None: @@ -84,8 +79,6 @@ def test_frontend_missing_exits(self) -> None: mock_path_instance.parent.__truediv__ = MagicMock(return_value=mock_frontend_path) mock_path_cls.return_value = mock_path_instance - from pyrit.backend.main import setup_frontend - with pytest.raises(SystemExit): setup_frontend() diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index b9307db492..4935970a05 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -10,7 +10,7 @@ import pytest from pyrit.backend.models.targets import CreateTargetRequest -from pyrit.backend.services.target_service import TargetService +from pyrit.backend.services.target_service import TargetService, get_target_service from pyrit.registry.instance_registries import TargetRegistry @@ -209,8 +209,6 @@ class TestTargetServiceSingleton: def test_get_target_service_returns_target_service(self) -> None: """Test that get_target_service returns a TargetService instance.""" - from pyrit.backend.services.target_service import get_target_service - get_target_service.cache_clear() service = get_target_service() @@ -218,8 +216,6 @@ def test_get_target_service_returns_target_service(self) -> None: def test_get_target_service_returns_same_instance(self) -> None: """Test that get_target_service returns the same instance.""" - from pyrit.backend.services.target_service import get_target_service - get_target_service.cache_clear() service1 = get_target_service() From 426255760fb8fabcea6c9921ce1b44a768c45400 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 11 Feb 2026 17:29:38 -0800 Subject: [PATCH 28/35] pr feedback --- pyrit/exceptions/exception_context.py | 26 +- pyrit/identifiers/attack_identifier.py | 5 +- pyrit/identifiers/converter_identifier.py | 5 +- pyrit/identifiers/scorer_identifier.py | 5 +- pyrit/identifiers/target_identifier.py | 5 +- .../float_scale/self_ask_likert_scorer.py | 1 + .../float_scale/self_ask_scale_scorer.py | 1 + tests/unit/analytics/test_result_analysis.py | 6 +- .../identifiers/test_attack_identifier.py | 266 ++++++++++++++++++ 9 files changed, 289 insertions(+), 31 deletions(-) create mode 100644 tests/unit/identifiers/test_attack_identifier.py diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index 8375763207..b88c92a017 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -13,7 +13,7 @@ from contextvars import ContextVar from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Any, Optional from pyrit.identifiers import AttackIdentifier, Identifier @@ -62,10 +62,10 @@ class ExecutionContext: attack_strategy_name: Optional[str] = None # The identifier for the attack strategy - attack_identifier: Optional[Union["AttackIdentifier", Dict[str, Any]]] = None + attack_identifier: Optional[AttackIdentifier] = None # The identifier from the component's get_identifier() (target, scorer, etc.) - component_identifier: Optional[Dict[str, Any]] = None + component_identifier: Optional[Identifier] = None # The objective target conversation ID if available objective_target_conversation_id: Optional[str] = None @@ -192,8 +192,8 @@ def execution_context( *, component_role: ComponentRole, attack_strategy_name: Optional[str] = None, - attack_identifier: Optional[Union[AttackIdentifier, Dict[str, Any]]] = None, - component_identifier: Optional[Union[Identifier, Dict[str, Any]]] = None, + attack_identifier: Optional[AttackIdentifier] = None, + component_identifier: Optional[Identifier] = None, objective_target_conversation_id: Optional[str] = None, objective: Optional[str] = None, ) -> ExecutionContextManager: @@ -203,9 +203,8 @@ def execution_context( Args: component_role: The role of the component being executed. attack_strategy_name: The name of the attack strategy class. - attack_identifier: The attack identifier. Can be an AttackIdentifier or a dict. + attack_identifier: The attack identifier. component_identifier: The identifier from component.get_identifier(). - Can be an Identifier object or a dict (legacy format). objective_target_conversation_id: The objective target conversation ID if available. objective: The attack objective if available. @@ -215,22 +214,15 @@ def execution_context( # Extract endpoint and component_name from component_identifier if available endpoint = None component_name = None - component_id_dict: Optional[Dict[str, Any]] = None if component_identifier: - if isinstance(component_identifier, Identifier): - endpoint = getattr(component_identifier, "endpoint", None) - component_name = component_identifier.class_name - component_id_dict = component_identifier.to_dict() - else: - endpoint = component_identifier.get("endpoint") - component_name = component_identifier.get("__type__") - component_id_dict = component_identifier + endpoint = getattr(component_identifier, "endpoint", None) + component_name = component_identifier.class_name context = ExecutionContext( component_role=component_role, attack_strategy_name=attack_strategy_name, attack_identifier=attack_identifier, - component_identifier=component_id_dict, + component_identifier=component_identifier, objective_target_conversation_id=objective_target_conversation_id, endpoint=endpoint, component_name=component_name, diff --git a/pyrit/identifiers/attack_identifier.py b/pyrit/identifiers/attack_identifier.py index 92c8cf103d..0a47d9f787 100644 --- a/pyrit/identifiers/attack_identifier.py +++ b/pyrit/identifiers/attack_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Dict, List, Optional, Type from pyrit.identifiers.converter_identifier import ConverterIdentifier from pyrit.identifiers.identifier import Identifier @@ -57,5 +57,4 @@ def from_dict(cls: Type["AttackIdentifier"], data: dict[str, Any]) -> "AttackIde for c in data["request_converter_identifiers"] ] - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(AttackIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/identifiers/converter_identifier.py b/pyrit/identifiers/converter_identifier.py index 777672a932..3c86a0daca 100644 --- a/pyrit/identifiers/converter_identifier.py +++ b/pyrit/identifiers/converter_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Tuple, Type from pyrit.identifiers.identifier import Identifier @@ -73,5 +73,4 @@ def from_dict(cls: Type["ConverterIdentifier"], data: dict[str, Any]) -> "Conver data["supported_output_types"] = () # Delegate to parent class for standard processing - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(ConverterIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/identifiers/scorer_identifier.py b/pyrit/identifiers/scorer_identifier.py index d467504fe4..8dac5fc676 100644 --- a/pyrit/identifiers/scorer_identifier.py +++ b/pyrit/identifiers/scorer_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Dict, List, Optional, Type from pyrit.identifiers.identifier import _MAX_STORAGE_LENGTH, Identifier from pyrit.models.score import ScoreType @@ -64,5 +64,4 @@ def from_dict(cls: Type["ScorerIdentifier"], data: dict[str, Any]) -> "ScorerIde ] # Delegate to parent class for standard processing - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(ScorerIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index b8924fb0c0..9d31170182 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional, Type, cast +from typing import Any, Dict, Optional, Type from pyrit.identifiers.identifier import Identifier @@ -54,5 +54,4 @@ def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIde TargetIdentifier: A new TargetIdentifier instance. """ # Delegate to parent class for standard processing - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(TargetIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index c01fe65073..523e5d703c 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -280,6 +280,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op message_data_type=message_piece.converted_value_data_type, scored_prompt_id=message_piece.id, category=self._score_category, + attack_identifier=message_piece.attack_identifier, objective=objective, ) diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 5e502681d0..6c2ed1e116 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -120,6 +120,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op scored_prompt_id=message_piece.id, category=self._category, objective=objective, + attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score( diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 812126b79e..5a074aefc6 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Optional + import pytest from pyrit.analytics.result_analysis import AttackStats, analyze_results @@ -11,13 +13,13 @@ # helpers def make_attack( outcome: AttackOutcome, - attack_type: str | None = "default", + attack_type: Optional[str] = "default", conversation_id: str = "conv-1", ) -> AttackResult: """ Minimal valid AttackResult for analytics tests. """ - attack_identifier: AttackIdentifier | None = None + attack_identifier: Optional[AttackIdentifier] = None if attack_type is not None: attack_identifier = AttackIdentifier(class_name=attack_type, class_module="tests.unit.analytics") diff --git a/tests/unit/identifiers/test_attack_identifier.py b/tests/unit/identifiers/test_attack_identifier.py new file mode 100644 index 0000000000..d194175eb7 --- /dev/null +++ b/tests/unit/identifiers/test_attack_identifier.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for AttackIdentifier-specific functionality. + +Note: Base Identifier functionality (hash computation, to_dict/from_dict basics, +frozen/hashable properties) is tested via ScorerIdentifier in test_scorer_identifier.py. +These tests focus on AttackIdentifier-specific fields and from_dict deserialization +of nested sub-identifiers. +""" + +import pytest + +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier + + +def _make_target_identifier() -> TargetIdentifier: + return TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + ) + + +def _make_scorer_identifier() -> ScorerIdentifier: + return ScorerIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score.true_false.self_ask_true_false_scorer", + class_description="True/false scorer", + identifier_type="instance", + ) + + +def _make_converter_identifier() -> ConverterIdentifier: + return ConverterIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter.base64_converter", + class_description="Base64 converter", + identifier_type="instance", + supported_input_types=["text"], + supported_output_types=["text"], + ) + + +class TestAttackIdentifierCreation: + """Test basic AttackIdentifier creation.""" + + def test_creation_minimal(self): + """Test creating an AttackIdentifier with only base fields.""" + identifier = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + + assert identifier.class_name == "PromptSendingAttack" + assert identifier.objective_target_identifier is None + assert identifier.objective_scorer_identifier is None + assert identifier.request_converter_identifiers is None + assert identifier.attack_specific_params is None + assert identifier.hash is not None + + def test_creation_all_fields(self): + """Test creating an AttackIdentifier with all sub-identifiers.""" + target_id = _make_target_identifier() + scorer_id = _make_scorer_identifier() + converter_id = _make_converter_identifier() + + identifier = AttackIdentifier( + class_name="CrescendoAttack", + class_module="pyrit.executor.attack.multi_turn.crescendo", + objective_target_identifier=target_id, + objective_scorer_identifier=scorer_id, + request_converter_identifiers=[converter_id], + attack_specific_params={"max_turns": 10}, + ) + + assert identifier.objective_target_identifier is target_id + assert identifier.objective_scorer_identifier is scorer_id + assert identifier.request_converter_identifiers == [converter_id] + assert identifier.attack_specific_params == {"max_turns": 10} + + def test_frozen(self): + """Test that AttackIdentifier is immutable.""" + identifier = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + + with pytest.raises(AttributeError): + identifier.class_name = "Other" # type: ignore[misc] + + def test_hashable(self): + """Test that AttackIdentifier can be used in sets/dicts.""" + identifier = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + # Should not raise + {identifier} + {identifier: 1} + + +class TestAttackIdentifierFromDict: + """Test AttackIdentifier.from_dict with nested sub-identifier deserialization.""" + + def test_from_dict_minimal(self): + """Test from_dict with no nested sub-identifiers.""" + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result, AttackIdentifier) + assert result.class_name == "PromptSendingAttack" + assert result.objective_target_identifier is None + assert result.objective_scorer_identifier is None + assert result.request_converter_identifiers is None + + def test_from_dict_deserializes_nested_target(self): + """Test that from_dict recursively deserializes the target sub-identifier.""" + target_id = _make_target_identifier() + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "objective_target_identifier": target_id.to_dict(), + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result.objective_target_identifier, TargetIdentifier) + assert result.objective_target_identifier.class_name == "OpenAIChatTarget" + assert result.objective_target_identifier.endpoint == "https://api.openai.com/v1" + + def test_from_dict_deserializes_nested_scorer(self): + """Test that from_dict recursively deserializes the scorer sub-identifier.""" + scorer_id = _make_scorer_identifier() + data = { + "class_name": "CrescendoAttack", + "class_module": "pyrit.executor.attack.multi_turn.crescendo", + "objective_scorer_identifier": scorer_id.to_dict(), + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result.objective_scorer_identifier, ScorerIdentifier) + assert result.objective_scorer_identifier.class_name == "SelfAskTrueFalseScorer" + + def test_from_dict_deserializes_nested_converters(self): + """Test that from_dict recursively deserializes converter sub-identifiers.""" + converter_id = _make_converter_identifier() + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "request_converter_identifiers": [converter_id.to_dict()], + } + + result = AttackIdentifier.from_dict(data) + + assert result.request_converter_identifiers is not None + assert len(result.request_converter_identifiers) == 1 + assert isinstance(result.request_converter_identifiers[0], ConverterIdentifier) + assert result.request_converter_identifiers[0].class_name == "Base64Converter" + + def test_from_dict_all_nested(self): + """Test from_dict with all nested sub-identifiers as dicts.""" + target_id = _make_target_identifier() + scorer_id = _make_scorer_identifier() + converter_id = _make_converter_identifier() + + data = { + "class_name": "CrescendoAttack", + "class_module": "pyrit.executor.attack.multi_turn.crescendo", + "objective_target_identifier": target_id.to_dict(), + "objective_scorer_identifier": scorer_id.to_dict(), + "request_converter_identifiers": [converter_id.to_dict()], + "attack_specific_params": {"max_turns": 10}, + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result, AttackIdentifier) + assert isinstance(result.objective_target_identifier, TargetIdentifier) + assert isinstance(result.objective_scorer_identifier, ScorerIdentifier) + assert isinstance(result.request_converter_identifiers[0], ConverterIdentifier) + assert result.attack_specific_params == {"max_turns": 10} + + def test_from_dict_already_typed_sub_identifiers_not_re_parsed(self): + """Test that from_dict handles already-typed sub-identifiers without error.""" + target_id = _make_target_identifier() + converter_id = _make_converter_identifier() + + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "objective_target_identifier": target_id, # Already typed, not a dict + "request_converter_identifiers": [converter_id], # Already typed + } + + result = AttackIdentifier.from_dict(data) + + assert result.objective_target_identifier is target_id + assert result.request_converter_identifiers[0] is converter_id + + def test_from_dict_none_converters_stays_none(self): + """Test that None converter list is preserved as None.""" + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "request_converter_identifiers": None, + } + + result = AttackIdentifier.from_dict(data) + assert result.request_converter_identifiers is None + + +class TestAttackIdentifierRoundTrip: + """Test to_dict → from_dict round-trip fidelity.""" + + def test_round_trip_minimal(self): + """Test round-trip with minimal fields.""" + original = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + + restored = AttackIdentifier.from_dict(original.to_dict()) + + assert restored.class_name == original.class_name + assert restored.class_module == original.class_module + assert restored.hash == original.hash + + def test_round_trip_with_nested_identifiers(self): + """Test round-trip preserves nested sub-identifiers.""" + original = AttackIdentifier( + class_name="CrescendoAttack", + class_module="pyrit.executor.attack.multi_turn.crescendo", + objective_target_identifier=_make_target_identifier(), + objective_scorer_identifier=_make_scorer_identifier(), + request_converter_identifiers=[_make_converter_identifier()], + ) + + restored = AttackIdentifier.from_dict(original.to_dict()) + + assert isinstance(restored.objective_target_identifier, TargetIdentifier) + assert isinstance(restored.objective_scorer_identifier, ScorerIdentifier) + assert isinstance(restored.request_converter_identifiers[0], ConverterIdentifier) + assert restored.hash == original.hash + + def test_round_trip_with_attack_specific_params(self): + """Test round-trip preserves attack_specific_params.""" + original = AttackIdentifier( + class_name="TreeOfAttacks", + class_module="pyrit.executor.attack.multi_turn.tree_of_attacks", + attack_specific_params={"width": 3, "depth": 5, "pruning": True}, + ) + + restored = AttackIdentifier.from_dict(original.to_dict()) + + assert restored.attack_specific_params == {"width": 3, "depth": 5, "pruning": True} + assert restored.hash == original.hash From e0f1c4e79b38c2b68dd610af614087230505309a Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 12 Feb 2026 03:42:56 -0800 Subject: [PATCH 29/35] feat: align labels with SDK path, optimize pagination, add lineage tracking Labels ------ - Source labels from PromptMemoryEntry.labels instead of AttackResult.metadata - Add _collect_labels_from_pieces helper to derive labels from pieces - Add get_unique_attack_labels() to MemoryInterface using JOIN+DISTINCT - Simplify /labels route to delegate to get_unique_attack_labels() - Remove labels from attack metadata in create_attack - Forward labels to prepended pieces and add_message pieces - Inherit labels from existing pieces when adding new messages Pagination ---------- - Refactor list_attacks_async into two phases: Phase 1: query + filter + sort on lightweight AttackResult objects Phase 2: fetch pieces only for the final paginated page - Rename _paginate to _paginate_attack_results (operates on AttackResult) Lineage tracking ---------------- - Add original_prompt_id to MessagePieceRequest DTO - Forward original_prompt_id through mapper to MessagePiece domain object Validation & style ------------------ - Add max_length=50 on PrependedMessageRequest.pieces and AddMessageRequest.pieces - Add max_length=200 on CreateAttackRequest.prepended_conversation - Add TOCTOU comment on sequence read-then-write in add_message_async - Add import placement rule to style guide - Move contextlib.closing import to top of memory_interface.py - Remove unused TYPE_CHECKING import from labels route Tests ----- - Add get_unique_attack_labels tests (empty, single, merge, no pieces, no labels, non-attack pieces, non-string values, sorted keys) - Add pagination test verifying pieces fetched only for page - Add original_prompt_id forwarding and default tests in mappers - Add labels stamping tests for prepended pieces and add_message - Add prepend ordering + lineage preservation test - Add _collect_labels_from_pieces tests - Update existing tests to match new label source (pieces, not metadata) --- .../instructions/style-guide.instructions.md | 21 +++ pyrit/backend/mappers/attack_mappers.py | 29 ++- pyrit/backend/models/attacks.py | 13 +- pyrit/backend/routes/labels.py | 35 +--- pyrit/backend/services/attack_service.py | 79 ++++---- pyrit/memory/memory_interface.py | 42 ++++- tests/unit/backend/test_api_routes.py | 76 +------- tests/unit/backend/test_attack_service.py | 170 ++++++++++++++++-- tests/unit/backend/test_mappers.py | 124 ++++++++++++- .../test_interface_attack_results.py | 96 ++++++++++ 10 files changed, 521 insertions(+), 164 deletions(-) diff --git a/.github/instructions/style-guide.instructions.md b/.github/instructions/style-guide.instructions.md index 5522686767..818554ee88 100644 --- a/.github/instructions/style-guide.instructions.md +++ b/.github/instructions/style-guide.instructions.md @@ -96,6 +96,27 @@ def process(self, data: str) -> str: ## Documentation Standards +### Import Placement +- **MANDATORY**: All import statements MUST be at the top of the file +- Do NOT use inline/local imports inside functions or methods +- The only exception is breaking circular import dependencies, which should be rare and documented + +```python +# CORRECT — imports at the top of the file +from contextlib import closing +from sqlalchemy.exc import SQLAlchemyError + +def update_entry(self, entry: Base) -> None: + with closing(self.get_session()) as session: + ... + +# INCORRECT — inline import inside a function +def update_entry(self, entry: Base) -> None: + from contextlib import closing # ← WRONG, must be at top of file + with closing(self.get_session()) as session: + ... +``` + ### Docstring Format - Use Google-style docstrings - Include type information in parameter descriptions diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 5c02ef88ef..06a3f09ed2 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -76,7 +76,7 @@ def attack_result_to_summary( outcome=map_outcome(ar.outcome), last_message_preview=last_preview, message_count=message_count, - labels=ar.metadata.get("labels", {}), + labels=_collect_labels_from_pieces(pieces), created_at=created_at, updated_at=updated_at, ) @@ -175,6 +175,7 @@ def request_piece_to_pyrit_message_piece( role: str, conversation_id: str, sequence: int, + labels: Optional[Dict[str, str]] = None, ) -> PyritMessagePiece: """ Convert a single request piece DTO to a PyRIT MessagePiece domain object. @@ -184,11 +185,14 @@ def request_piece_to_pyrit_message_piece( role: The message role. conversation_id: The conversation/attack ID. sequence: The message sequence number. + labels: Optional labels to stamp on the piece. Returns: PyritMessagePiece domain object. """ metadata = {"mime_type": piece.mime_type} if getattr(piece, "mime_type", None) else None + raw_id = getattr(piece, "original_prompt_id", None) + original_prompt_id = uuid.UUID(raw_id) if raw_id else None return PyritMessagePiece( role=role, original_value=piece.original_value, @@ -198,6 +202,8 @@ def request_piece_to_pyrit_message_piece( conversation_id=conversation_id, sequence=sequence, prompt_metadata=metadata, + labels=labels or {}, + original_prompt_id=original_prompt_id, ) @@ -206,6 +212,7 @@ def request_to_pyrit_message( request: AddMessageRequest, conversation_id: str, sequence: int, + labels: Optional[Dict[str, str]] = None, ) -> PyritMessage: """ Build a PyRIT Message from an AddMessageRequest DTO. @@ -214,6 +221,7 @@ def request_to_pyrit_message( request: The inbound API request. conversation_id: The conversation/attack ID. sequence: The message sequence number. + labels: Optional labels to stamp on each piece. Returns: PyritMessage ready to send to the target. @@ -224,6 +232,7 @@ def request_to_pyrit_message( role=request.role, conversation_id=conversation_id, sequence=sequence, + labels=labels, ) for p in request.pieces ] @@ -247,3 +256,21 @@ def _get_preview_from_pieces(pieces: List[Any]) -> Optional[str]: last_piece = max(pieces, key=lambda p: p.sequence) text = last_piece.converted_value or "" return text[:100] + "..." if len(text) > 100 else text + + +def _collect_labels_from_pieces(pieces: List[Any]) -> Dict[str, str]: + """ + Collect labels from message pieces. + + Returns the labels from the first piece that has non-empty labels. + All pieces in an attack share the same labels, so the first match + is representative. + + Returns: + Label dict, or empty dict if no pieces have labels. + """ + for p in pieces: + labels = getattr(p, "labels", None) + if labels: + return dict(labels) + return {} diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 215e9cb419..5819ae0441 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -126,13 +126,20 @@ class MessagePieceRequest(BaseModel): original_value: str = Field(..., description="Original value (text or base64 for media)") converted_value: Optional[str] = Field(None, description="Converted value. If provided, bypasses converters.") mime_type: Optional[str] = Field(None, description="MIME type for media content") + original_prompt_id: Optional[str] = Field( + None, + description="ID of the source piece when prepending from an existing conversation. " + "Preserves lineage so the new piece traces back to the original.", + ) class PrependedMessageRequest(BaseModel): """A message to prepend to the attack (for system prompt/branching).""" role: Literal["user", "assistant", "system"] = Field(..., description="Message role") - pieces: List[MessagePieceRequest] = Field(..., description="Message pieces (supports multimodal)") + pieces: List[MessagePieceRequest] = Field( + ..., description="Message pieces (supports multimodal)", max_length=50 + ) class CreateAttackRequest(BaseModel): @@ -141,7 +148,7 @@ class CreateAttackRequest(BaseModel): name: Optional[str] = Field(None, description="Attack name/label") target_id: str = Field(..., description="Target instance ID to attack") prepended_conversation: Optional[List[PrependedMessageRequest]] = Field( - None, description="Messages to prepend (system prompts, branching context)" + None, description="Messages to prepend (system prompts, branching context)", max_length=200 ) labels: Optional[Dict[str, str]] = Field(None, description="User-defined labels for filtering") @@ -179,7 +186,7 @@ class AddMessageRequest(BaseModel): """ role: Literal["user", "assistant", "system"] = Field(default="user", description="Message role") - pieces: List[MessagePieceRequest] = Field(..., description="Message pieces") + pieces: List[MessagePieceRequest] = Field(..., description="Message pieces", max_length=50) send: bool = Field( default=True, description="If True, send to target and wait for response. If False, just store in memory.", diff --git a/pyrit/backend/routes/labels.py b/pyrit/backend/routes/labels.py index ce057fbf04..e2f2d64c40 100644 --- a/pyrit/backend/routes/labels.py +++ b/pyrit/backend/routes/labels.py @@ -7,16 +7,13 @@ Provides access to unique label values for filtering in the GUI. """ -from typing import TYPE_CHECKING, Dict, List, Literal +from typing import Dict, List, Literal from fastapi import APIRouter, Query from pydantic import BaseModel, Field from pyrit.memory import CentralMemory -if TYPE_CHECKING: - from pyrit.memory import MemoryInterface - router = APIRouter(prefix="/labels", tags=["labels"]) @@ -52,37 +49,9 @@ async def get_label_options( memory = CentralMemory.get_memory_instance() if source == "attacks": - labels = _get_attack_labels(memory) + labels = memory.get_unique_attack_labels() else: # Future: add support for other sources labels = {} return LabelOptionsResponse(source=source, labels=labels) - - -def _get_attack_labels(memory: "MemoryInterface") -> Dict[str, List[str]]: - """ - Extract unique labels from all attack results. - - Returns: - Dict mapping label keys to sorted lists of unique values. - """ - attack_results = memory.get_attack_results() - - # Collect all unique key-value pairs - label_values: Dict[str, set[str]] = {} - - for ar in attack_results: - if ar.metadata: - for key, value in ar.metadata.items(): - # Skip internal metadata keys - if key.startswith("_") or key in ("created_at", "updated_at"): - continue - # Only include string values - if isinstance(value, str): - if key not in label_values: - label_values[key] = set() - label_values[key].add(value) - - # Convert sets to sorted lists - return {key: sorted(values) for key, values in sorted(label_values.items())} diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index d41ccdfab7..57fe79926e 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -89,44 +89,39 @@ async def list_attacks_async( Returns: AttackListResponse with filtered and paginated attack summaries. """ - # Map outcome string to AttackOutcome enum value for filtering - outcome_filter = outcome # Already matches enum values - - # Use labels filter at the database level if supported + # Phase 1: Query + lightweight filtering (no pieces needed) attack_results = self._memory.get_attack_results( - outcome=outcome_filter, + outcome=outcome, labels=labels, ) - # Convert to summaries and apply filters - summaries = [] + filtered: List[AttackResult] = [] for ar in attack_results: - # Filter by target_id - ar_target_id = ar.attack_identifier.get("target_id", "") - if target_id and ar_target_id != target_id: + if target_id and ar.attack_identifier.get("target_id", "") != target_id: continue - - # Filter by name (substring match) - ar_name = ar.attack_identifier.get("name", "") - if name and name.lower() not in ar_name.lower(): + if name and name.lower() not in ar.attack_identifier.get("name", "").lower(): continue - - # Filter by executed_turns if min_turns is not None and ar.executed_turns < min_turns: continue if max_turns is not None and ar.executed_turns > max_turns: continue + filtered.append(ar) - pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) - summary = attack_result_to_summary(ar, pieces=pieces) - summaries.append(summary) + # Sort by most recent (metadata lives on AttackResult, no pieces needed) + filtered.sort( + key=lambda ar: ar.metadata.get("updated_at", ar.metadata.get("created_at", "")), + reverse=True, + ) - # Sort by most recent - summaries.sort(key=lambda s: s.updated_at, reverse=True) + # Paginate on the lightweight list first + page_results, has_more = self._paginate_attack_results(filtered, cursor, limit) + next_cursor = page_results[-1].conversation_id if has_more and page_results else None - # Paginate - page, has_more = self._paginate(summaries, cursor, limit) - next_cursor = page[-1].attack_id if has_more and page else None + # Phase 2: Fetch pieces only for the page we're returning + page: List[AttackSummary] = [] + for ar in page_results: + pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) + page.append(attack_result_to_summary(ar, pieces=pieces)) return AttackListResponse( items=page, @@ -203,7 +198,6 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt metadata={ "created_at": now.isoformat(), "updated_at": now.isoformat(), - "labels": request.labels or {}, }, ) @@ -215,6 +209,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt await self._store_prepended_messages( conversation_id=conversation_id, prepended=request.prepended_conversation, + labels=request.labels, ) return CreateAttackResponse(attack_id=conversation_id, created_at=now) @@ -270,14 +265,20 @@ async def add_message_async(self, *, attack_id: str, request: AddMessageRequest) if not target_id: raise ValueError(f"Attack '{attack_id}' has no target configured") - # Get existing messages to determine sequence + # Get existing messages to determine sequence. + # NOTE: This read-then-write is not atomic (TOCTOU). Fine for the + # current single-user UI, but would need a DB-level sequence + # generator or optimistic locking if concurrent writes are supported. existing = self._memory.get_message_pieces(conversation_id=attack_id) sequence = max((p.sequence for p in existing), default=-1) + 1 + # Inherit labels from existing pieces so new messages stay consistent + attack_labels = next((p.labels for p in existing if getattr(p, "labels", None)), None) + if request.send: - await self._send_and_store_message(attack_id, target_id, request, sequence) + await self._send_and_store_message(attack_id, target_id, request, sequence, labels=attack_labels) else: - await self._store_message_only(attack_id, request, sequence) + await self._store_message_only(attack_id, request, sequence, labels=attack_labels) # Update attack timestamp ar.metadata["updated_at"] = datetime.now(timezone.utc).isoformat() @@ -296,11 +297,14 @@ async def add_message_async(self, *, attack_id: str, request: AddMessageRequest) # Private Helper Methods - Pagination # ======================================================================== - def _paginate( - self, items: List[AttackSummary], cursor: Optional[str], limit: int - ) -> tuple[List[AttackSummary], bool]: + def _paginate_attack_results( + self, items: List[AttackResult], cursor: Optional[str], limit: int + ) -> tuple[List[AttackResult], bool]: """ - Apply cursor-based pagination. + Apply cursor-based pagination over AttackResult objects. + + Operates on lightweight AttackResult objects before pieces are fetched, + so only the final page incurs per-attack piece queries. Returns: Tuple of (paginated items, has_more flag). @@ -308,7 +312,7 @@ def _paginate( start_idx = 0 if cursor: for i, item in enumerate(items): - if item.attack_id == cursor: + if item.conversation_id == cursor: start_idx = i + 1 break @@ -324,6 +328,7 @@ async def _store_prepended_messages( self, conversation_id: str, prepended: List[Any], + labels: Optional[Dict[str, str]] = None, ) -> None: """Store prepended conversation messages in memory.""" seq = 0 @@ -334,6 +339,7 @@ async def _store_prepended_messages( role=msg.role, conversation_id=conversation_id, sequence=seq, + labels=labels, ) self._memory.add_message_pieces_to_memory(message_pieces=[piece]) seq += 1 @@ -344,6 +350,8 @@ async def _send_and_store_message( target_id: str, request: AddMessageRequest, sequence: int, + *, + labels: Optional[Dict[str, str]] = None, ) -> None: """Send message to target via normalizer and store response.""" target_obj = get_target_service().get_target_object(target_id=target_id) @@ -354,6 +362,7 @@ async def _send_and_store_message( request=request, conversation_id=attack_id, sequence=sequence, + labels=labels, ) converter_configs = self._get_converter_configs(request) @@ -363,6 +372,7 @@ async def _send_and_store_message( target=target_obj, conversation_id=attack_id, request_converter_configurations=converter_configs, + labels=labels, ) # PromptNormalizer stores both request and response in memory automatically @@ -371,6 +381,8 @@ async def _store_message_only( attack_id: str, request: AddMessageRequest, sequence: int, + *, + labels: Optional[Dict[str, str]] = None, ) -> None: """Store message without sending (send=False).""" for p in request.pieces: @@ -379,6 +391,7 @@ async def _store_message_only( role=request.role, conversation_id=attack_id, sequence=sequence, + labels=labels, ) self._memory.add_message_pieces_to_memory(message_pieces=[piece]) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 9fc682d6d0..0f0d2d8847 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -7,6 +7,7 @@ import uuid import warnings import weakref +from contextlib import closing from datetime import datetime from pathlib import Path from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union @@ -238,8 +239,6 @@ def _update_entry(self, entry: Base) -> None: Raises: SQLAlchemyError: If there's an error during the database operation. """ - from contextlib import closing - from sqlalchemy.exc import SQLAlchemyError with closing(self.get_session()) as session: @@ -1274,6 +1273,45 @@ def get_attack_results( logger.exception(f"Failed to retrieve attack results with error {e}") raise + def get_unique_attack_labels(self) -> dict[str, list[str]]: + """ + Return all unique label key-value pairs across attack results. + + Labels live on ``PromptMemoryEntry.labels`` (the established SDK + path). This method JOINs with ``AttackResultEntry`` to scope the + query to conversations that belong to an attack, applies DISTINCT + to reduce duplicate label dicts, then aggregates unique key-value + pairs in Python. + + Returns: + dict[str, list[str]]: Mapping of label keys to sorted lists of + unique values. + """ + label_values: dict[str, set[str]] = {} + + with closing(self.get_session()) as session: + rows = ( + session.query(PromptMemoryEntry.labels) + .join( + AttackResultEntry, + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + ) + .filter(PromptMemoryEntry.labels.isnot(None)) + .distinct() + .all() + ) + + for (labels,) in rows: + if not isinstance(labels, dict): + continue + for key, value in labels.items(): + if isinstance(value, str): + if key not in label_values: + label_values[key] = set() + label_values[key].add(value) + + return {key: sorted(values) for key, values in sorted(label_values.items())} + def add_scenario_results_to_memory(self, *, scenario_results: Sequence[ScenarioResult]) -> None: """ Insert a list of scenario results into the memory storage. diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index e42b031e57..6c057b522f 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -808,12 +808,9 @@ class TestLabelsRoutes: def test_get_labels_for_attacks(self, client: TestClient) -> None: """Test getting labels from attack results.""" - mock_attack_result = MagicMock() - mock_attack_result.metadata = {"env": "prod", "team": "red", "created_at": "2024-01-01"} - with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_attack_results.return_value = [mock_attack_result] + mock_memory.get_unique_attack_labels.return_value = {"env": ["prod"], "team": ["red"]} mock_memory_class.get_memory_instance.return_value = mock_memory response = client.get("/api/labels?source=attacks") @@ -821,16 +818,14 @@ def test_get_labels_for_attacks(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() assert data["source"] == "attacks" - # env and team should be included, created_at should be excluded - assert "env" in data["labels"] - assert "team" in data["labels"] - assert "created_at" not in data["labels"] + assert data["labels"] == {"env": ["prod"], "team": ["red"]} + mock_memory.get_unique_attack_labels.assert_called_once() def test_get_labels_empty(self, client: TestClient) -> None: """Test getting labels when no attack results exist.""" with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_attack_results.return_value = [] + mock_memory.get_unique_attack_labels.return_value = {} mock_memory_class.get_memory_instance.return_value = mock_memory response = client.get("/api/labels?source=attacks") @@ -842,76 +837,21 @@ def test_get_labels_empty(self, client: TestClient) -> None: def test_get_labels_multiple_values(self, client: TestClient) -> None: """Test getting labels with multiple values per key.""" - mock_ar1 = MagicMock() - mock_ar1.metadata = {"env": "prod"} - mock_ar2 = MagicMock() - mock_ar2.metadata = {"env": "staging"} - mock_ar3 = MagicMock() - mock_ar3.metadata = {"env": "prod", "team": "blue"} - with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: mock_memory = MagicMock() - mock_memory.get_attack_results.return_value = [mock_ar1, mock_ar2, mock_ar3] + mock_memory.get_unique_attack_labels.return_value = { + "env": ["prod", "staging"], + "team": ["blue"], + } mock_memory_class.get_memory_instance.return_value = mock_memory response = client.get("/api/labels") assert response.status_code == status.HTTP_200_OK data = response.json() - # Should have both env values sorted assert set(data["labels"]["env"]) == {"prod", "staging"} assert data["labels"]["team"] == ["blue"] - def test_get_labels_skips_internal_metadata(self, client: TestClient) -> None: - """Test that internal metadata keys are skipped.""" - mock_ar = MagicMock() - mock_ar.metadata = { - "_internal": "value", - "created_at": "2024-01-01", - "updated_at": "2024-01-02", - "visible_label": "keep", - } - - with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_attack_results.return_value = [mock_ar] - mock_memory_class.get_memory_instance.return_value = mock_memory - - response = client.get("/api/labels") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - # Only visible_label should be included - assert "visible_label" in data["labels"] - assert "_internal" not in data["labels"] - assert "created_at" not in data["labels"] - assert "updated_at" not in data["labels"] - - def test_get_labels_skips_non_string_values(self, client: TestClient) -> None: - """Test that non-string metadata values are skipped.""" - mock_ar = MagicMock() - mock_ar.metadata = { - "string_val": "keep", - "int_val": 123, - "list_val": ["a", "b"], - "dict_val": {"nested": "value"}, - } - - with patch("pyrit.backend.routes.labels.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_attack_results.return_value = [mock_ar] - mock_memory_class.get_memory_instance.return_value = mock_memory - - response = client.get("/api/labels") - - assert response.status_code == status.HTTP_200_OK - data = response.json() - # Only string_val should be included - assert "string_val" in data["labels"] - assert "int_val" not in data["labels"] - assert "list_val" not in data["labels"] - assert "dict_val" not in data["labels"] - @pytest.mark.asyncio async def test_get_label_options_unsupported_source_returns_empty_labels(self) -> None: """Test that get_label_options returns empty labels for unsupported source types.""" diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index c443923104..b1371e2ab9 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -55,7 +55,6 @@ def make_attack_result( outcome: AttackOutcome = AttackOutcome.UNDETERMINED, created_at: datetime = None, updated_at: datetime = None, - labels: dict = None, ) -> AttackResult: """Create a mock AttackResult for testing.""" now = datetime.now(timezone.utc) @@ -74,7 +73,6 @@ def make_attack_result( metadata={ "created_at": created.isoformat(), "updated_at": updated.isoformat(), - "labels": labels or {}, }, ) @@ -222,13 +220,14 @@ async def test_list_attacks_filters_by_max_turns(self, attack_service, mock_memo @pytest.mark.asyncio async def test_list_attacks_includes_labels_in_summary(self, attack_service, mock_memory) -> None: - """Test that list_attacks includes labels from metadata in summaries.""" + """Test that list_attacks includes labels from message pieces in summaries.""" ar = make_attack_result( conversation_id="attack-1", - labels={"env": "prod", "team": "red"}, ) mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_message_pieces.return_value = [] + piece = make_mock_piece(conversation_id="attack-1") + piece.labels = {"env": "prod", "team": "red"} + mock_memory.get_message_pieces.return_value = [piece] result = await attack_service.list_attacks_async() @@ -366,8 +365,8 @@ async def test_create_attack_stores_prepended_conversation(self, attack_service, mock_memory.add_message_pieces_to_memory.assert_called() @pytest.mark.asyncio - async def test_create_attack_stores_labels_under_metadata_key(self, attack_service, mock_memory) -> None: - """Test that create_attack stores labels under metadata['labels'], not spread.""" + async def test_create_attack_does_not_store_labels_in_metadata(self, attack_service, mock_memory) -> None: + """Test that labels are not stored in attack metadata (they live on pieces).""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) @@ -381,14 +380,98 @@ async def test_create_attack_stores_labels_under_metadata_key(self, attack_servi ) ) - # Verify the AttackResult stored in memory has labels nested under metadata["labels"] call_args = mock_memory.add_attack_results_to_memory.call_args stored_ar = call_args[1]["attack_results"][0] - assert "labels" in stored_ar.metadata - assert stored_ar.metadata["labels"] == {"env": "prod", "team": "red"} - # Labels should NOT be spread as top-level metadata keys - assert "env" not in stored_ar.metadata - assert "team" not in stored_ar.metadata + assert "labels" not in stored_ar.metadata + + @pytest.mark.asyncio + async def test_create_attack_stamps_labels_on_prepended_pieces(self, attack_service, mock_memory) -> None: + """Test that labels are forwarded to prepended message pieces.""" + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_get_target_service.return_value = mock_target_service + + prepended = [ + PrependedMessageRequest( + role="system", + pieces=[MessagePieceRequest(original_value="Be helpful.")], + ) + ] + + await attack_service.create_attack_async( + request=CreateAttackRequest( + target_id="target-1", + labels={"env": "prod"}, + prepended_conversation=prepended, + ) + ) + + stored_piece = mock_memory.add_message_pieces_to_memory.call_args[1]["message_pieces"][0] + assert stored_piece.labels == {"env": "prod"} + + @pytest.mark.asyncio + async def test_create_attack_prepended_messages_have_incrementing_sequences( + self, attack_service, mock_memory + ) -> None: + """Test that multiple prepended messages get incrementing sequence numbers and preserve lineage.""" + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_service = MagicMock() + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_get_target_service.return_value = mock_target_service + + original_id_1 = "aaaaaaaa-1111-2222-3333-444444444444" + original_id_2 = "bbbbbbbb-1111-2222-3333-444444444444" + original_id_3 = "cccccccc-1111-2222-3333-444444444444" + + prepended = [ + PrependedMessageRequest( + role="system", + pieces=[ + MessagePieceRequest( + original_value="You are a helpful assistant.", + original_prompt_id=original_id_1, + ) + ], + ), + PrependedMessageRequest( + role="user", + pieces=[ + MessagePieceRequest(original_value="Hello", original_prompt_id=original_id_2), + ], + ), + PrependedMessageRequest( + role="assistant", + pieces=[ + MessagePieceRequest(original_value="Hi there!", original_prompt_id=original_id_3), + ], + ), + ] + + await attack_service.create_attack_async( + request=CreateAttackRequest(target_id="target-1", prepended_conversation=prepended) + ) + + # Each message stored separately with incrementing sequence + calls = mock_memory.add_message_pieces_to_memory.call_args_list + assert len(calls) == 3 + sequences = [call[1]["message_pieces"][0].sequence for call in calls] + assert sequences == [0, 1, 2] + + roles = [call[1]["message_pieces"][0].api_role for call in calls] + assert roles == ["system", "user", "assistant"] + + # original_prompt_id preserved for lineage tracking + import uuid + + stored_pieces = [call[1]["message_pieces"][0] for call in calls] + assert stored_pieces[0].original_prompt_id == uuid.UUID(original_id_1) + assert stored_pieces[1].original_prompt_id == uuid.UUID(original_id_2) + assert stored_pieces[2].original_prompt_id == uuid.UUID(original_id_3) + + # Each piece gets its own new id, different from the original + for piece in stored_pieces: + assert piece.id != piece.original_prompt_id # ============================================================================ @@ -444,23 +527,60 @@ async def test_add_message_raises_for_nonexistent_attack(self, attack_service, m await attack_service.add_message_async(attack_id="nonexistent", request=request) @pytest.mark.asyncio - async def test_add_message_without_send_stores_message(self, attack_service, mock_memory) -> None: - """Test that add_message with send=False stores message in memory.""" + async def test_add_message_without_send_stamps_labels_on_pieces(self, attack_service, mock_memory) -> None: + """Test that add_message (send=False) inherits labels from existing pieces.""" ar = make_attack_result(conversation_id="test-id", target_id="target-1") mock_memory.get_attack_results.return_value = [ar] - mock_memory.get_message_pieces.return_value = [] + + existing_piece = make_mock_piece(conversation_id="test-id") + existing_piece.labels = {"env": "prod"} + mock_memory.get_message_pieces.return_value = [existing_piece] mock_memory.get_conversation.return_value = [] request = AddMessageRequest( - role="system", - pieces=[MessagePieceRequest(original_value="You are a helpful assistant.")], + role="user", + pieces=[MessagePieceRequest(original_value="Hello")], send=False, ) result = await attack_service.add_message_async(attack_id="test-id", request=request) + stored_piece = mock_memory.add_message_pieces_to_memory.call_args[1]["message_pieces"][0] + assert stored_piece.labels == {"env": "prod"} assert result.attack is not None - mock_memory.add_message_pieces_to_memory.assert_called() + + @pytest.mark.asyncio + async def test_add_message_with_send_passes_labels_to_normalizer(self, attack_service, mock_memory) -> None: + """Test that add_message (send=True) inherits labels from existing pieces.""" + ar = make_attack_result(conversation_id="test-id", target_id="target-1") + mock_memory.get_attack_results.return_value = [ar] + + existing_piece = make_mock_piece(conversation_id="test-id") + existing_piece.labels = {"env": "staging"} + mock_memory.get_message_pieces.return_value = [existing_piece] + mock_memory.get_conversation.return_value = [] + + with ( + patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, + patch("pyrit.backend.services.attack_service.PromptNormalizer") as mock_normalizer_cls, + ): + mock_target_svc = MagicMock() + mock_target_svc.get_target_object.return_value = MagicMock() + mock_get_target_svc.return_value = mock_target_svc + + mock_normalizer = MagicMock() + mock_normalizer.send_prompt_async = AsyncMock() + mock_normalizer_cls.return_value = mock_normalizer + + request = AddMessageRequest( + pieces=[MessagePieceRequest(original_value="Hello")], + send=True, + ) + + await attack_service.add_message_async(attack_id="test-id", request=request) + + call_kwargs = mock_normalizer.send_prompt_async.call_args[1] + assert call_kwargs["labels"] == {"env": "staging"} @pytest.mark.asyncio async def test_add_message_raises_when_no_target_id(self, attack_service, mock_memory) -> None: @@ -666,6 +786,18 @@ async def test_list_attacks_cursor_skips_to_correct_position(self, attack_servic assert "attack-1" not in attack_ids assert len(result.items) == 2 + @pytest.mark.asyncio + async def test_list_attacks_fetches_pieces_only_for_page(self, attack_service, mock_memory) -> None: + """Test that pieces are fetched only for the paginated page, not all attacks.""" + attacks = [make_attack_result(conversation_id=f"attack-{i}") for i in range(5)] + mock_memory.get_attack_results.return_value = attacks + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async(limit=2) + + # get_message_pieces should be called only for the 2 items on the page, not all 5 + assert mock_memory.get_message_pieces.call_count == 2 + # ============================================================================ # Message Building Tests diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 7069aa4282..51ba3c5e7d 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -14,6 +14,7 @@ import pytest from pyrit.backend.mappers.attack_mappers import ( + _collect_labels_from_pieces, _infer_mime_type, attack_result_to_summary, map_outcome, @@ -39,7 +40,6 @@ def _make_attack_result( target_type: str = "TextTarget", name: str = "Test Attack", outcome: AttackOutcome = AttackOutcome.UNDETERMINED, - labels: dict = None, ) -> AttackResult: """Create an AttackResult for mapper tests.""" now = datetime.now(timezone.utc) @@ -55,7 +55,6 @@ def _make_attack_result( metadata={ "created_at": now.isoformat(), "updated_at": now.isoformat(), - "labels": labels or {}, }, ) @@ -149,10 +148,12 @@ def test_last_message_preview_truncated(self) -> None: assert summary.last_message_preview.endswith("...") def test_labels_are_mapped(self) -> None: - """Test that labels are extracted from metadata.""" - ar = _make_attack_result(labels={"env": "prod", "team": "red"}) + """Test that labels are derived from pieces.""" + ar = _make_attack_result() + piece = _make_mock_piece() + piece.labels = {"env": "prod", "team": "red"} - summary = attack_result_to_summary(ar, pieces=[]) + summary = attack_result_to_summary(ar, pieces=[piece]) assert summary.labels == {"env": "prod", "team": "red"} @@ -270,6 +271,7 @@ def test_converts_request_to_domain(self) -> None: piece.data_type = "text" piece.original_value = "hello" piece.converted_value = None + piece.original_prompt_id = None request.pieces = [piece] result = request_to_pyrit_message( @@ -293,6 +295,7 @@ def test_uses_converted_value_when_present(self) -> None: piece.data_type = "text" piece.original_value = "original" piece.converted_value = "converted" + piece.original_prompt_id = None result = request_piece_to_pyrit_message_piece( piece=piece, @@ -312,6 +315,7 @@ def test_falls_back_to_original_when_no_converted(self) -> None: piece.data_type = "text" piece.original_value = "fallback" piece.converted_value = None + piece.original_prompt_id = None result = request_piece_to_pyrit_message_piece( piece=piece, @@ -329,6 +333,7 @@ def test_passes_mime_type_through_prompt_metadata(self) -> None: piece.original_value = "base64data" piece.converted_value = None piece.mime_type = "image/png" + piece.original_prompt_id = None result = request_piece_to_pyrit_message_piece( piece=piece, @@ -346,6 +351,7 @@ def test_no_metadata_when_mime_type_absent(self) -> None: piece.original_value = "hello" piece.converted_value = None piece.mime_type = None + piece.original_prompt_id = None result = request_piece_to_pyrit_message_piece( piece=piece, @@ -356,6 +362,83 @@ def test_no_metadata_when_mime_type_absent(self) -> None: assert result.prompt_metadata == {} + def test_labels_are_stamped_on_piece(self) -> None: + """Test that labels are passed through to the MessagePiece.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + piece.mime_type = None + piece.original_prompt_id = None + + result = request_piece_to_pyrit_message_piece( + piece=piece, + role="user", + conversation_id="conv-1", + sequence=0, + labels={"env": "prod"}, + ) + + assert result.labels == {"env": "prod"} + + def test_labels_default_to_empty_dict(self) -> None: + """Test that labels default to empty dict when not provided.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + piece.mime_type = None + piece.original_prompt_id = None + + result = request_piece_to_pyrit_message_piece( + piece=piece, + role="user", + conversation_id="conv-1", + sequence=0, + ) + + assert result.labels == {} + + def test_original_prompt_id_forwarded_when_provided(self) -> None: + """Test that original_prompt_id is passed through for lineage tracking.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + piece.mime_type = None + piece.original_prompt_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + result = request_piece_to_pyrit_message_piece( + piece=piece, + role="user", + conversation_id="conv-1", + sequence=0, + ) + + import uuid + + assert result.original_prompt_id == uuid.UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + # New piece should have its own id, different from original_prompt_id + assert result.id != result.original_prompt_id + + def test_original_prompt_id_defaults_to_self_when_absent(self) -> None: + """Test that original_prompt_id defaults to the piece's own id when not provided.""" + piece = MagicMock() + piece.data_type = "text" + piece.original_value = "hello" + piece.converted_value = None + piece.mime_type = None + piece.original_prompt_id = None + + result = request_piece_to_pyrit_message_piece( + piece=piece, + role="user", + conversation_id="conv-1", + sequence=0, + ) + + assert result.original_prompt_id == result.id + class TestInferMimeType: """Tests for _infer_mime_type helper function.""" @@ -396,6 +479,37 @@ def test_infers_mp4(self) -> None: assert _infer_mime_type(value="/tmp/video.mp4", data_type="video") == "video/mp4" +class TestCollectLabelsFromPieces: + """Tests for _collect_labels_from_pieces helper.""" + + def test_returns_labels_from_first_piece(self) -> None: + """Returns labels from the first piece that has them.""" + p1 = MagicMock() + p1.labels = {"env": "prod"} + p2 = MagicMock() + p2.labels = {"env": "staging"} + + assert _collect_labels_from_pieces([p1, p2]) == {"env": "prod"} + + def test_returns_empty_when_no_pieces(self) -> None: + """Returns empty dict for empty list.""" + assert _collect_labels_from_pieces([]) == {} + + def test_returns_empty_when_pieces_have_no_labels(self) -> None: + """Returns empty dict when pieces have no labels attribute.""" + p = MagicMock(spec=[]) + assert _collect_labels_from_pieces([p]) == {} + + def test_skips_pieces_with_empty_labels(self) -> None: + """Skips pieces with empty/falsy labels.""" + p1 = MagicMock() + p1.labels = {} + p2 = MagicMock() + p2.labels = {"env": "prod"} + + assert _collect_labels_from_pieces([p1, p2]) == {"env": "prod"} + + # ============================================================================ # Target Mapper Tests # ============================================================================ diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index be2fa4c64d..a1972e2680 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -895,3 +895,99 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me results = sqlite_instance.get_attack_results(labels={"op_name": "op_exists", "researcher": "roakey"}) assert len(results) == 1 assert results[0].conversation_id == "conv_1" + + +# --------------------------------------------------------------------------- +# get_unique_attack_labels tests +# --------------------------------------------------------------------------- + + +def test_get_unique_attack_labels_empty(sqlite_instance: MemoryInterface): + """Returns empty dict when there are no attack results.""" + result = sqlite_instance.get_unique_attack_labels() + assert result == {} + + +def test_get_unique_attack_labels_single(sqlite_instance: MemoryInterface): + """Returns labels from a single attack result's message pieces.""" + message = create_message_piece("conv_1", 1, labels={"env": "prod", "team": "red"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message]) + + ar = create_attack_result("conv_1", 1) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + + result = sqlite_instance.get_unique_attack_labels() + assert result == {"env": ["prod"], "team": ["red"]} + + +def test_get_unique_attack_labels_multiple_attacks_merges_values(sqlite_instance: MemoryInterface): + """Values from different attacks are merged and sorted.""" + msg1 = create_message_piece("conv_1", 1, labels={"env": "prod", "team": "red"}) + msg2 = create_message_piece("conv_2", 2, labels={"env": "staging", "team": "red"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg1, msg2]) + + ar1 = create_attack_result("conv_1", 1) + ar2 = create_attack_result("conv_2", 2) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + result = sqlite_instance.get_unique_attack_labels() + assert result == {"env": ["prod", "staging"], "team": ["red"]} + + +def test_get_unique_attack_labels_no_pieces(sqlite_instance: MemoryInterface): + """Attack results without any message pieces return empty dict.""" + ar = create_attack_result("conv_1", 1) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + + result = sqlite_instance.get_unique_attack_labels() + assert result == {} + + +def test_get_unique_attack_labels_pieces_without_labels(sqlite_instance: MemoryInterface): + """Message pieces with no labels are skipped.""" + msg = create_message_piece("conv_1", 1) # labels=None + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg]) + + ar = create_attack_result("conv_1", 1) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + + result = sqlite_instance.get_unique_attack_labels() + assert result == {} + + +def test_get_unique_attack_labels_ignores_non_attack_pieces(sqlite_instance: MemoryInterface): + """Labels on pieces not linked to any attack are excluded.""" + msg = create_message_piece("conv_no_attack", 1, labels={"env": "prod"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg]) + + # No AttackResult for "conv_no_attack" + result = sqlite_instance.get_unique_attack_labels() + assert result == {} + + +def test_get_unique_attack_labels_non_string_values_skipped(sqlite_instance: MemoryInterface): + """Non-string label values are ignored.""" + msg = create_message_piece("conv_1", 1, labels={"env": "prod", "count": 42}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg]) + + ar = create_attack_result("conv_1", 1) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar]) + + result = sqlite_instance.get_unique_attack_labels() + assert result == {"env": ["prod"]} + + +def test_get_unique_attack_labels_keys_sorted(sqlite_instance: MemoryInterface): + """Returned keys and values are sorted alphabetically.""" + msg1 = create_message_piece("conv_1", 1, labels={"zoo": "z_val", "alpha": "a"}) + msg2 = create_message_piece("conv_2", 2, labels={"alpha": "b"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg1, msg2]) + + ar1 = create_attack_result("conv_1", 1) + ar2 = create_attack_result("conv_2", 2) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + result = sqlite_instance.get_unique_attack_labels() + assert list(result.keys()) == ["alpha", "zoo"] + assert result["alpha"] == ["a", "b"] + assert result["zoo"] == ["z_val"] From 72a2eaf426a87eba3c10ddf9eb584252819e42c5 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 13 Feb 2026 06:01:51 -0800 Subject: [PATCH 30/35] refactor: decouple API DTOs from internal PyRIT identifier objects Replace opaque Dict[str, Any] fields in API DTOs with explicit typed fields to prevent internal PyRIT core structures from leaking through the REST API boundary. Models: - AttackSummary: replace attack_identifier dict with attack_type, attack_specific_params, target_unique_name, target_type, converters - TargetInstance: replace identifier dict with target_type, endpoint, model_name, temperature, top_p, max_requests_per_minute, target_specific_params - ConverterInstance: replace type/params with converter_type, supported_input_types, supported_output_types, converter_specific_params, sub_converter_ids - CreateConverterResponse: use converter_type/display_name Mappers: - Extract specific fields from identifier objects via attribute access instead of calling .to_dict() + filter_sensitive_fields() - Add dedicated mapper modules for targets and converters API enhancements: - Add attack list filtering by attack_class, converter_classes, outcome, labels, min/max turns with SQL-level filtering - Add /attacks/attack-options and /attacks/converter-options endpoints - Support three-state converter_classes: None=no filter, []=no converters, non-empty=must have all listed - Stamp source:gui label via labels.setdefault() in create_attack_async - Add get_unique_attack_labels endpoint for label discovery Memory layer: - Add get_unique_attack_class_names, get_unique_converter_class_names, get_unique_attack_labels to MemoryInterface, SQLiteMemory, and AzureSQLMemory - SQL-level filtering for attack results by labels and harm categories Tests: - Expand mapper tests with coverage for no-target, converters extraction, attack_specific_params passthrough, None input/output types - Expand attack service tests for filtering, options, pagination - Update all test mocks to use attribute-based identifiers - Fix mypy errors (Sequence[Any] for pieces, ChatMessageRole for role) --- .../instructions/style-guide.instructions.md | 9 + pyrit/backend/main.py | 2 - pyrit/backend/mappers/__init__.py | 2 +- pyrit/backend/mappers/attack_mappers.py | 33 +- pyrit/backend/mappers/converter_mappers.py | 21 +- pyrit/backend/mappers/target_mappers.py | 25 +- pyrit/backend/models/__init__.py | 6 +- pyrit/backend/models/attacks.py | 49 +- pyrit/backend/models/converters.py | 18 +- pyrit/backend/models/targets.py | 23 +- pyrit/backend/routes/attacks.py | 83 ++- pyrit/backend/routes/targets.py | 17 +- pyrit/backend/services/attack_service.py | 151 ++++-- pyrit/backend/services/converter_service.py | 8 +- pyrit/backend/services/target_service.py | 63 +-- pyrit/memory/azure_sql_memory.py | 95 ++++ pyrit/memory/memory_interface.py | 74 +++ pyrit/memory/sqlite_memory.py | 83 ++- tests/unit/backend/test_api_routes.py | 175 +++++-- tests/unit/backend/test_attack_service.py | 491 +++++++++++++++--- tests/unit/backend/test_converter_service.py | 75 ++- tests/unit/backend/test_main.py | 1 - tests/unit/backend/test_mappers.py | 194 +++++-- tests/unit/backend/test_target_service.py | 52 +- .../test_interface_attack_results.py | 267 +++++++++- 25 files changed, 1601 insertions(+), 416 deletions(-) diff --git a/.github/instructions/style-guide.instructions.md b/.github/instructions/style-guide.instructions.md index 818554ee88..e411fa83d9 100644 --- a/.github/instructions/style-guide.instructions.md +++ b/.github/instructions/style-guide.instructions.md @@ -478,4 +478,13 @@ Before committing code, ensure: --- +## File Editing Rules + +### Never Use `sed` for File Edits +- **MANDATORY**: Never use `sed` (or similar stream-editing CLI tools) to modify source files +- `sed` frequently corrupts files, applies partial edits, or silently fails +- Always use the editor's built-in replace/edit tools (e.g., `replace_string_in_file`, `multi_replace_string_in_file`) to make targeted, verifiable changes + +--- + **Remember**: Clean code is written for humans to read. Make your intent clear and your code self-documenting. diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index 9dc095cab4..ae610a0531 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -6,14 +6,12 @@ """ import os -import sys from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles import pyrit diff --git a/pyrit/backend/mappers/__init__.py b/pyrit/backend/mappers/__init__.py index a2f7d4c029..780ffc8214 100644 --- a/pyrit/backend/mappers/__init__.py +++ b/pyrit/backend/mappers/__init__.py @@ -13,8 +13,8 @@ map_outcome, pyrit_messages_to_dto, pyrit_scores_to_dto, - request_to_pyrit_message, request_piece_to_pyrit_message_piece, + request_to_pyrit_message, ) from pyrit.backend.mappers.converter_mappers import ( converter_object_to_instance, diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 06a3f09ed2..b278f7fb7a 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -11,7 +11,7 @@ import mimetypes import uuid from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional, cast +from typing import Any, Dict, List, Literal, Optional, Sequence, cast from pyrit.backend.models.attacks import ( AddMessageRequest, @@ -20,11 +20,10 @@ MessagePiece, Score, ) -from pyrit.models import AttackOutcome, AttackResult, PromptDataType +from pyrit.models import AttackOutcome, AttackResult, ChatMessageRole, PromptDataType from pyrit.models import Message as PyritMessage from pyrit.models import MessagePiece as PyritMessagePiece - # ============================================================================ # Domain → DTO (for API responses) # ============================================================================ @@ -48,11 +47,14 @@ def map_outcome(outcome: AttackOutcome) -> Optional[Literal["undetermined", "suc def attack_result_to_summary( ar: AttackResult, *, - pieces: List[Any], + pieces: Sequence[Any], ) -> AttackSummary: """ Build an AttackSummary DTO from an AttackResult and its message pieces. + Extracts only the frontend-relevant fields from the internal identifiers, + avoiding leakage of internal PyRIT core structures. + Args: ar: The domain AttackResult. pieces: Pre-fetched message pieces for this conversation. @@ -68,11 +70,19 @@ def attack_result_to_summary( created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) updated_at = datetime.fromisoformat(updated_str) if updated_str else created_at + aid = ar.attack_identifier + + # Extract only frontend-relevant fields from identifiers + target_id = aid.objective_target_identifier if aid else None + converter_ids = aid.request_converter_identifiers if aid else None + return AttackSummary( - attack_id=ar.conversation_id, - name=ar.attack_identifier.get("name"), - target_id=ar.attack_identifier.get("target_id", ""), - target_type=ar.attack_identifier.get("target_type", ""), + conversation_id=ar.conversation_id, + attack_type=aid.class_name if aid else "Unknown", + attack_specific_params=aid.attack_specific_params if aid else None, + target_unique_name=target_id.unique_name if target_id else None, + target_type=target_id.class_name if target_id else None, + converters=[c.class_name for c in converter_ids] if converter_ids else [], outcome=map_outcome(ar.outcome), last_message_preview=last_preview, message_count=message_count, @@ -153,7 +163,6 @@ def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: first = msg.message_pieces[0] if msg.message_pieces else None messages.append( Message( - message_id=str(first.id) if first else str(uuid.uuid4()), turn_number=first.sequence if first else 0, role=first.role if first else "user", pieces=pieces, @@ -172,7 +181,7 @@ def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: def request_piece_to_pyrit_message_piece( *, piece: Any, - role: str, + role: "ChatMessageRole", conversation_id: str, sequence: int, labels: Optional[Dict[str, str]] = None, @@ -244,7 +253,7 @@ def request_to_pyrit_message( # ============================================================================ -def _get_preview_from_pieces(pieces: List[Any]) -> Optional[str]: +def _get_preview_from_pieces(pieces: Sequence[Any]) -> Optional[str]: """ Get a preview of the last message from a list of pieces. @@ -258,7 +267,7 @@ def _get_preview_from_pieces(pieces: List[Any]) -> Optional[str]: return text[:100] + "..." if len(text) > 100 else text -def _collect_labels_from_pieces(pieces: List[Any]) -> Dict[str, str]: +def _collect_labels_from_pieces(pieces: Sequence[Any]) -> Dict[str, str]: """ Collect labels from message pieces. diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py index 292f73dd0c..78e0c54915 100644 --- a/pyrit/backend/mappers/converter_mappers.py +++ b/pyrit/backend/mappers/converter_mappers.py @@ -5,28 +5,39 @@ Converter mappers – domain → DTO translation for converter-related models. """ -from typing import Any +from typing import Any, List, Optional from pyrit.backend.models.converters import ConverterInstance -def converter_object_to_instance(converter_id: str, converter_obj: Any) -> ConverterInstance: +def converter_object_to_instance( + converter_id: str, + converter_obj: Any, + *, + sub_converter_ids: Optional[List[str]] = None, +) -> ConverterInstance: """ Build a ConverterInstance DTO from a registry converter object. + Extracts only the frontend-relevant fields from the internal identifier, + avoiding leakage of internal PyRIT core structures. + Args: converter_id: The unique converter instance identifier. converter_obj: The domain PromptConverter object from the registry. + sub_converter_ids: Optional list of registered converter IDs for sub-converters. Returns: ConverterInstance DTO with metadata derived from the object. """ identifier = converter_obj.get_identifier() - identifier_dict = identifier.to_dict() return ConverterInstance( converter_id=converter_id, - type=identifier_dict.get("class_name", converter_obj.__class__.__name__), + converter_type=identifier.class_name, display_name=None, - params=identifier_dict, + supported_input_types=list(identifier.supported_input_types) if identifier.supported_input_types else [], + supported_output_types=list(identifier.supported_output_types) if identifier.supported_output_types else [], + converter_specific_params=identifier.converter_specific_params, + sub_converter_ids=sub_converter_ids, ) diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index 40ac967637..5b2b5b2972 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -7,29 +7,32 @@ from typing import Any -from pyrit.backend.models.common import filter_sensitive_fields from pyrit.backend.models.targets import TargetInstance -def target_object_to_instance(target_id: str, target_obj: Any) -> TargetInstance: +def target_object_to_instance(target_unique_name: str, target_obj: Any) -> TargetInstance: """ Build a TargetInstance DTO from a registry target object. + Extracts only the frontend-relevant fields from the internal identifier, + avoiding leakage of internal PyRIT core structures. + Args: - target_id: The unique target instance identifier. + target_unique_name: The unique target instance identifier (registry key / unique_name). target_obj: The domain PromptTarget object from the registry. Returns: TargetInstance DTO with metadata derived from the object. """ - identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else {} - identifier_dict = identifier.to_dict() if hasattr(identifier, "to_dict") else identifier - target_type = identifier_dict.get("__type__", target_obj.__class__.__name__) - filtered_params = filter_sensitive_fields(identifier_dict) + identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else None return TargetInstance( - target_id=target_id, - type=target_type, - display_name=None, - params=filtered_params, + target_unique_name=target_unique_name, + target_type=identifier.class_name if identifier else target_obj.__class__.__name__, + endpoint=getattr(identifier, "endpoint", None) if identifier else None, + model_name=getattr(identifier, "model_name", None) if identifier else None, + temperature=getattr(identifier, "temperature", None) if identifier else None, + top_p=getattr(identifier, "top_p", None) if identifier else None, + max_requests_per_minute=getattr(identifier, "max_requests_per_minute", None) if identifier else None, + target_specific_params=getattr(identifier, "target_specific_params", None) if identifier else None, ) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 4e3bcf6635..8f1c79e0c7 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -12,7 +12,9 @@ AddMessageResponse, AttackListResponse, AttackMessagesResponse, + AttackOptionsResponse, AttackSummary, + ConverterOptionsResponse, CreateAttackRequest, CreateAttackResponse, Message, @@ -40,7 +42,6 @@ ) from pyrit.backend.models.targets import ( CreateTargetRequest, - CreateTargetResponse, TargetInstance, TargetListResponse, ) @@ -76,7 +77,8 @@ "PreviewStep", # Targets "CreateTargetRequest", - "CreateTargetResponse", "TargetInstance", "TargetListResponse", + "AttackOptionsResponse", + "ConverterOptionsResponse", ] diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 5819ae0441..60cb5b06f2 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -9,7 +9,7 @@ """ from datetime import datetime -from typing import Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field @@ -36,8 +36,12 @@ class MessagePiece(BaseModel): """ piece_id: str = Field(..., description="Unique piece identifier") - original_value_data_type: str = Field(default="text", description="Data type of the original value: 'text', 'image', 'audio', etc.") - converted_value_data_type: str = Field(default="text", description="Data type of the converted value: 'text', 'image', 'audio', etc.") + original_value_data_type: str = Field( + default="text", description="Data type of the original value: 'text', 'image', 'audio', etc." + ) + converted_value_data_type: str = Field( + default="text", description="Data type of the converted value: 'text', 'image', 'audio', etc." + ) original_value: Optional[str] = Field(default=None, description="Original value before conversion") original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") converted_value: str = Field(..., description="Converted value (text or base64 for media)") @@ -54,7 +58,6 @@ class MessagePiece(BaseModel): class Message(BaseModel): """A message within an attack.""" - message_id: str = Field(..., description="Unique message identifier") turn_number: int = Field(..., description="Turn number in the conversation (1-indexed)") role: Literal["user", "assistant", "system"] = Field(..., description="Message role") pieces: List[MessagePiece] = Field(..., description="Message pieces (multimodal support)") @@ -69,10 +72,14 @@ class Message(BaseModel): class AttackSummary(BaseModel): """Summary view of an attack (for list views, omits full message content).""" - attack_id: str = Field(..., description="Unique attack identifier") - name: Optional[str] = Field(None, description="Attack name/label") - target_id: str = Field(..., description="Target instance ID") - target_type: str = Field(..., description="Target type (e.g., 'azure_openai')") + conversation_id: str = Field(..., description="Unique attack identifier") + attack_type: str = Field(..., description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") + attack_specific_params: Optional[Dict[str, Any]] = Field(None, description="Additional attack-specific parameters") + target_unique_name: Optional[str] = Field(None, description="Unique name of the objective target") + target_type: Optional[str] = Field(None, description="Target class name (e.g., 'OpenAIChatTarget')") + converters: List[str] = Field( + default_factory=list, description="Request converter class names applied in this attack" + ) outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( None, description="Attack outcome (null if not yet determined)" ) @@ -93,7 +100,7 @@ class AttackSummary(BaseModel): class AttackMessagesResponse(BaseModel): """Response containing all messages for an attack.""" - attack_id: str = Field(..., description="Attack identifier") + conversation_id: str = Field(..., description="Attack identifier") messages: List[Message] = Field(default_factory=list, description="All messages in order") @@ -109,6 +116,22 @@ class AttackListResponse(BaseModel): pagination: PaginationInfo = Field(..., description="Pagination metadata") +class AttackOptionsResponse(BaseModel): + """Response containing unique attack class names used across attacks.""" + + attack_classes: List[str] = Field( + ..., description="Sorted list of unique attack class names found in attack results" + ) + + +class ConverterOptionsResponse(BaseModel): + """Response containing unique converter class names used across attacks.""" + + converter_classes: List[str] = Field( + ..., description="Sorted list of unique converter class names found in attack results" + ) + + # ============================================================================ # Create Attack # ============================================================================ @@ -137,16 +160,14 @@ class PrependedMessageRequest(BaseModel): """A message to prepend to the attack (for system prompt/branching).""" role: Literal["user", "assistant", "system"] = Field(..., description="Message role") - pieces: List[MessagePieceRequest] = Field( - ..., description="Message pieces (supports multimodal)", max_length=50 - ) + pieces: List[MessagePieceRequest] = Field(..., description="Message pieces (supports multimodal)", max_length=50) class CreateAttackRequest(BaseModel): """Request to create a new attack.""" name: Optional[str] = Field(None, description="Attack name/label") - target_id: str = Field(..., description="Target instance ID to attack") + target_unique_name: str = Field(..., description="Target instance ID to attack") prepended_conversation: Optional[List[PrependedMessageRequest]] = Field( None, description="Messages to prepend (system prompts, branching context)", max_length=200 ) @@ -156,7 +177,7 @@ class CreateAttackRequest(BaseModel): class CreateAttackResponse(BaseModel): """Response after creating an attack.""" - attack_id: str = Field(..., description="Unique attack identifier") + conversation_id: str = Field(..., description="Unique attack identifier") created_at: datetime = Field(..., description="Attack creation timestamp") diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index fd57dc0a9a..27304b3932 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -33,9 +33,20 @@ class ConverterInstance(BaseModel): """A registered converter instance.""" converter_id: str = Field(..., description="Unique converter instance identifier") - type: str = Field(..., description="Converter type (e.g., 'base64', 'translation')") + converter_type: str = Field(..., description="Converter class name (e.g., 'Base64Converter')") display_name: Optional[str] = Field(None, description="Human-readable display name") - params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters (resolved)") + supported_input_types: List[str] = Field( + default_factory=list, description="Input data types supported by this converter" + ) + supported_output_types: List[str] = Field( + default_factory=list, description="Output data types produced by this converter" + ) + converter_specific_params: Optional[Dict[str, Any]] = Field( + None, description="Additional converter-specific parameters" + ) + sub_converter_ids: Optional[List[str]] = Field( + None, description="Converter IDs of sub-converters (for pipelines/composites)" + ) class ConverterInstanceListResponse(BaseModel): @@ -59,9 +70,8 @@ class CreateConverterResponse(BaseModel): """Response after creating a converter instance.""" converter_id: str = Field(..., description="Unique converter instance identifier") - type: str = Field(..., description="Converter type") + converter_type: str = Field(..., description="Converter class name") display_name: Optional[str] = Field(None, description="Human-readable display name") - params: Dict[str, Any] = Field(default_factory=dict, description="Converter parameters") # ============================================================================ diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index fd4d607a3c..b22d2e3415 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -23,12 +23,17 @@ class TargetInstance(BaseModel): A runtime target instance. Created either by an initializer (at startup) or by user (via API). + Also used as the create-target response (same shape as GET). """ - target_id: str = Field(..., description="Unique target instance identifier") - type: str = Field(..., description="Target type (e.g., 'azure_openai', 'text_target')") - display_name: Optional[str] = Field(None, description="Human-readable display name") - params: Dict[str, Any] = Field(default_factory=dict, description="Target configuration (sensitive fields filtered)") + target_unique_name: str = Field(..., description="Unique target instance identifier (TargetIdentifier.unique_name)") + target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") + endpoint: Optional[str] = Field(None, description="Target endpoint URL") + model_name: Optional[str] = Field(None, description="Model or deployment name") + temperature: Optional[float] = Field(None, description="Temperature parameter for generation") + top_p: Optional[float] = Field(None, description="Top-p parameter for generation") + max_requests_per_minute: Optional[int] = Field(None, description="Maximum requests per minute") + target_specific_params: Optional[Dict[str, Any]] = Field(None, description="Additional target-specific parameters") class TargetListResponse(BaseModel): @@ -42,14 +47,4 @@ class CreateTargetRequest(BaseModel): """Request to create a new target instance.""" type: str = Field(..., description="Target type (e.g., 'OpenAIChatTarget')") - display_name: Optional[str] = Field(None, description="Human-readable display name") params: Dict[str, Any] = Field(default_factory=dict, description="Target constructor parameters") - - -class CreateTargetResponse(BaseModel): - """Response after creating a target instance.""" - - target_id: str = Field(..., description="Unique target instance identifier") - type: str = Field(..., description="Target type") - display_name: Optional[str] = Field(None, description="Human-readable display name") - params: Dict[str, Any] = Field(default_factory=dict, description="Filtered configuration (no secrets)") diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 73076c6e5f..6b9851f09a 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -17,7 +17,9 @@ AddMessageResponse, AttackListResponse, AttackMessagesResponse, + AttackOptionsResponse, AttackSummary, + ConverterOptionsResponse, CreateAttackRequest, CreateAttackResponse, UpdateAttackRequest, @@ -50,14 +52,17 @@ def _parse_labels(label_params: Optional[List[str]]) -> Optional[Dict[str, str]] response_model=AttackListResponse, ) async def list_attacks( - target_id: Optional[str] = Query(None, description="Filter by target instance ID"), + attack_class: Optional[str] = Query(None, description="Filter by exact attack class name"), + converter_classes: Optional[List[str]] = Query( + None, + description="Filter by converter class names (repeatable, AND logic). Pass empty to match no-converter attacks.", + ), outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), - name: Optional[str] = Query(None, description="Filter by attack name (substring match)"), label: Optional[List[str]] = Query(None, description="Filter by labels (format: key:value, repeatable)"), min_turns: Optional[int] = Query(None, ge=0, description="Filter by minimum executed turns"), max_turns: Optional[int] = Query(None, ge=0, description="Filter by maximum executed turns"), limit: int = Query(20, ge=1, le=100, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (attack_id)"), + cursor: Optional[str] = Query(None, description="Pagination cursor (conversation_id)"), ) -> AttackListResponse: """ List attacks with optional filtering and pagination. @@ -71,9 +76,9 @@ async def list_attacks( service = get_attack_service() labels = _parse_labels(label) return await service.list_attacks_async( - target_id=target_id, + attack_class=attack_class, + converter_classes=converter_classes, outcome=outcome, - name=name, labels=labels, min_turns=min_turns, max_turns=max_turns, @@ -82,6 +87,44 @@ async def list_attacks( ) +@router.get( + "/attack-options", + response_model=AttackOptionsResponse, +) +async def get_attack_options() -> AttackOptionsResponse: + """ + Get unique attack class names used across all attacks. + + Returns all attack class names found in stored attack results. + Useful for populating attack type filter dropdowns in the GUI. + + Returns: + AttackOptionsResponse: Sorted list of unique attack class names. + """ + service = get_attack_service() + class_names = await service.get_attack_options_async() + return AttackOptionsResponse(attack_classes=class_names) + + +@router.get( + "/converter-options", + response_model=ConverterOptionsResponse, +) +async def get_converter_options() -> ConverterOptionsResponse: + """ + Get unique converter class names used across all attacks. + + Returns all converter class names found in stored attack results. + Useful for populating converter filter dropdowns in the GUI. + + Returns: + ConverterOptionsResponse: Sorted list of unique converter class names. + """ + service = get_attack_service() + class_names = await service.get_converter_options_async() + return ConverterOptionsResponse(converter_classes=class_names) + + @router.post( "", response_model=CreateAttackResponse, @@ -114,13 +157,13 @@ async def create_attack(request: CreateAttackRequest) -> CreateAttackResponse: @router.get( - "/{attack_id}", + "/{conversation_id}", response_model=AttackSummary, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, }, ) -async def get_attack(attack_id: str) -> AttackSummary: +async def get_attack(conversation_id: str) -> AttackSummary: """ Get attack details. @@ -131,25 +174,25 @@ async def get_attack(attack_id: str) -> AttackSummary: """ service = get_attack_service() - attack = await service.get_attack_async(attack_id=attack_id) + attack = await service.get_attack_async(conversation_id=conversation_id) if not attack: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Attack '{attack_id}' not found", + detail=f"Attack '{conversation_id}' not found", ) return attack @router.patch( - "/{attack_id}", + "/{conversation_id}", response_model=AttackSummary, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, }, ) async def update_attack( - attack_id: str, + conversation_id: str, request: UpdateAttackRequest, ) -> AttackSummary: """ @@ -162,24 +205,24 @@ async def update_attack( """ service = get_attack_service() - attack = await service.update_attack_async(attack_id=attack_id, request=request) + attack = await service.update_attack_async(conversation_id=conversation_id, request=request) if not attack: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Attack '{attack_id}' not found", + detail=f"Attack '{conversation_id}' not found", ) return attack @router.get( - "/{attack_id}/messages", + "/{conversation_id}/messages", response_model=AttackMessagesResponse, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, }, ) -async def get_attack_messages(attack_id: str) -> AttackMessagesResponse: +async def get_attack_messages(conversation_id: str) -> AttackMessagesResponse: """ Get all messages for an attack. @@ -190,18 +233,18 @@ async def get_attack_messages(attack_id: str) -> AttackMessagesResponse: """ service = get_attack_service() - messages = await service.get_attack_messages_async(attack_id=attack_id) + messages = await service.get_attack_messages_async(conversation_id=conversation_id) if not messages: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Attack '{attack_id}' not found", + detail=f"Attack '{conversation_id}' not found", ) return messages @router.post( - "/{attack_id}/messages", + "/{conversation_id}/messages", response_model=AddMessageResponse, responses={ 404: {"model": ProblemDetail, "description": "Attack not found"}, @@ -209,7 +252,7 @@ async def get_attack_messages(attack_id: str) -> AttackMessagesResponse: }, ) async def add_message( - attack_id: str, + conversation_id: str, request: AddMessageRequest, ) -> AddMessageResponse: """ @@ -230,7 +273,7 @@ async def add_message( service = get_attack_service() try: - return await service.add_message_async(attack_id=attack_id, request=request) + return await service.add_message_async(conversation_id=conversation_id, request=request) except ValueError as e: error_msg = str(e) if "not found" in error_msg.lower(): diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index fd7e04508c..437d8212ff 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -15,7 +15,6 @@ from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.targets import ( CreateTargetRequest, - CreateTargetResponse, TargetInstance, TargetListResponse, ) @@ -33,7 +32,7 @@ ) async def list_targets( limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (target_id)"), + cursor: Optional[str] = Query(None, description="Pagination cursor (target_unique_name)"), ) -> TargetListResponse: """ List target instances with pagination. @@ -49,13 +48,13 @@ async def list_targets( @router.post( "", - response_model=CreateTargetResponse, + response_model=TargetInstance, status_code=status.HTTP_201_CREATED, responses={ 400: {"model": ProblemDetail, "description": "Invalid target type or parameters"}, }, ) -async def create_target(request: CreateTargetRequest) -> CreateTargetResponse: +async def create_target(request: CreateTargetRequest) -> TargetInstance: """ Create a new target instance. @@ -84,26 +83,26 @@ async def create_target(request: CreateTargetRequest) -> CreateTargetResponse: @router.get( - "/{target_id}", + "/{target_unique_name}", response_model=TargetInstance, responses={ 404: {"model": ProblemDetail, "description": "Target not found"}, }, ) -async def get_target(target_id: str) -> TargetInstance: +async def get_target(target_unique_name: str) -> TargetInstance: """ - Get a target instance by ID. + Get a target instance by unique name. Returns: TargetInstance: The target instance details. """ service = get_target_service() - target = await service.get_target_async(target_id=target_id) + target = await service.get_target_async(target_unique_name=target_unique_name) if not target: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Target '{target_id}' not found", + detail=f"Target '{target_unique_name}' not found", ) return target diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 57fe79926e..8ed7f5e5f8 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -20,6 +20,12 @@ from functools import lru_cache from typing import Any, Dict, List, Literal, Optional +from pyrit.backend.mappers.attack_mappers import ( + attack_result_to_summary, + pyrit_messages_to_dto, + request_piece_to_pyrit_message_piece, + request_to_pyrit_message, +) from pyrit.backend.models.attacks import ( AddMessageRequest, AddMessageResponse, @@ -31,14 +37,9 @@ UpdateAttackRequest, ) from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.mappers.attack_mappers import ( - attack_result_to_summary, - pyrit_messages_to_dto, - request_piece_to_pyrit_message_piece, - request_to_pyrit_message, -) from pyrit.backend.services.converter_service import get_converter_service from pyrit.backend.services.target_service import get_target_service +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer @@ -62,9 +63,9 @@ def __init__(self) -> None: async def list_attacks_async( self, *, - target_id: Optional[str] = None, + attack_class: Optional[str] = None, + converter_classes: Optional[List[str]] = None, outcome: Optional[Literal["undetermined", "success", "failure"]] = None, - name: Optional[str] = None, labels: Optional[Dict[str, str]] = None, min_turns: Optional[int] = None, max_turns: Optional[int] = None, @@ -77,9 +78,11 @@ async def list_attacks_async( Queries AttackResult entries from the database. Args: - target_id: Filter by target instance ID (from attack_identifier). + attack_class: Filter by exact attack class_name (case-sensitive). + converter_classes: Filter by converter usage. + None = no filter, [] = only attacks with no converters, + ["A", "B"] = only attacks using ALL specified converters (AND logic, case-insensitive). outcome: Filter by attack outcome. - name: Filter by attack name (substring match on attack_identifier.name). labels: Filter by labels (all must match). min_turns: Filter by minimum executed turns. max_turns: Filter by maximum executed turns. @@ -93,14 +96,12 @@ async def list_attacks_async( attack_results = self._memory.get_attack_results( outcome=outcome, labels=labels, + attack_class=attack_class, + converter_classes=converter_classes, ) filtered: List[AttackResult] = [] for ar in attack_results: - if target_id and ar.attack_identifier.get("target_id", "") != target_id: - continue - if name and name.lower() not in ar.attack_identifier.get("name", "").lower(): - continue if min_turns is not None and ar.executed_turns < min_turns: continue if max_turns is not None and ar.executed_turns > max_turns: @@ -128,7 +129,31 @@ async def list_attacks_async( pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) - async def get_attack_async(self, *, attack_id: str) -> Optional[AttackSummary]: + async def get_attack_options_async(self) -> List[str]: + """ + Get all unique attack class names from stored attack results. + + Delegates to the memory layer which extracts distinct class_name + values from the attack_identifier JSON column via SQL. + + Returns: + Sorted list of unique attack class names. + """ + return self._memory.get_unique_attack_class_names() + + async def get_converter_options_async(self) -> List[str]: + """ + Get all unique converter class names used across attack results. + + Delegates to the memory layer which extracts distinct converter + class_name values from the attack_identifier JSON column via SQL. + + Returns: + Sorted list of unique converter class names. + """ + return self._memory.get_unique_converter_class_names() + + async def get_attack_async(self, *, conversation_id: str) -> Optional[AttackSummary]: """ Get attack details (high-level metadata, no messages). @@ -137,7 +162,7 @@ async def get_attack_async(self, *, attack_id: str) -> Optional[AttackSummary]: Returns: AttackSummary if found, None otherwise. """ - results = self._memory.get_attack_results(conversation_id=attack_id) + results = self._memory.get_attack_results(conversation_id=conversation_id) if not results: return None @@ -145,7 +170,7 @@ async def get_attack_async(self, *, attack_id: str) -> Optional[AttackSummary]: pieces = self._memory.get_message_pieces(conversation_id=ar.conversation_id) return attack_result_to_summary(ar, pieces=pieces) - async def get_attack_messages_async(self, *, attack_id: str) -> Optional[AttackMessagesResponse]: + async def get_attack_messages_async(self, *, conversation_id: str) -> Optional[AttackMessagesResponse]: """ Get all messages for an attack. @@ -153,16 +178,16 @@ async def get_attack_messages_async(self, *, attack_id: str) -> Optional[AttackM AttackMessagesResponse if attack found, None otherwise. """ # Check attack exists - results = self._memory.get_attack_results(conversation_id=attack_id) + results = self._memory.get_attack_results(conversation_id=conversation_id) if not results: return None # Get messages for this conversation - pyrit_messages = self._memory.get_conversation(conversation_id=attack_id) + pyrit_messages = self._memory.get_conversation(conversation_id=conversation_id) backend_messages = pyrit_messages_to_dto(list(pyrit_messages)) return AttackMessagesResponse( - attack_id=attack_id, + conversation_id=conversation_id, messages=backend_messages, ) @@ -176,11 +201,15 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt CreateAttackResponse with the new attack's ID and creation time. """ target_service = get_target_service() - target_instance = await target_service.get_target_async(target_id=request.target_id) + target_instance = await target_service.get_target_async(target_unique_name=request.target_unique_name) if not target_instance: - raise ValueError(f"Target instance '{request.target_id}' not found") + raise ValueError(f"Target instance '{request.target_unique_name}' not found") + + # Get the actual target object so we can capture its TargetIdentifier + target_obj = target_service.get_target_object(target_unique_name=request.target_unique_name) + target_identifier = target_obj.get_identifier() if target_obj else None - # Generate conversation_id (this is the attack_id) + # Generate a new conversation_id for this attack conversation_id = str(uuid.uuid4()) now = datetime.now(timezone.utc) @@ -188,12 +217,11 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt attack_result = AttackResult( conversation_id=conversation_id, objective=request.name or "Manual attack via GUI", - attack_identifier={ - "name": request.name or "", - "target_id": request.target_id, - "target_type": target_instance.type, - "source": "gui", - }, + attack_identifier=AttackIdentifier( + class_name=request.name or "ManualAttack", + class_module="pyrit.backend", + objective_target_identifier=target_identifier, + ), outcome=AttackOutcome.UNDETERMINED, metadata={ "created_at": now.isoformat(), @@ -201,6 +229,10 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt }, ) + # Merge source label with any user-supplied labels + labels = dict(request.labels) if request.labels else {} + labels.setdefault("source", "gui") + # Store in memory self._memory.add_attack_results_to_memory(attack_results=[attack_result]) @@ -209,12 +241,14 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt await self._store_prepended_messages( conversation_id=conversation_id, prepended=request.prepended_conversation, - labels=request.labels, + labels=labels, ) - return CreateAttackResponse(attack_id=conversation_id, created_at=now) + return CreateAttackResponse(conversation_id=conversation_id, created_at=now) - async def update_attack_async(self, *, attack_id: str, request: UpdateAttackRequest) -> Optional[AttackSummary]: + async def update_attack_async( + self, *, conversation_id: str, request: UpdateAttackRequest + ) -> Optional[AttackSummary]: """ Update an attack's outcome. @@ -223,7 +257,7 @@ async def update_attack_async(self, *, attack_id: str, request: UpdateAttackRequ Returns: Updated AttackSummary if found, None otherwise. """ - results = self._memory.get_attack_results(conversation_id=attack_id) + results = self._memory.get_attack_results(conversation_id=conversation_id) if not results: return None @@ -244,9 +278,9 @@ async def update_attack_async(self, *, attack_id: str, request: UpdateAttackRequ # Re-add to memory (this should update) self._memory.add_attack_results_to_memory(attack_results=[ar]) - return await self.get_attack_async(attack_id=attack_id) + return await self.get_attack_async(conversation_id=conversation_id) - async def add_message_async(self, *, attack_id: str, request: AddMessageRequest) -> AddMessageResponse: + async def add_message_async(self, *, conversation_id: str, request: AddMessageRequest) -> AddMessageResponse: """ Add a message to an attack, optionally sending to target. @@ -256,43 +290,50 @@ async def add_message_async(self, *, attack_id: str, request: AddMessageRequest) AddMessageResponse containing the updated attack detail. """ # Check if attack exists - results = self._memory.get_attack_results(conversation_id=attack_id) + results = self._memory.get_attack_results(conversation_id=conversation_id) if not results: - raise ValueError(f"Attack '{attack_id}' not found") + raise ValueError(f"Attack '{conversation_id}' not found") ar = results[0] - target_id = ar.attack_identifier.get("target_id") - if not target_id: - raise ValueError(f"Attack '{attack_id}' has no target configured") + aid = ar.attack_identifier + if not aid or not aid.objective_target_identifier: + raise ValueError(f"Attack '{conversation_id}' has no target configured") + target_unique_name = aid.objective_target_identifier.unique_name # Get existing messages to determine sequence. # NOTE: This read-then-write is not atomic (TOCTOU). Fine for the # current single-user UI, but would need a DB-level sequence # generator or optimistic locking if concurrent writes are supported. - existing = self._memory.get_message_pieces(conversation_id=attack_id) + existing = self._memory.get_message_pieces(conversation_id=conversation_id) sequence = max((p.sequence for p in existing), default=-1) + 1 # Inherit labels from existing pieces so new messages stay consistent attack_labels = next((p.labels for p in existing if getattr(p, "labels", None)), None) if request.send: - await self._send_and_store_message(attack_id, target_id, request, sequence, labels=attack_labels) + await self._send_and_store_message( + conversation_id, target_unique_name, request, sequence, labels=attack_labels + ) else: - await self._store_message_only(attack_id, request, sequence, labels=attack_labels) + await self._store_message_only(conversation_id, request, sequence, labels=attack_labels) # Update attack timestamp ar.metadata["updated_at"] = datetime.now(timezone.utc).isoformat() - attack_detail = await self.get_attack_async(attack_id=attack_id) + attack_detail = await self.get_attack_async(conversation_id=conversation_id) if attack_detail is None: - raise ValueError(f"Attack '{attack_id}' not found after update") + raise ValueError(f"Attack '{conversation_id}' not found after update") - attack_messages = await self.get_attack_messages_async(attack_id=attack_id) + attack_messages = await self.get_attack_messages_async(conversation_id=conversation_id) if attack_messages is None: - raise ValueError(f"Attack '{attack_id}' messages not found after update") + raise ValueError(f"Attack '{conversation_id}' messages not found after update") return AddMessageResponse(attack=attack_detail, messages=attack_messages) + # ======================================================================== + # Private Helper Methods - Identifier Access + # ======================================================================== + # ======================================================================== # Private Helper Methods - Pagination # ======================================================================== @@ -346,21 +387,21 @@ async def _store_prepended_messages( async def _send_and_store_message( self, - attack_id: str, - target_id: str, + conversation_id: str, + target_unique_name: str, request: AddMessageRequest, sequence: int, *, labels: Optional[Dict[str, str]] = None, ) -> None: """Send message to target via normalizer and store response.""" - target_obj = get_target_service().get_target_object(target_id=target_id) + target_obj = get_target_service().get_target_object(target_unique_name=target_unique_name) if not target_obj: - raise ValueError(f"Target object for '{target_id}' not found") + raise ValueError(f"Target object for '{target_unique_name}' not found") pyrit_message = request_to_pyrit_message( request=request, - conversation_id=attack_id, + conversation_id=conversation_id, sequence=sequence, labels=labels, ) @@ -370,7 +411,7 @@ async def _send_and_store_message( await normalizer.send_prompt_async( message=pyrit_message, target=target_obj, - conversation_id=attack_id, + conversation_id=conversation_id, request_converter_configurations=converter_configs, labels=labels, ) @@ -378,7 +419,7 @@ async def _send_and_store_message( async def _store_message_only( self, - attack_id: str, + conversation_id: str, request: AddMessageRequest, sequence: int, *, @@ -389,7 +430,7 @@ async def _store_message_only( piece = request_piece_to_pyrit_message_piece( piece=p, role=request.role, - conversation_id=attack_id, + conversation_id=conversation_id, sequence=sequence, labels=labels, ) diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 3c1049e103..c18d0ec084 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -17,6 +17,7 @@ from typing import Any, List, Optional, Tuple from pyrit import prompt_converter +from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.models.converters import ( ConverterInstance, ConverterInstanceListResponse, @@ -26,7 +27,6 @@ CreateConverterResponse, PreviewStep, ) -from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter from pyrit.registry.instance_registries import ConverterRegistry @@ -88,7 +88,8 @@ async def list_converters_async(self) -> ConverterInstanceListResponse: ConverterInstanceListResponse containing all registered converters. """ items = [ - self._build_instance_from_object(converter_id=name, converter_obj=obj) for name, obj in self._registry.get_all_instances().items() + self._build_instance_from_object(converter_id=name, converter_obj=obj) + for name, obj in self._registry.get_all_instances().items() ] return ConverterInstanceListResponse(items=items) @@ -139,9 +140,8 @@ async def create_converter_async(self, *, request: CreateConverterRequest) -> Cr return CreateConverterResponse( converter_id=converter_id, - type=request.type, + converter_type=request.type, display_name=request.display_name, - params=request.params, ) async def preview_conversion_async(self, *, request: ConverterPreviewRequest) -> ConverterPreviewResponse: diff --git a/pyrit/backend/services/target_service.py b/pyrit/backend/services/target_service.py index 1e662d0eec..159c7e7ce2 100644 --- a/pyrit/backend/services/target_service.py +++ b/pyrit/backend/services/target_service.py @@ -12,19 +12,17 @@ - Retrieved from registry (pre-registered at startup or created earlier) """ -import uuid from functools import lru_cache from typing import Any, List, Optional from pyrit import prompt_target -from pyrit.backend.models.common import PaginationInfo, filter_sensitive_fields +from pyrit.backend.mappers.target_mappers import target_object_to_instance +from pyrit.backend.models.common import PaginationInfo from pyrit.backend.models.targets import ( CreateTargetRequest, - CreateTargetResponse, TargetInstance, TargetListResponse, ) -from pyrit.backend.mappers.target_mappers import target_object_to_instance from pyrit.prompt_target import PromptTarget from pyrit.registry.instance_registries import TargetRegistry @@ -84,14 +82,14 @@ def _get_target_class(self, *, target_type: str) -> type: ) return cls - def _build_instance_from_object(self, *, target_id: str, target_obj: Any) -> TargetInstance: + def _build_instance_from_object(self, *, target_unique_name: str, target_obj: Any) -> TargetInstance: """ Build a TargetInstance from a registry object. Returns: TargetInstance with metadata derived from the object. """ - return target_object_to_instance(target_id, target_obj) + return target_object_to_instance(target_unique_name, target_obj) async def list_targets_async( self, @@ -104,25 +102,24 @@ async def list_targets_async( Args: limit: Maximum items to return. - cursor: Pagination cursor (target_id to start after). + cursor: Pagination cursor (target_unique_name to start after). Returns: TargetListResponse containing paginated targets. """ items = [ - self._build_instance_from_object(target_id=name, target_obj=obj) for name, obj in self._registry.get_all_instances().items() + self._build_instance_from_object(target_unique_name=name, target_obj=obj) + for name, obj in self._registry.get_all_instances().items() ] page, has_more = self._paginate(items, cursor, limit) - next_cursor = page[-1].target_id if has_more and page else None + next_cursor = page[-1].target_unique_name if has_more and page else None return TargetListResponse( items=page, pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) @staticmethod - def _paginate( - items: List[TargetInstance], cursor: Optional[str], limit: int - ) -> tuple[List[TargetInstance], bool]: + def _paginate(items: List[TargetInstance], cursor: Optional[str], limit: int) -> tuple[List[TargetInstance], bool]: """ Apply cursor-based pagination. @@ -132,7 +129,7 @@ def _paginate( start_idx = 0 if cursor: for i, item in enumerate(items): - if item.target_id == cursor: + if item.target_unique_name == cursor: start_idx = i + 1 break @@ -140,61 +137,51 @@ def _paginate( has_more = len(items) > start_idx + limit return page, has_more - async def get_target_async(self, *, target_id: str) -> Optional[TargetInstance]: + async def get_target_async(self, *, target_unique_name: str) -> Optional[TargetInstance]: """ - Get a target instance by ID. + Get a target instance by unique name. Returns: TargetInstance if found, None otherwise. """ - obj = self._registry.get_instance_by_name(target_id) + obj = self._registry.get_instance_by_name(target_unique_name) if obj is None: return None - return self._build_instance_from_object(target_id=target_id, target_obj=obj) + return self._build_instance_from_object(target_unique_name=target_unique_name, target_obj=obj) - def get_target_object(self, *, target_id: str) -> Optional[Any]: + def get_target_object(self, *, target_unique_name: str) -> Optional[Any]: """ Get the actual target object for use in attacks. Returns: The PromptTarget object if found, None otherwise. """ - return self._registry.get_instance_by_name(target_id) + return self._registry.get_instance_by_name(target_unique_name) - async def create_target_async(self, *, request: CreateTargetRequest) -> CreateTargetResponse: + async def create_target_async(self, *, request: CreateTargetRequest) -> TargetInstance: """ Create a new target instance from API request. Instantiates the target with the given type and params, - then registers it in the registry. + then registers it in the registry under its unique_name. Args: request: The create target request with type and params. Returns: - CreateTargetResponse with the new target's details. + TargetInstance with the new target's details. Raises: ValueError: If the target type is not found. """ - target_id = str(uuid.uuid4()) - - # Instantiate from request params and register + # Instantiate from request params and register (uses unique_name as key by default) target_class = self._get_target_class(target_type=request.type) target_obj = target_class(**request.params) - self._registry.register_instance(target_obj, name=target_id) - - # Build response from the object's identifier - identifier = target_obj.get_identifier() - identifier_dict = identifier.to_dict() if hasattr(identifier, "to_dict") else identifier - filtered_params = filter_sensitive_fields(identifier_dict) - - return CreateTargetResponse( - target_id=target_id, - type=request.type, - display_name=request.display_name, - params=filtered_params, - ) + self._registry.register_instance(target_obj) + + # Build response from the registered instance + target_unique_name = target_obj.get_identifier().unique_name + return self._build_instance_from_object(target_unique_name=target_unique_name, target_obj=target_obj) @lru_cache(maxsize=1) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index f53c4b6cd4..52c8cbf6d8 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -385,6 +385,101 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery + def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: + """ + Azure SQL implementation for filtering AttackResults by attack type. + Uses JSON_VALUE() to match class_name in the attack_identifier JSON column. + + Args: + attack_class (str): Exact attack class name to match. + + Returns: + Any: SQLAlchemy text condition with bound parameter. + """ + return text( + """ISJSON("AttackResultEntries".attack_identifier) = 1 + AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_class""" + ).bindparams(attack_class=attack_class) + + def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any: + """ + Azure SQL implementation for filtering AttackResults by converter classes. + + When converter_classes is empty, matches attacks with no converters. + When non-empty, uses OPENJSON() to check all specified classes are present + (AND logic, case-insensitive). + + Args: + converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. + + Returns: + Any: SQLAlchemy combined condition with bound parameters. + """ + if len(converter_classes) == 0: + # Explicitly "no converters": match attacks where the converter list + # is absent, null, or empty in the stored JSON. + return text( + """("AttackResultEntries".attack_identifier IS NULL + OR "AttackResultEntries".attack_identifier = '{}' + OR JSON_QUERY("AttackResultEntries".attack_identifier, '$.request_converter_identifiers') IS NULL + OR JSON_QUERY("AttackResultEntries".attack_identifier, '$.request_converter_identifiers') = '[]')""" + ) + + conditions = [] + bindparams_dict: dict[str, str] = {} + for i, cls in enumerate(converter_classes): + param_name = f"conv_cls_{i}" + conditions.append( + f'EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".attack_identifier, ' + f"'$.request_converter_identifiers')) " + f"WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})" + ) + bindparams_dict[param_name] = cls.lower() + + combined = " AND ".join(conditions) + return text(f"""ISJSON("AttackResultEntries".attack_identifier) = 1 AND {combined}""").bindparams( + **bindparams_dict + ) + + def get_unique_attack_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique class_name values from attack_identifier JSON. + + Returns: + Sorted list of unique attack class name strings. + """ + with closing(self.get_session()) as session: + rows = session.execute( + text( + """SELECT DISTINCT JSON_VALUE(attack_identifier, '$.class_name') AS cls + FROM "AttackResultEntries" + WHERE ISJSON(attack_identifier) = 1 + AND JSON_VALUE(attack_identifier, '$.class_name') IS NOT NULL""" + ) + ).fetchall() + return sorted(row[0] for row in rows) + + def get_unique_converter_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique converter class_name values + from the request_converter_identifiers array in attack_identifier JSON. + + Returns: + Sorted list of unique converter class name strings. + """ + with closing(self.get_session()) as session: + rows = session.execute( + text( + """SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls + FROM "AttackResultEntries" + CROSS APPLY OPENJSON(JSON_QUERY(attack_identifier, + '$.request_converter_identifiers')) AS c + WHERE ISJSON(attack_identifier) = 1 + AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL""" + ) + ).fetchall() + return sorted(row[0] for row in rows) + def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Get the SQL Azure implementation for filtering ScenarioResults by labels. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0f0d2d8847..8a09b63b80 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -286,6 +286,64 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Database-specific SQLAlchemy condition. """ + @abc.abstractmethod + def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: + """ + Return a database-specific condition for filtering AttackResults by attack type + (class_name in the attack_identifier JSON column). + + Args: + attack_class: Exact attack class name to match. + + Returns: + Database-specific SQLAlchemy condition. + """ + + @abc.abstractmethod + def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any: + """ + Return a database-specific condition for filtering AttackResults by converter classes + in the request_converter_identifiers array within attack_identifier JSON column. + + This method is only called when converter filtering is requested (converter_classes + is not None). The caller handles the None-vs-list distinction: + + - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. + - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter + class names to be present (AND logic, case-insensitive). + + Args: + converter_classes: Converter class names to require. An empty sequence means + "match only attacks that have no converters". + + Returns: + Database-specific SQLAlchemy condition. + """ + + @abc.abstractmethod + def get_unique_attack_class_names(self) -> list[str]: + """ + Return sorted unique attack class names from all stored attack results. + + Extracts class_name from the attack_identifier JSON column via a + database-level DISTINCT query. + + Returns: + Sorted list of unique attack class name strings. + """ + + @abc.abstractmethod + def get_unique_converter_class_names(self) -> list[str]: + """ + Return sorted unique converter class names used across all attack results. + + Extracts class_name values from the request_converter_identifiers array + within the attack_identifier JSON column via a database-level query. + + Returns: + Sorted list of unique converter class name strings. + """ + @abc.abstractmethod def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @@ -1209,6 +1267,8 @@ def get_attack_results( objective: Optional[str] = None, objective_sha256: Optional[Sequence[str]] = None, outcome: Optional[str] = None, + attack_class: Optional[str] = None, + converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, ) -> Sequence[AttackResult]: @@ -1223,6 +1283,11 @@ def get_attack_results( Defaults to None. outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). Defaults to None. + attack_class (Optional[str], optional): Filter by exact attack class_name in attack_identifier. + Defaults to None. + converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. + Returns only attacks that used ALL specified converters (AND logic, case-insensitive). + Defaults to None. targeted_harm_categories (Optional[Sequence[str]], optional): A list of targeted harm categories to filter results by. These targeted harm categories are associated with the prompts themselves, @@ -1254,6 +1319,15 @@ def get_attack_results( if outcome: conditions.append(AttackResultEntry.outcome == outcome) + if attack_class: + # Use database-specific JSON query method + conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) + + if converter_classes is not None: + # converter_classes=[] means "only attacks with no converters" + # converter_classes=["A","B"] means "must have all listed converters" + conditions.append(self._get_attack_result_converter_condition(converter_classes=converter_classes)) + if targeted_harm_categories: # Use database-specific JSON query method conditions.append( diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index bca6a21817..7b038bd1a9 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Any, MutableSequence, Optional, Sequence, TypeVar, Union -from sqlalchemy import create_engine, text +from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload, sessionmaker @@ -504,6 +504,87 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery + def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: + """ + SQLite implementation for filtering AttackResults by attack type. + Uses json_extract() to match class_name in the attack_identifier JSON column. + + Returns: + Any: A SQLAlchemy condition for filtering by attack type. + """ + return func.json_extract(AttackResultEntry.attack_identifier, "$.class_name") == attack_class + + def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any: + """ + SQLite implementation for filtering AttackResults by converter classes. + + When converter_classes is empty, matches attacks with no converters + (request_converter_identifiers is absent or null in the JSON). + When non-empty, uses json_each() to check all specified classes are present + (AND logic, case-insensitive). + + Returns: + Any: A SQLAlchemy condition for filtering by converter classes. + """ + if len(converter_classes) == 0: + # Explicitly "no converters": match attacks where the converter list + # is absent, null, or empty in the stored JSON. + converter_json = func.json_extract(AttackResultEntry.attack_identifier, "$.request_converter_identifiers") + return or_( + AttackResultEntry.attack_identifier.is_(None), + AttackResultEntry.attack_identifier == "{}", + converter_json.is_(None), + converter_json == "[]", + ) + + conditions = [] + for i, cls in enumerate(converter_classes): + param_name = f"conv_cls_{i}" + conditions.append( + text( + f"""EXISTS(SELECT 1 FROM json_each( + json_extract("AttackResultEntries".attack_identifier, '$.request_converter_identifiers')) + WHERE LOWER(json_extract(value, '$.class_name')) = :{param_name})""" + ).bindparams(**{param_name: cls.lower()}) + ) + return and_(*conditions) + + def get_unique_attack_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique class_name values from attack_identifier JSON. + + Returns: + Sorted list of unique attack class name strings. + """ + with closing(self.get_session()) as session: + rows = ( + session.query(func.json_extract(AttackResultEntry.attack_identifier, "$.class_name")) + .filter(func.json_extract(AttackResultEntry.attack_identifier, "$.class_name").isnot(None)) + .distinct() + .all() + ) + return sorted(row[0] for row in rows) + + def get_unique_converter_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique converter class_name values + from the request_converter_identifiers array in attack_identifier JSON. + + Returns: + Sorted list of unique converter class name strings. + """ + with closing(self.get_session()) as session: + rows = session.execute( + text( + """SELECT DISTINCT json_extract(j.value, '$.class_name') AS cls + FROM "AttackResultEntries", + json_each(json_extract("AttackResultEntries".attack_identifier, + '$.request_converter_identifiers')) AS j + WHERE cls IS NOT NULL""" + ) + ).fetchall() + return sorted(row[0] for row in rows) + def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ SQLite implementation for filtering ScenarioResults by labels. diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 6c057b522f..44ad3f34bb 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -34,7 +34,6 @@ PreviewStep, ) from pyrit.backend.models.targets import ( - CreateTargetResponse, TargetInstance, TargetListResponse, ) @@ -87,14 +86,14 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: response = client.get( "/api/attacks", - params={"target_id": "t1", "outcome": "success", "limit": 10}, + params={"attack_class": "CrescendoAttack", "outcome": "success", "limit": 10}, ) assert response.status_code == status.HTTP_200_OK mock_service.list_attacks_async.assert_called_once_with( - target_id="t1", + attack_class="CrescendoAttack", + converter_classes=None, outcome="success", - name=None, labels=None, min_turns=None, max_turns=None, @@ -110,7 +109,7 @@ def test_create_attack_success(self, client: TestClient) -> None: mock_service = MagicMock() mock_service.create_attack_async = AsyncMock( return_value=CreateAttackResponse( - attack_id="attack-1", + conversation_id="attack-1", created_at=now, ) ) @@ -118,12 +117,12 @@ def test_create_attack_success(self, client: TestClient) -> None: response = client.post( "/api/attacks", - json={"target_id": "target-1", "name": "Test Attack"}, + json={"target_unique_name": "target-1", "name": "Test Attack"}, ) assert response.status_code == status.HTTP_201_CREATED data = response.json() - assert data["attack_id"] == "attack-1" + assert data["conversation_id"] == "attack-1" def test_create_attack_target_not_found(self, client: TestClient) -> None: """Test attack creation with non-existent target.""" @@ -134,7 +133,7 @@ def test_create_attack_target_not_found(self, client: TestClient) -> None: response = client.post( "/api/attacks", - json={"target_id": "nonexistent"}, + json={"target_unique_name": "nonexistent"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -147,10 +146,8 @@ def test_get_attack_success(self, client: TestClient) -> None: mock_service = MagicMock() mock_service.get_attack_async = AsyncMock( return_value=AttackSummary( - attack_id="attack-1", - name="Test", - target_id="target-1", - target_type="TextTarget", + conversation_id="attack-1", + attack_type="TestAttack", outcome=None, last_message_preview=None, message_count=0, @@ -164,7 +161,7 @@ def test_get_attack_success(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["attack_id"] == "attack-1" + assert data["conversation_id"] == "attack-1" def test_get_attack_not_found(self, client: TestClient) -> None: """Test getting a non-existent attack.""" @@ -185,10 +182,8 @@ def test_update_attack_success(self, client: TestClient) -> None: mock_service = MagicMock() mock_service.update_attack_async = AsyncMock( return_value=AttackSummary( - attack_id="attack-1", - name=None, - target_id="target-1", - target_type="TextTarget", + conversation_id="attack-1", + attack_type="TestAttack", outcome="success", last_message_preview=None, message_count=0, @@ -212,10 +207,8 @@ def test_add_message_success(self, client: TestClient) -> None: now = datetime.now(timezone.utc) attack_summary = AttackSummary( - attack_id="attack-1", - name=None, - target_id="target-1", - target_type="TextTarget", + conversation_id="attack-1", + attack_type="TestAttack", outcome=None, last_message_preview=None, message_count=2, @@ -224,10 +217,9 @@ def test_add_message_success(self, client: TestClient) -> None: ) attack_messages = AttackMessagesResponse( - attack_id="attack-1", + conversation_id="attack-1", messages=[ Message( - message_id="msg-1", turn_number=1, role="user", pieces=[ @@ -239,7 +231,6 @@ def test_add_message_success(self, client: TestClient) -> None: created_at=now, ), Message( - message_id="msg-2", turn_number=2, role="assistant", pieces=[ @@ -350,10 +341,9 @@ def test_get_attack_messages_success(self, client: TestClient) -> None: mock_service = MagicMock() mock_service.get_attack_messages_async = AsyncMock( return_value=AttackMessagesResponse( - attack_id="attack-1", + conversation_id="attack-1", messages=[ Message( - message_id="msg-1", turn_number=1, role="user", pieces=[MessagePiece(piece_id="p1", converted_value="Hello")], @@ -368,7 +358,7 @@ def test_get_attack_messages_success(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["attack_id"] == "attack-1" + assert data["conversation_id"] == "attack-1" assert len(data["messages"]) == 1 def test_get_attack_messages_not_found(self, client: TestClient) -> None: @@ -392,10 +382,8 @@ def test_list_attacks_with_labels(self, client: TestClient) -> None: return_value=AttackListResponse( items=[ AttackSummary( - attack_id="attack-1", - name=None, - target_id="target-1", - target_type="TextTarget", + conversation_id="attack-1", + attack_type="TestAttack", outcome=None, last_message_preview=None, message_count=0, @@ -417,6 +405,105 @@ def test_list_attacks_with_labels(self, client: TestClient) -> None: call_kwargs = mock_service.list_attacks_async.call_args[1] assert call_kwargs["labels"] == {"env": "prod", "team": "red"} + def test_get_attack_options(self, client: TestClient) -> None: + """Test getting attack type options from attack results.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_attack_options_async = AsyncMock(return_value=["CrescendoAttack", "ManualAttack"]) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks/attack-options") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["attack_classes"] == ["CrescendoAttack", "ManualAttack"] + + def test_get_converter_options(self, client: TestClient) -> None: + """Test getting converter options from attack results.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_converter_options_async = AsyncMock(return_value=["Base64Converter", "ROT13Converter"]) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks/converter-options") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["converter_classes"] == ["Base64Converter", "ROT13Converter"] + + def test_parse_labels_skips_param_without_colon(self, client: TestClient) -> None: + """Test that _parse_labels skips label params that have no colon.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?label=nocolon&label=env:prod") + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args[1] + # Only the valid label should be parsed + assert call_kwargs["labels"] == {"env": "prod"} + + def test_parse_labels_all_invalid_returns_none(self, client: TestClient) -> None: + """Test that _parse_labels returns None when all params lack colons.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?label=nocolon&label=alsonocolon") + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args[1] + assert call_kwargs["labels"] is None + + def test_parse_labels_value_with_extra_colons(self, client: TestClient) -> None: + """Test that _parse_labels handles values containing colons (split on first only).""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?label=url:http://example.com:8080") + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args[1] + assert call_kwargs["labels"] == {"url": "http://example.com:8080"} + + def test_list_attacks_forwards_converter_classes_param(self, client: TestClient) -> None: + """Test that converter_classes query params are forwarded to service.""" + with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_attacks_async = AsyncMock( + return_value=AttackListResponse( + items=[], + pagination=PaginationInfo(limit=20, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/attacks?converter_classes=Base64&converter_classes=ROT13") + + assert response.status_code == status.HTTP_200_OK + call_kwargs = mock_service.list_attacks_async.call_args[1] + assert call_kwargs["converter_classes"] == ["Base64", "ROT13"] + # ============================================================================ # Target Routes Tests @@ -450,23 +537,21 @@ def test_create_target_success(self, client: TestClient) -> None: with patch("pyrit.backend.routes.targets.get_target_service") as mock_get_service: mock_service = MagicMock() mock_service.create_target_async = AsyncMock( - return_value=CreateTargetResponse( - target_id="target-1", - type="TextTarget", - display_name="My Target", - params={}, + return_value=TargetInstance( + target_unique_name="target-1", + target_type="TextTarget", ) ) mock_get_service.return_value = mock_service response = client.post( "/api/targets", - json={"type": "TextTarget", "display_name": "My Target", "params": {}}, + json={"type": "TextTarget", "params": {}}, ) assert response.status_code == status.HTTP_201_CREATED data = response.json() - assert data["target_id"] == "target-1" + assert data["target_unique_name"] == "target-1" def test_create_target_invalid_type(self, client: TestClient) -> None: """Test target creation with invalid type.""" @@ -502,10 +587,8 @@ def test_get_target_success(self, client: TestClient) -> None: mock_service = MagicMock() mock_service.get_target_async = AsyncMock( return_value=TargetInstance( - target_id="target-1", - type="TextTarget", - display_name=None, - params={}, + target_unique_name="target-1", + target_type="TextTarget", ) ) mock_get_service.return_value = mock_service @@ -514,7 +597,7 @@ def test_get_target_success(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["target_id"] == "target-1" + assert data["target_unique_name"] == "target-1" def test_get_target_not_found(self, client: TestClient) -> None: """Test getting a non-existent target.""" @@ -556,9 +639,8 @@ def test_create_converter_success(self, client: TestClient) -> None: mock_service.create_converter_async = AsyncMock( return_value=CreateConverterResponse( converter_id="conv-1", - type="Base64Converter", + converter_type="Base64Converter", display_name="My Base64", - params={}, ) ) mock_get_service.return_value = mock_service @@ -607,9 +689,8 @@ def test_get_converter_success(self, client: TestClient) -> None: mock_service.get_converter_async = AsyncMock( return_value=ConverterInstance( converter_id="conv-1", - type="Base64Converter", + converter_type="Base64Converter", display_name=None, - params={}, ) ) mock_get_service.return_value = mock_service diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index b1371e2ab9..0dbf1da9d6 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -23,6 +23,7 @@ AttackService, get_attack_service, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import AttackOutcome, AttackResult @@ -49,8 +50,7 @@ def make_attack_result( *, conversation_id: str = "attack-1", objective: str = "Test objective", - target_id: str = "target-1", - target_type: str = "TextTarget", + has_target: bool = True, name: str = "Test Attack", outcome: AttackOutcome = AttackOutcome.UNDETERMINED, created_at: datetime = None, @@ -60,15 +60,24 @@ def make_attack_result( now = datetime.now(timezone.utc) created = created_at or now updated = updated_at or now + + target_identifier = ( + TargetIdentifier( + class_name="TextTarget", + class_module="pyrit.prompt_target", + ) + if has_target + else None + ) + return AttackResult( conversation_id=conversation_id, objective=objective, - attack_identifier={ - "name": name, - "target_id": target_id, - "target_type": target_type, - "source": "gui", - }, + attack_identifier=AttackIdentifier( + class_name=name, + class_module="pyrit.backend", + objective_target_identifier=target_identifier, + ), outcome=outcome, metadata={ "created_at": created.isoformat(), @@ -159,34 +168,80 @@ async def test_list_attacks_returns_attacks(self, attack_service, mock_memory) - result = await attack_service.list_attacks_async() assert len(result.items) == 1 - assert result.items[0].attack_id == "attack-1" - assert result.items[0].target_id == "target-1" + assert result.items[0].conversation_id == "attack-1" + assert result.items[0].attack_type == "Test Attack" @pytest.mark.asyncio - async def test_list_attacks_filters_by_target_id(self, attack_service, mock_memory) -> None: - """Test that list_attacks filters by target_id.""" - ar1 = make_attack_result(conversation_id="attack-1", target_id="target-1") - ar2 = make_attack_result(conversation_id="attack-2", target_id="target-2") - mock_memory.get_attack_results.return_value = [ar1, ar2] + async def test_list_attacks_filters_by_attack_class_exact(self, attack_service, mock_memory) -> None: + """Test that list_attacks passes attack_class to memory layer.""" + ar1 = make_attack_result(conversation_id="attack-1", name="CrescendoAttack") + mock_memory.get_attack_results.return_value = [ar1] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks_async(target_id="target-1") + result = await attack_service.list_attacks_async(attack_class="CrescendoAttack") assert len(result.items) == 1 - assert result.items[0].target_id == "target-1" + assert result.items[0].conversation_id == "attack-1" + # Verify attack_class was forwarded to the memory layer + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["attack_class"] == "CrescendoAttack" @pytest.mark.asyncio - async def test_list_attacks_filters_by_name(self, attack_service, mock_memory) -> None: - """Test that list_attacks filters by name substring (case-insensitive).""" - ar1 = make_attack_result(conversation_id="attack-1", name="Test Attack") - ar2 = make_attack_result(conversation_id="attack-2", name="Other") - mock_memory.get_attack_results.return_value = [ar1, ar2] + async def test_list_attacks_attack_class_passed_to_memory(self, attack_service, mock_memory) -> None: + """Test that attack_class is forwarded to memory for DB-level filtering.""" + mock_memory.get_attack_results.return_value = [] + mock_memory.get_message_pieces.return_value = [] + + await attack_service.list_attacks_async(attack_class="Crescendo") + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["attack_class"] == "Crescendo" + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_no_converters(self, attack_service, mock_memory) -> None: + """Test that converter_classes=[] is forwarded to memory for DB-level filtering.""" + mock_memory.get_attack_results.return_value = [] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks_async(name="test") + await attack_service.list_attacks_async(converter_classes=[]) + + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["converter_classes"] == [] + + @pytest.mark.asyncio + async def test_list_attacks_filters_by_converter_classes_and_logic(self, attack_service, mock_memory) -> None: + """Test that list_attacks passes converter_classes to memory layer.""" + from pyrit.identifiers import ConverterIdentifier + + ar1 = make_attack_result(conversation_id="attack-1", name="Attack One") + ar1.attack_identifier = AttackIdentifier( + class_name="Attack One", + class_module="pyrit.backend", + request_converter_identifiers=[ + ConverterIdentifier( + class_name="Base64Converter", + class_module="pyrit.converters", + supported_input_types=("text",), + supported_output_types=("text",), + ), + ConverterIdentifier( + class_name="ROT13Converter", + class_module="pyrit.converters", + supported_input_types=("text",), + supported_output_types=("text",), + ), + ], + ) + mock_memory.get_attack_results.return_value = [ar1] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks_async(converter_classes=["Base64Converter", "ROT13Converter"]) assert len(result.items) == 1 - assert result.items[0].name == "Test Attack" + assert result.items[0].conversation_id == "attack-1" + # Verify converter_classes was forwarded to the memory layer + call_kwargs = mock_memory.get_attack_results.call_args[1] + assert call_kwargs["converter_classes"] == ["Base64Converter", "ROT13Converter"] @pytest.mark.asyncio async def test_list_attacks_filters_by_min_turns(self, attack_service, mock_memory) -> None: @@ -201,7 +256,7 @@ async def test_list_attacks_filters_by_min_turns(self, attack_service, mock_memo result = await attack_service.list_attacks_async(min_turns=3) assert len(result.items) == 1 - assert result.items[0].attack_id == "attack-1" + assert result.items[0].conversation_id == "attack-1" @pytest.mark.asyncio async def test_list_attacks_filters_by_max_turns(self, attack_service, mock_memory) -> None: @@ -216,7 +271,7 @@ async def test_list_attacks_filters_by_max_turns(self, attack_service, mock_memo result = await attack_service.list_attacks_async(max_turns=3) assert len(result.items) == 1 - assert result.items[0].attack_id == "attack-2" + assert result.items[0].conversation_id == "attack-2" @pytest.mark.asyncio async def test_list_attacks_includes_labels_in_summary(self, attack_service, mock_memory) -> None: @@ -234,6 +289,83 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc assert len(result.items) == 1 assert result.items[0].labels == {"env": "prod", "team": "red"} + @pytest.mark.asyncio + async def test_list_attacks_combined_min_and_max_turns(self, attack_service, mock_memory) -> None: + """Test that list_attacks filters by both min_turns and max_turns together.""" + ar1 = make_attack_result(conversation_id="attack-1") + ar1.executed_turns = 1 + ar2 = make_attack_result(conversation_id="attack-2") + ar2.executed_turns = 3 + ar3 = make_attack_result(conversation_id="attack-3") + ar3.executed_turns = 7 + mock_memory.get_attack_results.return_value = [ar1, ar2, ar3] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks_async(min_turns=2, max_turns=5) + + assert len(result.items) == 1 + assert result.items[0].conversation_id == "attack-2" + + +# ============================================================================ +# Attack Options Tests +# ============================================================================ + + +@pytest.mark.usefixtures("patch_central_database") +class TestAttackOptions: + """Tests for get_attack_options_async method.""" + + @pytest.mark.asyncio + async def test_returns_empty_when_no_attacks(self, attack_service, mock_memory) -> None: + """Test that attack options returns empty list when no attacks exist.""" + mock_memory.get_unique_attack_class_names.return_value = [] + + result = await attack_service.get_attack_options_async() + + assert result == [] + mock_memory.get_unique_attack_class_names.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_result_from_memory(self, attack_service, mock_memory) -> None: + """Test that attack options delegates to memory layer.""" + mock_memory.get_unique_attack_class_names.return_value = ["CrescendoAttack", "ManualAttack"] + + result = await attack_service.get_attack_options_async() + + assert result == ["CrescendoAttack", "ManualAttack"] + mock_memory.get_unique_attack_class_names.assert_called_once() + + +# ============================================================================ +# Converter Options Tests +# ============================================================================ + + +@pytest.mark.usefixtures("patch_central_database") +class TestConverterOptions: + """Tests for get_converter_options_async method.""" + + @pytest.mark.asyncio + async def test_returns_empty_when_no_attacks(self, attack_service, mock_memory) -> None: + """Test that converter options returns empty list when no attacks exist.""" + mock_memory.get_unique_converter_class_names.return_value = [] + + result = await attack_service.get_converter_options_async() + + assert result == [] + mock_memory.get_unique_converter_class_names.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_result_from_memory(self, attack_service, mock_memory) -> None: + """Test that converter options delegates to memory layer.""" + mock_memory.get_unique_converter_class_names.return_value = ["Base64Converter", "ROT13Converter"] + + result = await attack_service.get_converter_options_async() + + assert result == ["Base64Converter", "ROT13Converter"] + mock_memory.get_unique_converter_class_names.assert_called_once() + # ============================================================================ # Get Attack Tests @@ -249,7 +381,7 @@ async def test_get_attack_returns_none_for_nonexistent(self, attack_service, moc """Test that get_attack returns None when AttackResult doesn't exist.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.get_attack_async(attack_id="nonexistent") + result = await attack_service.get_attack_async(conversation_id="nonexistent") assert result is None @@ -259,19 +391,15 @@ async def test_get_attack_returns_attack_details(self, attack_service, mock_memo ar = make_attack_result( conversation_id="test-id", name="My Attack", - target_id="target-1", - target_type="TextTarget", ) mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - result = await attack_service.get_attack_async(attack_id="test-id") + result = await attack_service.get_attack_async(conversation_id="test-id") assert result is not None - assert result.attack_id == "test-id" - assert result.target_id == "target-1" - assert result.target_type == "TextTarget" - assert result.name == "My Attack" + assert result.conversation_id == "test-id" + assert result.attack_type == "My Attack" # ============================================================================ @@ -288,7 +416,7 @@ async def test_get_attack_messages_returns_none_for_nonexistent(self, attack_ser """Test that get_attack_messages returns None when attack doesn't exist.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.get_attack_messages_async(attack_id="nonexistent") + result = await attack_service.get_attack_messages_async(conversation_id="nonexistent") assert result is None @@ -299,10 +427,10 @@ async def test_get_attack_messages_returns_messages(self, attack_service, mock_m mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - result = await attack_service.get_attack_messages_async(attack_id="test-id") + result = await attack_service.get_attack_messages_async(conversation_id="test-id") assert result is not None - assert result.attack_id == "test-id" + assert result.conversation_id == "test-id" assert result.messages == [] @@ -324,19 +452,26 @@ async def test_create_attack_validates_target_exists(self, attack_service) -> No mock_get_target_service.return_value = mock_target_service with pytest.raises(ValueError, match="not found"): - await attack_service.create_attack_async(request=CreateAttackRequest(target_id="nonexistent")) + await attack_service.create_attack_async(request=CreateAttackRequest(target_unique_name="nonexistent")) @pytest.mark.asyncio async def test_create_attack_stores_attack_result(self, attack_service, mock_memory) -> None: """Test that create_attack stores AttackResult in memory.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj mock_get_target_service.return_value = mock_target_service - result = await attack_service.create_attack_async(request=CreateAttackRequest(target_id="target-1", name="My Attack")) + result = await attack_service.create_attack_async( + request=CreateAttackRequest(target_unique_name="target-1", name="My Attack") + ) - assert result.attack_id is not None + assert result.conversation_id is not None assert result.created_at is not None mock_memory.add_attack_results_to_memory.assert_called_once() @@ -344,8 +479,13 @@ async def test_create_attack_stores_attack_result(self, attack_service, mock_mem async def test_create_attack_stores_prepended_conversation(self, attack_service, mock_memory) -> None: """Test that create_attack stores prepended conversation messages.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj mock_get_target_service.return_value = mock_target_service prepended = [ @@ -356,10 +496,10 @@ async def test_create_attack_stores_prepended_conversation(self, attack_service, ] result = await attack_service.create_attack_async( - request=CreateAttackRequest(target_id="target-1", prepended_conversation=prepended) + request=CreateAttackRequest(target_unique_name="target-1", prepended_conversation=prepended) ) - assert result.attack_id is not None + assert result.conversation_id is not None # Both attack result and prepended message pieces should be stored mock_memory.add_attack_results_to_memory.assert_called_once() mock_memory.add_message_pieces_to_memory.assert_called() @@ -368,13 +508,18 @@ async def test_create_attack_stores_prepended_conversation(self, attack_service, async def test_create_attack_does_not_store_labels_in_metadata(self, attack_service, mock_memory) -> None: """Test that labels are not stored in attack metadata (they live on pieces).""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj mock_get_target_service.return_value = mock_target_service await attack_service.create_attack_async( request=CreateAttackRequest( - target_id="target-1", + target_unique_name="target-1", name="Labeled Attack", labels={"env": "prod", "team": "red"}, ) @@ -388,8 +533,13 @@ async def test_create_attack_does_not_store_labels_in_metadata(self, attack_serv async def test_create_attack_stamps_labels_on_prepended_pieces(self, attack_service, mock_memory) -> None: """Test that labels are forwarded to prepended message pieces.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj mock_get_target_service.return_value = mock_target_service prepended = [ @@ -401,14 +551,14 @@ async def test_create_attack_stamps_labels_on_prepended_pieces(self, attack_serv await attack_service.create_attack_async( request=CreateAttackRequest( - target_id="target-1", + target_unique_name="target-1", labels={"env": "prod"}, prepended_conversation=prepended, ) ) stored_piece = mock_memory.add_message_pieces_to_memory.call_args[1]["message_pieces"][0] - assert stored_piece.labels == {"env": "prod"} + assert stored_piece.labels == {"env": "prod", "source": "gui"} @pytest.mark.asyncio async def test_create_attack_prepended_messages_have_incrementing_sequences( @@ -416,8 +566,13 @@ async def test_create_attack_prepended_messages_have_incrementing_sequences( ) -> None: """Test that multiple prepended messages get incrementing sequence numbers and preserve lineage.""" with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) mock_target_service = MagicMock() mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj mock_get_target_service.return_value = mock_target_service original_id_1 = "aaaaaaaa-1111-2222-3333-444444444444" @@ -449,7 +604,7 @@ async def test_create_attack_prepended_messages_have_incrementing_sequences( ] await attack_service.create_attack_async( - request=CreateAttackRequest(target_id="target-1", prepended_conversation=prepended) + request=CreateAttackRequest(target_unique_name="target-1", prepended_conversation=prepended) ) # Each message stored separately with incrementing sequence @@ -473,6 +628,57 @@ async def test_create_attack_prepended_messages_have_incrementing_sequences( for piece in stored_pieces: assert piece.id != piece.original_prompt_id + @pytest.mark.asyncio + async def test_create_attack_preserves_user_supplied_source_label(self, attack_service, mock_memory) -> None: + """Test that setdefault does not overwrite user-supplied 'source' label.""" + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) + mock_target_service = MagicMock() + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj + mock_get_target_service.return_value = mock_target_service + + prepended = [ + PrependedMessageRequest( + role="system", + pieces=[MessagePieceRequest(original_value="Be helpful.")], + ) + ] + + await attack_service.create_attack_async( + request=CreateAttackRequest( + target_unique_name="target-1", + labels={"source": "api-test"}, + prepended_conversation=prepended, + ) + ) + + stored_piece = mock_memory.add_message_pieces_to_memory.call_args[1]["message_pieces"][0] + assert stored_piece.labels["source"] == "api-test" # not overwritten to "gui" + + @pytest.mark.asyncio + async def test_create_attack_default_name(self, attack_service, mock_memory) -> None: + """Test that request.name=None uses default class_name and objective.""" + with patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_service: + mock_target_obj = MagicMock() + mock_target_obj.get_identifier.return_value = TargetIdentifier( + class_name="TextTarget", class_module="pyrit.prompt_target" + ) + mock_target_service = MagicMock() + mock_target_service.get_target_async = AsyncMock(return_value=MagicMock(type="TextTarget")) + mock_target_service.get_target_object.return_value = mock_target_obj + mock_get_target_service.return_value = mock_target_service + + await attack_service.create_attack_async(request=CreateAttackRequest(target_unique_name="target-1")) + + call_args = mock_memory.add_attack_results_to_memory.call_args + stored_ar = call_args[1]["attack_results"][0] + assert stored_ar.objective == "Manual attack via GUI" + assert stored_ar.attack_identifier.class_name == "ManualAttack" + # ============================================================================ # Update Attack Tests @@ -488,21 +694,68 @@ async def test_update_attack_returns_none_for_nonexistent(self, attack_service, """Test that update_attack returns None for nonexistent attack.""" mock_memory.get_attack_results.return_value = [] - result = await attack_service.update_attack_async(attack_id="nonexistent", request=UpdateAttackRequest(outcome="success")) + result = await attack_service.update_attack_async( + conversation_id="nonexistent", request=UpdateAttackRequest(outcome="success") + ) assert result is None @pytest.mark.asyncio - async def test_update_attack_updates_outcome(self, attack_service, mock_memory) -> None: - """Test that update_attack updates the AttackResult outcome.""" + async def test_update_attack_updates_outcome_success(self, attack_service, mock_memory) -> None: + """Test that update_attack maps 'success' to AttackOutcome.SUCCESS.""" ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_conversation.return_value = [] - await attack_service.update_attack_async(attack_id="test-id", request=UpdateAttackRequest(outcome="success")) + await attack_service.update_attack_async( + conversation_id="test-id", request=UpdateAttackRequest(outcome="success") + ) + + stored_ar = mock_memory.add_attack_results_to_memory.call_args[1]["attack_results"][0] + assert stored_ar.outcome == AttackOutcome.SUCCESS + + @pytest.mark.asyncio + async def test_update_attack_updates_outcome_failure(self, attack_service, mock_memory) -> None: + """Test that update_attack maps 'failure' to AttackOutcome.FAILURE.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] + + await attack_service.update_attack_async( + conversation_id="test-id", request=UpdateAttackRequest(outcome="failure") + ) + + stored_ar = mock_memory.add_attack_results_to_memory.call_args[1]["attack_results"][0] + assert stored_ar.outcome == AttackOutcome.FAILURE + + @pytest.mark.asyncio + async def test_update_attack_updates_outcome_undetermined(self, attack_service, mock_memory) -> None: + """Test that update_attack maps 'undetermined' to AttackOutcome.UNDETERMINED.""" + ar = make_attack_result(conversation_id="test-id", outcome=AttackOutcome.SUCCESS) + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] - # Should call add_attack_results_to_memory to update - mock_memory.add_attack_results_to_memory.assert_called() + await attack_service.update_attack_async( + conversation_id="test-id", request=UpdateAttackRequest(outcome="undetermined") + ) + + stored_ar = mock_memory.add_attack_results_to_memory.call_args[1]["attack_results"][0] + assert stored_ar.outcome == AttackOutcome.UNDETERMINED + + @pytest.mark.asyncio + async def test_update_attack_refreshes_updated_at(self, attack_service, mock_memory) -> None: + """Test that update_attack refreshes the updated_at metadata.""" + old_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + ar = make_attack_result(conversation_id="test-id", updated_at=old_time) + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] + + await attack_service.update_attack_async( + conversation_id="test-id", request=UpdateAttackRequest(outcome="success") + ) + + stored_ar = mock_memory.add_attack_results_to_memory.call_args[1]["attack_results"][0] + assert stored_ar.metadata["updated_at"] != old_time.isoformat() # ============================================================================ @@ -524,12 +777,12 @@ async def test_add_message_raises_for_nonexistent_attack(self, attack_service, m ) with pytest.raises(ValueError, match="not found"): - await attack_service.add_message_async(attack_id="nonexistent", request=request) + await attack_service.add_message_async(conversation_id="nonexistent", request=request) @pytest.mark.asyncio async def test_add_message_without_send_stamps_labels_on_pieces(self, attack_service, mock_memory) -> None: """Test that add_message (send=False) inherits labels from existing pieces.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] existing_piece = make_mock_piece(conversation_id="test-id") @@ -543,7 +796,7 @@ async def test_add_message_without_send_stamps_labels_on_pieces(self, attack_ser send=False, ) - result = await attack_service.add_message_async(attack_id="test-id", request=request) + result = await attack_service.add_message_async(conversation_id="test-id", request=request) stored_piece = mock_memory.add_message_pieces_to_memory.call_args[1]["message_pieces"][0] assert stored_piece.labels == {"env": "prod"} @@ -552,7 +805,7 @@ async def test_add_message_without_send_stamps_labels_on_pieces(self, attack_ser @pytest.mark.asyncio async def test_add_message_with_send_passes_labels_to_normalizer(self, attack_service, mock_memory) -> None: """Test that add_message (send=True) inherits labels from existing pieces.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] existing_piece = make_mock_piece(conversation_id="test-id") @@ -577,7 +830,7 @@ async def test_add_message_with_send_passes_labels_to_normalizer(self, attack_se send=True, ) - await attack_service.add_message_async(attack_id="test-id", request=request) + await attack_service.add_message_async(conversation_id="test-id", request=request) call_kwargs = mock_normalizer.send_prompt_async.call_args[1] assert call_kwargs["labels"] == {"env": "staging"} @@ -585,8 +838,7 @@ async def test_add_message_with_send_passes_labels_to_normalizer(self, attack_se @pytest.mark.asyncio async def test_add_message_raises_when_no_target_id(self, attack_service, mock_memory) -> None: """Test that add_message raises ValueError when attack has no target configured.""" - ar = make_attack_result(conversation_id="test-id", target_id="") - ar.attack_identifier["target_id"] = "" # Explicitly set to empty + ar = make_attack_result(conversation_id="test-id", has_target=False) mock_memory.get_attack_results.return_value = [ar] request = AddMessageRequest( @@ -594,12 +846,12 @@ async def test_add_message_raises_when_no_target_id(self, attack_service, mock_m ) with pytest.raises(ValueError, match="has no target configured"): - await attack_service.add_message_async(attack_id="test-id", request=request) + await attack_service.add_message_async(conversation_id="test-id", request=request) @pytest.mark.asyncio async def test_add_message_with_send_calls_normalizer(self, attack_service, mock_memory) -> None: """Test that add_message with send=True sends message via normalizer.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] mock_memory.get_conversation.return_value = [] @@ -621,7 +873,7 @@ async def test_add_message_with_send_calls_normalizer(self, attack_service, mock send=True, ) - result = await attack_service.add_message_async(attack_id="test-id", request=request) + result = await attack_service.add_message_async(conversation_id="test-id", request=request) mock_normalizer.send_prompt_async.assert_called_once() assert result.attack is not None @@ -629,7 +881,7 @@ async def test_add_message_with_send_calls_normalizer(self, attack_service, mock @pytest.mark.asyncio async def test_add_message_with_send_raises_when_target_not_found(self, attack_service, mock_memory) -> None: """Test that add_message with send=True raises when target object not found.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] @@ -644,12 +896,12 @@ async def test_add_message_with_send_raises_when_target_not_found(self, attack_s ) with pytest.raises(ValueError, match="Target object .* not found"): - await attack_service.add_message_async(attack_id="test-id", request=request) + await attack_service.add_message_async(conversation_id="test-id", request=request) @pytest.mark.asyncio async def test_add_message_with_converter_ids_gets_converters(self, attack_service, mock_memory) -> None: """Test that add_message with converter_ids gets converters from service.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] mock_memory.get_conversation.return_value = [] @@ -680,14 +932,14 @@ async def test_add_message_with_converter_ids_gets_converters(self, attack_servi converter_ids=["conv-1"], ) - await attack_service.add_message_async(attack_id="test-id", request=request) + await attack_service.add_message_async(conversation_id="test-id", request=request) mock_conv_svc.get_converter_objects_for_ids.assert_called_once_with(converter_ids=["conv-1"]) @pytest.mark.asyncio async def test_add_message_raises_when_attack_not_found_after_update(self, attack_service, mock_memory) -> None: """Test that add_message raises ValueError when attack disappears after update.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] mock_memory.get_conversation.return_value = [] @@ -700,12 +952,12 @@ async def test_add_message_raises_when_attack_not_found_after_update(self, attac with patch.object(attack_service, "get_attack_async", new=AsyncMock(return_value=None)): with pytest.raises(ValueError, match="not found after update"): - await attack_service.add_message_async(attack_id="test-id", request=request) + await attack_service.add_message_async(conversation_id="test-id", request=request) @pytest.mark.asyncio async def test_add_message_raises_when_messages_not_found_after_update(self, attack_service, mock_memory) -> None: """Test that add_message raises ValueError when messages disappear after update.""" - ar = make_attack_result(conversation_id="test-id", target_id="target-1") + ar = make_attack_result(conversation_id="test-id") mock_memory.get_attack_results.return_value = [ar] mock_memory.get_message_pieces.return_value = [] mock_memory.get_conversation.return_value = [] @@ -721,7 +973,63 @@ async def test_add_message_raises_when_messages_not_found_after_update(self, att patch.object(attack_service, "get_attack_messages_async", new=AsyncMock(return_value=None)), ): with pytest.raises(ValueError, match="messages not found after update"): - await attack_service.add_message_async(attack_id="test-id", request=request) + await attack_service.add_message_async(conversation_id="test-id", request=request) + + @pytest.mark.asyncio + async def test_get_converter_configs_skips_when_preconverted(self, attack_service, mock_memory) -> None: + """Test that _get_converter_configs returns [] when pieces have converted_value set.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] + mock_memory.get_conversation.return_value = [] + + with ( + patch("pyrit.backend.services.attack_service.get_target_service") as mock_get_target_svc, + patch("pyrit.backend.services.attack_service.get_converter_service") as mock_get_conv_svc, + patch("pyrit.backend.services.attack_service.PromptNormalizer") as mock_normalizer_cls, + ): + mock_target_svc = MagicMock() + mock_target_svc.get_target_object.return_value = MagicMock() + mock_get_target_svc.return_value = mock_target_svc + + mock_normalizer = MagicMock() + mock_normalizer.send_prompt_async = AsyncMock() + mock_normalizer_cls.return_value = mock_normalizer + + request = AddMessageRequest( + pieces=[MessagePieceRequest(original_value="Hello", converted_value="SGVsbG8=")], + send=True, + converter_ids=["conv-1"], + ) + + await attack_service.add_message_async(conversation_id="test-id", request=request) + + # Converter service should NOT be called since pieces are preconverted + mock_get_conv_svc.assert_not_called() + # Normalizer should still be called with empty converter configs + call_kwargs = mock_normalizer.send_prompt_async.call_args[1] + assert call_kwargs["request_converter_configurations"] == [] + + @pytest.mark.asyncio + async def test_add_message_no_existing_pieces_labels_none(self, attack_service, mock_memory) -> None: + """Test that add_message with no existing pieces passes labels=None to storage.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_message_pieces.return_value = [] # No existing pieces + mock_memory.get_conversation.return_value = [] + + request = AddMessageRequest( + role="user", + pieces=[MessagePieceRequest(original_value="Hello")], + send=False, + ) + + result = await attack_service.add_message_async(conversation_id="test-id", request=request) + + stored_piece = mock_memory.add_message_pieces_to_memory.call_args[1]["message_pieces"][0] + # No labels inherited from existing pieces (no existing pieces had labels) + assert stored_piece.labels is None or stored_piece.labels == {} + assert result.attack is not None # ============================================================================ @@ -782,7 +1090,7 @@ async def test_list_attacks_cursor_skips_to_correct_position(self, attack_servic # Cursor = attack-1 should skip attack-1 and return from attack-2 onward result = await attack_service.list_attacks_async(cursor="attack-1", limit=10) - attack_ids = [item.attack_id for item in result.items] + attack_ids = [item.conversation_id for item in result.items] assert "attack-1" not in attack_ids assert len(result.items) == 2 @@ -798,6 +1106,45 @@ async def test_list_attacks_fetches_pieces_only_for_page(self, attack_service, m # get_message_pieces should be called only for the 2 items on the page, not all 5 assert mock_memory.get_message_pieces.call_count == 2 + @pytest.mark.asyncio + async def test_pagination_cursor_not_found_returns_from_start(self, attack_service, mock_memory) -> None: + """Test that a stale/invalid cursor defaults to returning from the beginning.""" + ar1 = make_attack_result( + conversation_id="attack-1", + updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc), + ) + ar2 = make_attack_result( + conversation_id="attack-2", + updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + mock_memory.get_attack_results.return_value = [ar1, ar2] + mock_memory.get_message_pieces.return_value = [] + + result = await attack_service.list_attacks_async(cursor="nonexistent-cursor", limit=10) + + # Should return all items (from beginning) since cursor wasn't found + assert len(result.items) == 2 + + @pytest.mark.asyncio + async def test_pagination_cursor_at_last_item_returns_empty(self, attack_service, mock_memory) -> None: + """Test that cursor pointing to the last item returns empty page with has_more=False.""" + ar1 = make_attack_result( + conversation_id="attack-1", + updated_at=datetime(2024, 1, 2, tzinfo=timezone.utc), + ) + ar2 = make_attack_result( + conversation_id="attack-2", + updated_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + mock_memory.get_attack_results.return_value = [ar1, ar2] + mock_memory.get_message_pieces.return_value = [] + + # Cursor = last sorted item (attack-2 has the oldest updated_at, so it's last) + result = await attack_service.list_attacks_async(cursor="attack-2", limit=10) + + assert len(result.items) == 0 + assert result.pagination.has_more is False + # ============================================================================ # Message Building Tests @@ -832,7 +1179,7 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic mock_memory.get_conversation.return_value = [mock_msg] - result = await attack_service.get_attack_messages_async(attack_id="test-id") + result = await attack_service.get_attack_messages_async(conversation_id="test-id") assert result is not None assert len(result.messages) == 1 diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 796db1d271..9b8fc2e6d6 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -53,12 +53,11 @@ async def test_list_converters_returns_converters_from_registry(self) -> None: mock_converter = MagicMock() mock_converter.__class__.__name__ = "MockConverter" mock_identifier = MagicMock() - mock_identifier.to_dict.return_value = { - "class_name": "MockConverter", - "converter_specific_params": {"param1": "value1", "param2": 42}, - "supported_input_types": ["text"], - "supported_output_types": ["text"], - } + mock_identifier.class_name = "MockConverter" + mock_identifier.supported_input_types = ("text",) + mock_identifier.supported_output_types = ("text",) + mock_identifier.converter_specific_params = {"param1": "value1", "param2": 42} + mock_identifier.sub_identifier = None mock_converter.get_identifier.return_value = mock_identifier service._registry.register_instance(mock_converter, name="conv-1") @@ -66,12 +65,10 @@ async def test_list_converters_returns_converters_from_registry(self) -> None: assert len(result.items) == 1 assert result.items[0].converter_id == "conv-1" - assert result.items[0].type == "MockConverter" - # Verify params contains the full identifier dict - assert result.items[0].params["class_name"] == "MockConverter" - assert result.items[0].params["converter_specific_params"] == {"param1": "value1", "param2": 42} - assert result.items[0].params["supported_input_types"] == ["text"] - assert result.items[0].params["supported_output_types"] == ["text"] + assert result.items[0].converter_type == "MockConverter" + assert result.items[0].supported_input_types == ["text"] + assert result.items[0].supported_output_types == ["text"] + assert result.items[0].converter_specific_params == {"param1": "value1", "param2": 42} class TestGetConverter: @@ -94,10 +91,11 @@ async def test_get_converter_returns_converter_from_registry(self) -> None: mock_converter = MagicMock() mock_converter.__class__.__name__ = "MockConverter" mock_identifier = MagicMock() - mock_identifier.to_dict.return_value = { - "class_name": "MockConverter", - "converter_specific_params": {"param1": "value1"}, - } + mock_identifier.class_name = "MockConverter" + mock_identifier.supported_input_types = ("text",) + mock_identifier.supported_output_types = ("text",) + mock_identifier.converter_specific_params = {"param1": "value1"} + mock_identifier.sub_identifier = None mock_converter.get_identifier.return_value = mock_identifier service._registry.register_instance(mock_converter, name="conv-1") @@ -105,7 +103,7 @@ async def test_get_converter_returns_converter_from_registry(self) -> None: assert result is not None assert result.converter_id == "conv-1" - assert result.type == "MockConverter" + assert result.converter_type == "MockConverter" class TestGetConverterObject: @@ -160,7 +158,7 @@ async def test_create_converter_success(self) -> None: result = await service.create_converter_async(request=request) assert result.converter_id is not None - assert result.type == "Base64Converter" + assert result.converter_type == "Base64Converter" assert result.display_name == "My Base64" @pytest.mark.asyncio @@ -400,8 +398,8 @@ def test_build_instance_from_converter(self, converter_name: str) -> None: For converters that can be instantiated with no arguments, verifies: - converter_id is set correctly - - type matches the class name - - params contains class_name from the identifier + - converter_type matches the class name + - supported_input_types and supported_output_types are lists For converters requiring arguments, the test is skipped (since we can't know the required parameters without external configuration). @@ -418,11 +416,9 @@ def test_build_instance_from_converter(self, converter_name: str) -> None: # Verify the result assert result.converter_id == "test-id" - assert result.type == converter_name - assert isinstance(result.params, dict) - # The params should contain at least class_name from the identifier - assert "class_name" in result.params - assert result.params["class_name"] == converter_name + assert result.converter_type == converter_name + assert isinstance(result.supported_input_types, list) + assert isinstance(result.supported_output_types, list) class TestConverterParamsExtraction: @@ -439,9 +435,9 @@ def test_caesar_converter_params(self) -> None: service = ConverterService() result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) - assert result.type == "CaesarConverter" - converter_specific = result.params.get("converter_specific_params", {}) - assert converter_specific.get("caesar_offset") == 13 + assert result.converter_type == "CaesarConverter" + assert result.converter_specific_params is not None + assert result.converter_specific_params.get("caesar_offset") == 13 def test_suffix_append_converter_params(self) -> None: """Test that SuffixAppendConverter params are extracted correctly.""" @@ -449,9 +445,9 @@ def test_suffix_append_converter_params(self) -> None: service = ConverterService() result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) - assert result.type == "SuffixAppendConverter" - converter_specific = result.params.get("converter_specific_params", {}) - assert converter_specific.get("suffix") == "test suffix" + assert result.converter_type == "SuffixAppendConverter" + assert result.converter_specific_params is not None + assert result.converter_specific_params.get("suffix") == "test suffix" def test_repeat_token_converter_params(self) -> None: """Test that RepeatTokenConverter params are extracted correctly.""" @@ -459,10 +455,10 @@ def test_repeat_token_converter_params(self) -> None: service = ConverterService() result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) - assert result.type == "RepeatTokenConverter" - converter_specific = result.params.get("converter_specific_params", {}) - assert converter_specific.get("token_to_repeat") == "x" - assert converter_specific.get("times_to_repeat") == 5 + assert result.converter_type == "RepeatTokenConverter" + assert result.converter_specific_params is not None + assert result.converter_specific_params.get("token_to_repeat") == "x" + assert result.converter_specific_params.get("times_to_repeat") == 5 def test_base64_converter_default_params(self) -> None: """Test that Base64Converter default params are captured.""" @@ -470,8 +466,7 @@ def test_base64_converter_default_params(self) -> None: service = ConverterService() result = service._build_instance_from_object(converter_id="test-id", converter_obj=converter) - assert result.type == "Base64Converter" - # Verify params dict is populated from identifier - assert "class_name" in result.params - assert "supported_input_types" in result.params - assert "supported_output_types" in result.params + assert result.converter_type == "Base64Converter" + # Verify type info is populated from identifier + assert isinstance(result.supported_input_types, list) + assert isinstance(result.supported_output_types, list) diff --git a/tests/unit/backend/test_main.py b/tests/unit/backend/test_main.py index f9dd85f1f0..51e935d61e 100644 --- a/tests/unit/backend/test_main.py +++ b/tests/unit/backend/test_main.py @@ -8,7 +8,6 @@ """ import os -import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 51ba3c5e7d..67ce7aa32b 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -11,8 +11,6 @@ from datetime import datetime, timezone from unittest.mock import MagicMock -import pytest - from pyrit.backend.mappers.attack_mappers import ( _collect_labels_from_pieces, _infer_mime_type, @@ -25,9 +23,9 @@ ) from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.mappers.target_mappers import target_object_to_instance +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, TargetIdentifier from pyrit.models import AttackOutcome, AttackResult - # ============================================================================ # Helpers # ============================================================================ @@ -36,21 +34,33 @@ def _make_attack_result( *, conversation_id: str = "attack-1", - target_id: str = "target-1", - target_type: str = "TextTarget", + has_target: bool = True, name: str = "Test Attack", outcome: AttackOutcome = AttackOutcome.UNDETERMINED, ) -> AttackResult: """Create an AttackResult for mapper tests.""" now = datetime.now(timezone.utc) + + target_identifier = ( + TargetIdentifier( + class_name="TextTarget", + class_module="pyrit.prompt_target", + ) + if has_target + else None + ) + return AttackResult( conversation_id=conversation_id, objective="test", - attack_identifier={ - "name": name, - "target_id": target_id, - "target_type": target_type, - }, + attack_identifier=AttackIdentifier( + class_name=name, + class_module="pyrit.backend", + objective_target_identifier=target_identifier, + attack_specific_params={ + "source": "gui", + }, + ), outcome=outcome, metadata={ "created_at": now.isoformat(), @@ -114,17 +124,18 @@ class TestAttackResultToSummary: def test_basic_mapping(self) -> None: """Test that all fields are mapped correctly.""" - ar = _make_attack_result(name="My Attack", target_id="t-1", target_type="OpenAIChatTarget") + ar = _make_attack_result(name="My Attack") pieces = [_make_mock_piece(sequence=0), _make_mock_piece(sequence=1)] summary = attack_result_to_summary(ar, pieces=pieces) - assert summary.attack_id == ar.conversation_id - assert summary.name == "My Attack" - assert summary.target_id == "t-1" - assert summary.target_type == "OpenAIChatTarget" + assert summary.conversation_id == ar.conversation_id assert summary.outcome == "undetermined" assert summary.message_count == 2 + # Attack metadata should be extracted into explicit fields + assert summary.attack_type == "My Attack" + assert summary.target_type == "TextTarget" + assert summary.target_unique_name is not None def test_empty_pieces_gives_zero_messages(self) -> None: """Test mapping with no message pieces.""" @@ -165,6 +176,63 @@ def test_outcome_success(self) -> None: assert summary.outcome == "success" + def test_no_target_returns_none_fields(self) -> None: + """Test that target fields are None when no target identifier exists.""" + ar = _make_attack_result(has_target=False) + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.target_unique_name is None + assert summary.target_type is None + + def test_attack_specific_params_passed_through(self) -> None: + """Test that attack_specific_params are extracted from identifier.""" + ar = _make_attack_result() + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.attack_specific_params == {"source": "gui"} + + def test_converters_extracted_from_identifier(self) -> None: + """Test that converter class names are extracted into converters list.""" + now = datetime.now(timezone.utc) + ar = AttackResult( + conversation_id="attack-conv", + objective="test", + attack_identifier=AttackIdentifier( + class_name="TestAttack", + class_module="pyrit.backend", + request_converter_identifiers=[ + ConverterIdentifier( + class_name="Base64Converter", + class_module="pyrit.converters", + supported_input_types=("text",), + supported_output_types=("text",), + ), + ConverterIdentifier( + class_name="ROT13Converter", + class_module="pyrit.converters", + supported_input_types=("text",), + supported_output_types=("text",), + ), + ], + ), + outcome=AttackOutcome.UNDETERMINED, + metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, + ) + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.converters == ["Base64Converter", "ROT13Converter"] + + def test_no_converters_returns_empty_list(self) -> None: + """Test that converters is empty list when no converters in identifier.""" + ar = _make_attack_result() + + summary = attack_result_to_summary(ar, pieces=[]) + + assert summary.converters == [] + class TestPyritScoresToDto: """Tests for pyrit_scores_to_dto function.""" @@ -521,37 +589,52 @@ class TestTargetObjectToInstance: def test_maps_target_with_identifier(self) -> None: """Test mapping a target object that has get_identifier.""" target_obj = MagicMock() - target_obj.get_identifier.return_value = {"__type__": "OpenAIChatTarget", "endpoint": "http://test"} + mock_identifier = MagicMock() + mock_identifier.class_name = "OpenAIChatTarget" + mock_identifier.endpoint = "http://test" + mock_identifier.model_name = "gpt-4" + mock_identifier.temperature = 0.7 + mock_identifier.top_p = None + mock_identifier.max_requests_per_minute = None + mock_identifier.target_specific_params = None + target_obj.get_identifier.return_value = mock_identifier result = target_object_to_instance("t-1", target_obj) - assert result.target_id == "t-1" - assert result.type == "OpenAIChatTarget" - assert result.display_name is None + assert result.target_unique_name == "t-1" + assert result.target_type == "OpenAIChatTarget" + assert result.endpoint == "http://test" + assert result.model_name == "gpt-4" + assert result.temperature == 0.7 - def test_filters_sensitive_fields(self) -> None: - """Test that sensitive fields are removed from params.""" + def test_no_endpoint_returns_none(self) -> None: + """Test that missing endpoint returns None.""" target_obj = MagicMock() - target_obj.get_identifier.return_value = { - "__type__": "TestTarget", - "api_key": "secret-key", - "endpoint": "http://test", - } + mock_identifier = MagicMock() + mock_identifier.class_name = "TextTarget" + mock_identifier.endpoint = None + mock_identifier.model_name = None + mock_identifier.temperature = None + mock_identifier.top_p = None + mock_identifier.max_requests_per_minute = None + mock_identifier.target_specific_params = None + target_obj.get_identifier.return_value = mock_identifier result = target_object_to_instance("t-1", target_obj) - assert "api_key" not in result.params - assert result.params.get("endpoint") == "http://test" + assert result.target_type == "TextTarget" + assert result.endpoint is None + assert result.model_name is None - def test_fallback_to_class_name(self) -> None: - """Test fallback to __class__.__name__ when no __type__ in identifier.""" - target_obj = MagicMock() - target_obj.__class__.__name__ = "FallbackTarget" - target_obj.get_identifier.return_value = {"endpoint": "http://test"} + def test_no_get_identifier_falls_back_to_class_name(self) -> None: + """Test fallback when target has no get_identifier method.""" + target_obj = MagicMock(spec=[]) + target_obj.__class__ = type("FakeTarget", (), {}) result = target_object_to_instance("t-1", target_obj) - assert result.type == "FallbackTarget" + assert result.target_type == "FakeTarget" + assert result.endpoint is None # ============================================================================ @@ -566,24 +649,51 @@ def test_maps_converter_with_identifier(self) -> None: """Test mapping a converter object.""" converter_obj = MagicMock() identifier = MagicMock() - identifier.to_dict.return_value = {"class_name": "Base64Converter", "param1": "value1"} + identifier.class_name = "Base64Converter" + identifier.supported_input_types = ("text",) + identifier.supported_output_types = ("text",) + identifier.converter_specific_params = {"param1": "value1"} + identifier.sub_identifier = None converter_obj.get_identifier.return_value = identifier result = converter_object_to_instance("c-1", converter_obj) assert result.converter_id == "c-1" - assert result.type == "Base64Converter" + assert result.converter_type == "Base64Converter" assert result.display_name is None - assert result.params["class_name"] == "Base64Converter" + assert result.supported_input_types == ["text"] + assert result.supported_output_types == ["text"] + assert result.converter_specific_params == {"param1": "value1"} + assert result.sub_converter_ids is None + + def test_sub_converter_ids_passed_through(self) -> None: + """Test that sub_converter_ids are passed through when provided.""" + converter_obj = MagicMock() + identifier = MagicMock() + identifier.class_name = "PipelineConverter" + identifier.supported_input_types = ("text",) + identifier.supported_output_types = ("text",) + identifier.converter_specific_params = None + identifier.sub_identifier = None + converter_obj.get_identifier.return_value = identifier + + result = converter_object_to_instance("c-1", converter_obj, sub_converter_ids=["sub-1", "sub-2"]) + + assert result.sub_converter_ids == ["sub-1", "sub-2"] - def test_fallback_to_class_name(self) -> None: - """Test fallback to __class__.__name__ when no class_name in identifier.""" + def test_none_input_output_types_returns_empty_lists(self) -> None: + """Test that None supported types produce empty lists.""" converter_obj = MagicMock() - converter_obj.__class__.__name__ = "FallbackConverter" identifier = MagicMock() - identifier.to_dict.return_value = {"param1": "value1"} + identifier.class_name = "CustomConverter" + identifier.supported_input_types = None + identifier.supported_output_types = None + identifier.converter_specific_params = None converter_obj.get_identifier.return_value = identifier result = converter_object_to_instance("c-1", converter_obj) - assert result.type == "FallbackConverter" + assert result.supported_input_types == [] + assert result.supported_output_types == [] + assert result.converter_specific_params is None + assert result.sub_converter_ids is None diff --git a/tests/unit/backend/test_target_service.py b/tests/unit/backend/test_target_service.py index 4935970a05..34e79c31c5 100644 --- a/tests/unit/backend/test_target_service.py +++ b/tests/unit/backend/test_target_service.py @@ -5,6 +5,7 @@ Tests for backend target service. """ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -22,6 +23,19 @@ def reset_registry(): TargetRegistry.reset_instance() +def _mock_target_identifier(*, class_name: str = "MockTarget", **kwargs) -> SimpleNamespace: + """Create a mock target identifier with attribute access.""" + return SimpleNamespace( + class_name=class_name, + endpoint=kwargs.get("endpoint"), + model_name=kwargs.get("model_name"), + temperature=kwargs.get("temperature"), + top_p=kwargs.get("top_p"), + max_requests_per_minute=kwargs.get("max_requests_per_minute"), + target_specific_params=kwargs.get("target_specific_params"), + ) + + class TestListTargets: """Tests for TargetService.list_targets method.""" @@ -42,14 +56,14 @@ async def test_list_targets_returns_targets_from_registry(self) -> None: # Register a mock target mock_target = MagicMock() - mock_target.get_identifier.return_value = {"__type__": "MockTarget", "endpoint": "http://test"} + mock_target.get_identifier.return_value = _mock_target_identifier(endpoint="http://test") service._registry.register_instance(mock_target, name="target-1") result = await service.list_targets_async() assert len(result.items) == 1 - assert result.items[0].target_id == "target-1" - assert result.items[0].type == "MockTarget" + assert result.items[0].target_unique_name == "target-1" + assert result.items[0].target_type == "MockTarget" assert result.pagination.has_more is False @pytest.mark.asyncio @@ -59,7 +73,7 @@ async def test_list_targets_paginates_with_limit(self) -> None: for i in range(5): mock_target = MagicMock() - mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + mock_target.get_identifier.return_value = _mock_target_identifier() service._registry.register_instance(mock_target, name=f"target-{i}") result = await service.list_targets_async(limit=3) @@ -67,7 +81,7 @@ async def test_list_targets_paginates_with_limit(self) -> None: assert len(result.items) == 3 assert result.pagination.limit == 3 assert result.pagination.has_more is True - assert result.pagination.next_cursor == result.items[-1].target_id + assert result.pagination.next_cursor == result.items[-1].target_unique_name @pytest.mark.asyncio async def test_list_targets_cursor_returns_next_page(self) -> None: @@ -76,14 +90,14 @@ async def test_list_targets_cursor_returns_next_page(self) -> None: for i in range(5): mock_target = MagicMock() - mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + mock_target.get_identifier.return_value = _mock_target_identifier() service._registry.register_instance(mock_target, name=f"target-{i}") first_page = await service.list_targets_async(limit=2) second_page = await service.list_targets_async(limit=2, cursor=first_page.pagination.next_cursor) assert len(second_page.items) == 2 - assert second_page.items[0].target_id != first_page.items[0].target_id + assert second_page.items[0].target_unique_name != first_page.items[0].target_unique_name assert second_page.pagination.has_more is True @pytest.mark.asyncio @@ -93,7 +107,7 @@ async def test_list_targets_last_page_has_no_more(self) -> None: for i in range(3): mock_target = MagicMock() - mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + mock_target.get_identifier.return_value = _mock_target_identifier() service._registry.register_instance(mock_target, name=f"target-{i}") first_page = await service.list_targets_async(limit=2) @@ -112,7 +126,7 @@ async def test_get_target_returns_none_for_nonexistent(self) -> None: """Test that get_target returns None for non-existent target.""" service = TargetService() - result = await service.get_target_async(target_id="nonexistent-id") + result = await service.get_target_async(target_unique_name="nonexistent-id") assert result is None @@ -122,14 +136,14 @@ async def test_get_target_returns_target_from_registry(self) -> None: service = TargetService() mock_target = MagicMock() - mock_target.get_identifier.return_value = {"__type__": "MockTarget"} + mock_target.get_identifier.return_value = _mock_target_identifier() service._registry.register_instance(mock_target, name="target-1") - result = await service.get_target_async(target_id="target-1") + result = await service.get_target_async(target_unique_name="target-1") assert result is not None - assert result.target_id == "target-1" - assert result.type == "MockTarget" + assert result.target_unique_name == "target-1" + assert result.target_type == "MockTarget" class TestGetTargetObject: @@ -139,7 +153,7 @@ def test_get_target_object_returns_none_for_nonexistent(self) -> None: """Test that get_target_object returns None for non-existent target.""" service = TargetService() - result = service.get_target_object(target_id="nonexistent-id") + result = service.get_target_object(target_unique_name="nonexistent-id") assert result is None @@ -149,7 +163,7 @@ def test_get_target_object_returns_object_from_registry(self) -> None: mock_target = MagicMock() service._registry.register_instance(mock_target, name="target-1") - result = service.get_target_object(target_id="target-1") + result = service.get_target_object(target_unique_name="target-1") assert result is mock_target @@ -177,15 +191,13 @@ async def test_create_target_success(self, sqlite_instance) -> None: request = CreateTargetRequest( type="TextTarget", - display_name="My Text Target", params={}, ) result = await service.create_target_async(request=request) - assert result.target_id is not None - assert result.type == "TextTarget" - assert result.display_name == "My Text Target" + assert result.target_unique_name is not None + assert result.target_type == "TextTarget" @pytest.mark.asyncio async def test_create_target_registers_in_registry(self, sqlite_instance) -> None: @@ -200,7 +212,7 @@ async def test_create_target_registers_in_registry(self, sqlite_instance) -> Non result = await service.create_target_async(request=request) # Object should be retrievable from registry - target_obj = service.get_target_object(target_id=result.target_id) + target_obj = service.get_target_object(target_unique_name=result.target_unique_name) assert target_obj is not None diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index b1bb5ebda8..dcd2151a9d 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -3,10 +3,10 @@ import uuid -from typing import Sequence +from typing import Optional, Sequence from pyrit.common.utils import to_sha256 -from pyrit.identifiers import AttackIdentifier, ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier from pyrit.memory import MemoryInterface from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( @@ -961,3 +961,266 @@ def test_get_unique_attack_labels_keys_sorted(sqlite_instance: MemoryInterface): assert list(result.keys()) == ["alpha", "zoo"] assert result["alpha"] == ["a", "b"] assert result["zoo"] == ["z_val"] + + +def test_get_unique_attack_labels_non_dict_labels_skipped(sqlite_instance: MemoryInterface): + """Labels stored as a non-dict JSON value (e.g. a string) are skipped.""" + from contextlib import closing + + from sqlalchemy import text + + # Insert a real attack + piece with normal labels first + msg1 = create_message_piece("conv_1", 1, labels={"env": "prod"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg1]) + ar1 = create_attack_result("conv_1", 1) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + # Insert a second attack and use raw SQL to set labels to a JSON string + msg2 = create_message_piece("conv_2", 2, labels={"placeholder": "x"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[msg2]) + ar2 = create_attack_result("conv_2", 2) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar2]) + with closing(sqlite_instance.get_session()) as session: + session.execute( + text('UPDATE "PromptMemoryEntries" SET labels = \'"just_a_string"\' WHERE conversation_id = :cid'), + {"cid": "conv_2"}, + ) + session.commit() + + result = sqlite_instance.get_unique_attack_labels() + # Only the dict labels from conv_1 should appear + assert result == {"env": ["prod"]} + + +# ============================================================================ +# Attack class and converter class filtering tests +# ============================================================================ + + +def _make_attack_result_with_identifier( + conversation_id: str, + class_name: str, + converter_class_names: Optional[list[str]] = None, +) -> AttackResult: + """Helper to create an AttackResult with an AttackIdentifier containing converters.""" + converter_ids = None + if converter_class_names is not None: + converter_ids = [ + ConverterIdentifier( + class_name=name, + class_module="pyrit.converters", + supported_input_types=("text",), + supported_output_types=("text",), + ) + for name in converter_class_names + ] + + return AttackResult( + conversation_id=conversation_id, + objective=f"Objective for {conversation_id}", + attack_identifier=AttackIdentifier( + class_name=class_name, + class_module="pyrit.attacks", + request_converter_identifiers=converter_ids, + ), + ) + + +def test_get_attack_results_by_attack_class(sqlite_instance: MemoryInterface): + """Test filtering attack results by attack_class matches class_name in JSON.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + assert len(results) == 2 + assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} + + +def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInterface): + """Test that attack_class filter returns empty when nothing matches.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results(attack_class="NonExistentAttack") + assert len(results) == 0 + + +def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): + """Test that attack_class filter is case-sensitive (exact match).""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results(attack_class="crescendoattack") + assert len(results) == 0 + + +def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): + """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" + ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} + ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + assert len(results) == 1 + assert results[0].conversation_id == "conv_2" + + +def test_get_attack_results_converter_classes_none_returns_all(sqlite_instance: MemoryInterface): + """Test that converter_classes=None (omitted) returns all attacks unfiltered.""" + ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "Attack") # No converters (None) + ar3 = create_attack_result("conv_3", 3) # No identifier at all + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + results = sqlite_instance.get_attack_results(converter_classes=None) + assert len(results) == 3 + + +def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite_instance: MemoryInterface): + """Test that converter_classes=[] returns only attacks with no converters.""" + ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") # converter_ids=None + ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) # converter_ids=[] + ar_no_identifier = create_attack_result("conv_4", 4) # No identifier → stored as {} + sqlite_instance.add_attack_results_to_memory( + attack_results=[ar_with_conv, ar_no_conv_none, ar_no_conv_empty, ar_no_identifier] + ) + + results = sqlite_instance.get_attack_results(converter_classes=[]) + conv_ids = {r.conversation_id for r in results} + # Should include attacks with no converters (None key, empty array, or empty identifier) + assert "conv_1" not in conv_ids, "Should not include attacks that have converters" + assert "conv_2" in conv_ids, "Should include attacks where converter key is absent (None)" + assert "conv_3" in conv_ids, "Should include attacks with empty converter list" + assert "conv_4" in conv_ids, "Should include attacks with empty attack_identifier" + + +def test_get_attack_results_converter_classes_single_match(sqlite_instance: MemoryInterface): + """Test that converter_classes with one class returns attacks using that converter.""" + ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["ROT13Converter"]) + ar3 = _make_attack_result_with_identifier("conv_3", "Attack", ["Base64Converter", "ROT13Converter"]) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + results = sqlite_instance.get_attack_results(converter_classes=["Base64Converter"]) + conv_ids = {r.conversation_id for r in results} + assert conv_ids == {"conv_1", "conv_3"} + + +def test_get_attack_results_converter_classes_and_logic(sqlite_instance: MemoryInterface): + """Test that multiple converter_classes use AND logic — all must be present.""" + ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["ROT13Converter"]) + ar3 = _make_attack_result_with_identifier("conv_3", "Attack", ["Base64Converter", "ROT13Converter"]) + ar4 = _make_attack_result_with_identifier("conv_4", "Attack", ["Base64Converter", "ROT13Converter", "UrlConverter"]) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) + + results = sqlite_instance.get_attack_results(converter_classes=["Base64Converter", "ROT13Converter"]) + conv_ids = {r.conversation_id for r in results} + # conv_3 and conv_4 have both; conv_1 and conv_2 have only one + assert conv_ids == {"conv_3", "conv_4"} + + +def test_get_attack_results_converter_classes_case_insensitive(sqlite_instance: MemoryInterface): + """Test that converter class matching is case-insensitive.""" + ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results(converter_classes=["base64converter"]) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + +def test_get_attack_results_converter_classes_no_match(sqlite_instance: MemoryInterface): + """Test that converter_classes filter returns empty when no attack has the converter.""" + ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results(converter_classes=["NonExistentConverter"]) + assert len(results) == 0 + + +def test_get_attack_results_attack_class_and_converter_classes_combined(sqlite_instance: MemoryInterface): + """Test combining attack_class and converter_classes filters.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack", ["Base64Converter"]) + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack", ["ROT13Converter"]) + ar4 = _make_attack_result_with_identifier("conv_4", "CrescendoAttack") # No converters + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) + + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=["Base64Converter"]) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + +def test_get_attack_results_attack_class_with_no_converters(sqlite_instance: MemoryInterface): + """Test combining attack_class with converter_classes=[] (no converters).""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") # No converters + ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # No converters + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=[]) + assert len(results) == 1 + assert results[0].conversation_id == "conv_2" + + +# ============================================================================ +# Unique attack class and converter class name tests +# ============================================================================ + + +def test_get_unique_attack_class_names_empty(sqlite_instance: MemoryInterface): + """Test that no attacks returns empty list.""" + result = sqlite_instance.get_unique_attack_class_names() + assert result == [] + + +def test_get_unique_attack_class_names_sorted_unique(sqlite_instance: MemoryInterface): + """Test that unique class names are returned sorted, with duplicates removed.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + result = sqlite_instance.get_unique_attack_class_names() + assert result == ["CrescendoAttack", "ManualAttack"] + + +def test_get_unique_attack_class_names_skips_empty_identifier(sqlite_instance: MemoryInterface): + """Test that attacks with empty attack_identifier (no class_name) are excluded.""" + ar_no_id = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} + ar_with_id = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar_no_id, ar_with_id]) + + result = sqlite_instance.get_unique_attack_class_names() + assert result == ["CrescendoAttack"] + + +def test_get_unique_converter_class_names_empty(sqlite_instance: MemoryInterface): + """Test that no attacks returns empty list.""" + result = sqlite_instance.get_unique_converter_class_names() + assert result == [] + + +def test_get_unique_converter_class_names_sorted_unique(sqlite_instance: MemoryInterface): + """Test that unique converter class names are returned sorted, with duplicates removed.""" + ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter", "ROT13Converter"]) + ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["Base64Converter"]) + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + result = sqlite_instance.get_unique_converter_class_names() + assert result == ["Base64Converter", "ROT13Converter"] + + +def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: MemoryInterface): + """Test that attacks with no converters don't contribute names.""" + ar_no_conv = _make_attack_result_with_identifier("conv_1", "Attack") # No converters + ar_with_conv = _make_attack_result_with_identifier("conv_2", "Attack", ["Base64Converter"]) + ar_empty_id = create_attack_result("conv_3", 3) # Empty attack_identifier + sqlite_instance.add_attack_results_to_memory(attack_results=[ar_no_conv, ar_with_conv, ar_empty_id]) + + result = sqlite_instance.get_unique_converter_class_names() + assert result == ["Base64Converter"] From 0b8f148f773b9a3683d2b4c60396ba3254aa42d2 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 13 Feb 2026 09:37:03 -0800 Subject: [PATCH 31/35] Preserve message roles for frontend and add drift detection tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Widen role fields in DTOs from Literal["user", "assistant", "system"] to ChatMessageRole so simulated_assistant (and other roles like tool, developer) flow through to the frontend without remapping - Switch mapper from deprecated .role property (which collapses simulated_assistant → assistant) to get_role_for_storage() which preserves the actual stored role - Add TestDomainModelFieldsExist: 16 parametrized tests that verify every field the mappers access still exists on AttackIdentifier, TargetIdentifier, and ConverterIdentifier dataclasses - Update mock pieces in test_mappers.py and test_attack_service.py to configure get_role_for_storage() Files changed: pyrit/backend/models/attacks.py pyrit/backend/mappers/attack_mappers.py tests/unit/backend/test_mappers.py tests/unit/backend/test_attack_service.py tests/unit/backend/test_converter_service.py --- pyrit/backend/mappers/attack_mappers.py | 2 +- pyrit/backend/models/attacks.py | 8 +- tests/unit/backend/test_attack_service.py | 2 + tests/unit/backend/test_converter_service.py | 116 +++++++++++++++++-- tests/unit/backend/test_mappers.py | 78 +++++++++++++ 5 files changed, 193 insertions(+), 13 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index b278f7fb7a..d8f316ddee 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -164,7 +164,7 @@ def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: messages.append( Message( turn_number=first.sequence if first else 0, - role=first.role if first else "user", + role=first.get_role_for_storage() if first else "user", pieces=pieces, created_at=first.timestamp if first else datetime.now(timezone.utc), ) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 60cb5b06f2..9bcf13ae19 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo -from pyrit.models import PromptResponseError +from pyrit.models import ChatMessageRole, PromptResponseError class Score(BaseModel): @@ -59,7 +59,7 @@ class Message(BaseModel): """A message within an attack.""" turn_number: int = Field(..., description="Turn number in the conversation (1-indexed)") - role: Literal["user", "assistant", "system"] = Field(..., description="Message role") + role: ChatMessageRole = Field(..., description="Message role") pieces: List[MessagePiece] = Field(..., description="Message pieces (multimodal support)") created_at: datetime = Field(..., description="Message creation timestamp") @@ -159,7 +159,7 @@ class MessagePieceRequest(BaseModel): class PrependedMessageRequest(BaseModel): """A message to prepend to the attack (for system prompt/branching).""" - role: Literal["user", "assistant", "system"] = Field(..., description="Message role") + role: ChatMessageRole = Field(..., description="Message role") pieces: List[MessagePieceRequest] = Field(..., description="Message pieces (supports multimodal)", max_length=50) @@ -206,7 +206,7 @@ class AddMessageRequest(BaseModel): in memory without sending (useful for system messages, context injection). """ - role: Literal["user", "assistant", "system"] = Field(default="user", description="Message role") + role: ChatMessageRole = Field(default="user", description="Message role") pieces: List[MessagePieceRequest] = Field(..., description="Message pieces", max_length=50) send: bool = Field( default=True, diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 0dbf1da9d6..15aff12d79 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -100,6 +100,7 @@ def make_mock_piece( piece.id = "piece-id" piece.conversation_id = conversation_id piece.role = role + piece.get_role_for_storage.return_value = role piece.sequence = sequence piece.original_value = original_value piece.converted_value = converted_value @@ -1171,6 +1172,7 @@ async def test_get_attack_with_messages_translates_correctly(self, attack_servic mock_piece.response_error = None mock_piece.sequence = 0 mock_piece.role = "user" + mock_piece.get_role_for_storage.return_value = "user" mock_piece.timestamp = datetime.now(timezone.utc) mock_piece.scores = None diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 9b8fc2e6d6..52c2199913 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -361,22 +361,122 @@ def _get_all_converter_names() -> list[str]: def _try_instantiate_converter(converter_name: str): """ - Try to instantiate a converter with no arguments. + Try to instantiate a converter with minimal representative arguments. + + Uses mock objects for complex dependencies (PromptChatTarget, PromptConverter) + and provides minimal valid values for simple required parameters so that the + identifier extraction test covers ALL converters without skipping. Returns: Tuple of (converter_instance, error_message). If successful, error_message is None. If failed, converter_instance is None and error_message explains why. """ + import inspect + import tempfile + from pathlib import Path + from unittest.mock import MagicMock + + from pyrit.common.apply_defaults import _RequiredValueSentinel + from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget + + # Converters requiring external credentials or resources that can't be mocked + # at the constructor level — these validate env vars / files in __init__ body + _SKIP_CONVERTERS = { + "AzureSpeechAudioToTextConverter", # requires AZURE_SPEECH_REGION env var + "AzureSpeechTextToAudioConverter", # requires AZURE_SPEECH_REGION env var + "TransparencyAttackConverter", # requires a real JPEG image file on disk + } + + # Converter-specific overrides for params with validation + _OVERRIDES: dict = { + "CodeChameleonConverter": {"encrypt_type": "reverse"}, + "SearchReplaceConverter": {"pattern": "foo", "replace": "bar"}, + "PersuasionConverter": {"persuasion_technique": "logical_appeal"}, + } + converter_cls = getattr(prompt_converter, converter_name, None) if converter_cls is None: return None, f"Converter {converter_name} not found in prompt_converter module" + if converter_name in _SKIP_CONVERTERS: + return None, None # Signal to skip without failure + + # Build minimal kwargs based on constructor signature + sig = inspect.signature(converter_cls.__init__) + kwargs: dict = {} + + for pname, param in sig.parameters.items(): + if pname in ("self", "args", "kwargs"): + continue + + # Check if this param has a REQUIRED_VALUE sentinel as its default + is_required_value = isinstance(param.default, _RequiredValueSentinel) + has_no_default = param.default is inspect.Parameter.empty + + if not has_no_default and not is_required_value: + continue # Has a real default — skip + + # Check overrides first + if converter_name in _OVERRIDES and pname in _OVERRIDES[converter_name]: + kwargs[pname] = _OVERRIDES[converter_name][pname] + continue + + ann = param.annotation + ann_str = str(ann) if ann is not inspect.Parameter.empty else "" + + # PromptChatTarget — mock it with a proper identifier + if ann is not inspect.Parameter.empty and ( + (isinstance(ann, type) and issubclass(ann, PromptChatTarget)) or "PromptChatTarget" in ann_str + ): + mock_target = MagicMock(spec=PromptChatTarget) + mock_target.__class__.__name__ = "MockChatTarget" + # Configure get_identifier() to return a proper identifier-like object + # so that _create_identifier can extract class_name, model_name, etc. + mock_id = MagicMock() + mock_id.class_name = "MockChatTarget" + mock_id.model_name = "test-model" + mock_id.temperature = None + mock_id.top_p = None + mock_target.get_identifier.return_value = mock_id + kwargs[pname] = mock_target + # PromptConverter — use a real simple converter to avoid JSON serialization issues + elif "PromptConverter" in ann_str: + kwargs[pname] = Base64Converter() + # TextSelectionStrategy — use a real concrete strategy + elif "TextSelectionStrategy" in ann_str: + from pyrit.prompt_converter.text_selection_strategy import AllWordsSelectionStrategy + + kwargs[pname] = AllWordsSelectionStrategy() + # TextJailBreak — use string template + elif "TextJailBreak" in ann_str: + from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak + + kwargs[pname] = TextJailBreak(string_template="Test {{ prompt }}") + # Path — use a temp JPEG file + elif ann is Path or "Path" in ann_str: + tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) + # Minimal valid JPEG header + tmp.write(b"\xff\xd8\xff\xe0\x00\x10JFIF\x00") + tmp.close() + kwargs[pname] = Path(tmp.name) + # str + elif ann is str or ann_str == "": + kwargs[pname] = "test_value" + # int + elif ann is int or ann_str == "": + kwargs[pname] = 1 + # float + elif ann is float or ann_str == "": + kwargs[pname] = 0.5 + else: + kwargs[pname] = "test_value" + try: - instance = converter_cls() + instance = converter_cls(**kwargs) return instance, None except Exception as e: - return None, f"Could not instantiate {converter_name} with no args: {e}" + return None, f"Could not instantiate {converter_name}: {e}" # Get all converter names dynamically @@ -396,19 +496,19 @@ def test_build_instance_from_converter(self, converter_name: str) -> None: """ Test that _build_instance_from_object works with each converter. - For converters that can be instantiated with no arguments, verifies: + Instantiates every converter with minimal representative arguments + (using mocks for complex dependencies like PromptChatTarget) and verifies: - converter_id is set correctly - converter_type matches the class name - supported_input_types and supported_output_types are lists - - For converters requiring arguments, the test is skipped (since we can't - know the required parameters without external configuration). """ # Try to instantiate the converter converter_instance, error = _try_instantiate_converter(converter_name) + if converter_instance is None and error is None: + pytest.skip(f"{converter_name} requires external credentials/resources") if error: - pytest.skip(error) + pytest.fail(error) # Build the instance using the service method service = ConverterService() diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 67ce7aa32b..8e442e0acf 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -8,9 +8,12 @@ without any database or service dependencies. """ +import dataclasses from datetime import datetime, timezone from unittest.mock import MagicMock +import pytest + from pyrit.backend.mappers.attack_mappers import ( _collect_labels_from_pieces, _infer_mime_type, @@ -85,6 +88,7 @@ def _make_mock_piece( p.original_value_data_type = "text" p.response_error = "none" p.role = "user" + p.get_role_for_storage.return_value = "user" p.timestamp = datetime.now(timezone.utc) p.scores = [] return p @@ -697,3 +701,77 @@ def test_none_input_output_types_returns_empty_lists(self) -> None: assert result.supported_output_types == [] assert result.converter_specific_params is None assert result.sub_converter_ids is None + + +# ============================================================================ +# Drift Detection Tests – verify mapper-accessed fields exist on domain models +# ============================================================================ + + +class TestDomainModelFieldsExist: + """Lightweight safety-net: ensure fields the mappers access still exist on the domain dataclasses. + + If a domain model field is renamed or removed, these tests fail immediately – + before a mapper silently starts returning incorrect data. + """ + + # -- AttackIdentifier fields used in attack_mappers.py -------------------- + + @pytest.mark.parametrize( + "field_name", + [ + "class_name", + "attack_specific_params", + "objective_target_identifier", + "request_converter_identifiers", + ], + ) + def test_attack_identifier_has_field(self, field_name: str) -> None: + from pyrit.identifiers.attack_identifier import AttackIdentifier + + field_names = {f.name for f in dataclasses.fields(AttackIdentifier)} + assert field_name in field_names, ( + f"AttackIdentifier is missing '{field_name}' – attack_mappers.py depends on this field" + ) + + # -- TargetIdentifier fields used in attack_mappers.py & target_mappers.py + + @pytest.mark.parametrize( + "field_name", + [ + "class_name", + "unique_name", + "endpoint", + "model_name", + "temperature", + "top_p", + "max_requests_per_minute", + "target_specific_params", + ], + ) + def test_target_identifier_has_field(self, field_name: str) -> None: + from pyrit.identifiers.target_identifier import TargetIdentifier + + field_names = {f.name for f in dataclasses.fields(TargetIdentifier)} + assert field_name in field_names, ( + f"TargetIdentifier is missing '{field_name}' – target_mappers.py depends on this field" + ) + + # -- ConverterIdentifier fields used in converter_mappers.py -------------- + + @pytest.mark.parametrize( + "field_name", + [ + "class_name", + "supported_input_types", + "supported_output_types", + "converter_specific_params", + ], + ) + def test_converter_identifier_has_field(self, field_name: str) -> None: + from pyrit.identifiers.converter_identifier import ConverterIdentifier + + field_names = {f.name for f in dataclasses.fields(ConverterIdentifier)} + assert field_name in field_names, ( + f"ConverterIdentifier is missing '{field_name}' – converter_mappers.py depends on this field" + ) From 8ed0ce8cb643f3b237ca3cf66dac6a9e860b5da8 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 13 Feb 2026 10:03:06 -0800 Subject: [PATCH 32/35] fix py310 test with mime type for wav --- tests/unit/backend/test_mappers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 8e442e0acf..ab158408e1 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -328,7 +328,8 @@ def test_mime_type_for_audio(self) -> None: result = pyrit_messages_to_dto([msg]) - assert result[0].pieces[0].original_value_mime_type == "audio/x-wav" + # Python 3.10 returns "audio/wav", 3.11+ returns "audio/x-wav" + assert result[0].pieces[0].original_value_mime_type in ("audio/wav", "audio/x-wav") assert result[0].pieces[0].converted_value_mime_type == "audio/mpeg" @@ -731,7 +732,8 @@ def test_attack_identifier_has_field(self, field_name: str) -> None: field_names = {f.name for f in dataclasses.fields(AttackIdentifier)} assert field_name in field_names, ( - f"AttackIdentifier is missing '{field_name}' – attack_mappers.py depends on this field" + f"AttackIdentifier is missing '{field_name}' – " + f"attack_mappers.py depends on this field" ) # -- TargetIdentifier fields used in attack_mappers.py & target_mappers.py @@ -754,7 +756,8 @@ def test_target_identifier_has_field(self, field_name: str) -> None: field_names = {f.name for f in dataclasses.fields(TargetIdentifier)} assert field_name in field_names, ( - f"TargetIdentifier is missing '{field_name}' – target_mappers.py depends on this field" + f"TargetIdentifier is missing '{field_name}' – " + f"target_mappers.py depends on this field" ) # -- ConverterIdentifier fields used in converter_mappers.py -------------- @@ -773,5 +776,6 @@ def test_converter_identifier_has_field(self, field_name: str) -> None: field_names = {f.name for f in dataclasses.fields(ConverterIdentifier)} assert field_name in field_names, ( - f"ConverterIdentifier is missing '{field_name}' – converter_mappers.py depends on this field" + f"ConverterIdentifier is missing '{field_name}' – " + f"converter_mappers.py depends on this field" ) From d3418b432c48781367c0d6656e9528ca6467ebec Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 13 Feb 2026 10:32:05 -0800 Subject: [PATCH 33/35] fix ruff --- tests/unit/backend/test_mappers.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index ab158408e1..5ff03515e5 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -732,8 +732,7 @@ def test_attack_identifier_has_field(self, field_name: str) -> None: field_names = {f.name for f in dataclasses.fields(AttackIdentifier)} assert field_name in field_names, ( - f"AttackIdentifier is missing '{field_name}' – " - f"attack_mappers.py depends on this field" + f"AttackIdentifier is missing '{field_name}' – attack_mappers.py depends on this field" ) # -- TargetIdentifier fields used in attack_mappers.py & target_mappers.py @@ -756,8 +755,7 @@ def test_target_identifier_has_field(self, field_name: str) -> None: field_names = {f.name for f in dataclasses.fields(TargetIdentifier)} assert field_name in field_names, ( - f"TargetIdentifier is missing '{field_name}' – " - f"target_mappers.py depends on this field" + f"TargetIdentifier is missing '{field_name}' – target_mappers.py depends on this field" ) # -- ConverterIdentifier fields used in converter_mappers.py -------------- @@ -776,6 +774,5 @@ def test_converter_identifier_has_field(self, field_name: str) -> None: field_names = {f.name for f in dataclasses.fields(ConverterIdentifier)} assert field_name in field_names, ( - f"ConverterIdentifier is missing '{field_name}' – " - f"converter_mappers.py depends on this field" + f"ConverterIdentifier is missing '{field_name}' – converter_mappers.py depends on this field" ) From 5e2a6abe418eab5df4c809476f8722eade62e031 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sat, 14 Feb 2026 10:24:51 -0800 Subject: [PATCH 34/35] fix: cross-platform reliability for backend startup and e2e tests - Fix .lower() crash on int log_level in pyrit_backend CLI - Guard against double initialization in lifespan (CentralMemory check) - Remove CREATE_NEW_PROCESS_GROUP in dev.py (broke Ctrl+C on Windows) - Add sys.stdout/stderr.reconfigure(errors="replace") for Windows cp1252 - Route e2e tests through Vite proxy with beforeAll health polling - Use 127.0.0.1 instead of localhost to avoid IPv6 ECONNREFUSED - Suppress noisy proxy error logs via custom Vite logger - Return 502 on proxy errors so polling doesn't hang - Increase Playwright webServer timeout to 120s for CI - Add test for lifespan skip-init when CentralMemory is already set --- frontend/dev.py | 33 ++++++++++----------------------- frontend/e2e/api.spec.ts | 26 ++++++++++++++++++++++++-- frontend/playwright.config.ts | 6 ++++-- frontend/vite.config.ts | 32 ++++++++++++++++++++++++++++++-- pyrit/backend/main.py | 7 +++++-- pyrit/cli/pyrit_backend.py | 7 ++++++- tests/unit/backend/test_main.py | 21 ++++++++++++++++++--- uv.lock | 12 ++++-------- 8 files changed, 101 insertions(+), 43 deletions(-) diff --git a/frontend/dev.py b/frontend/dev.py index 772d4a2996..71acafeb53 100644 --- a/frontend/dev.py +++ b/frontend/dev.py @@ -8,12 +8,17 @@ import json import os import platform -import signal import subprocess import sys import time from pathlib import Path +# Ensure emoji and other Unicode characters don't crash on Windows consoles +# that use legacy encodings like cp1252. Characters that can't be encoded +# are replaced with '?' instead of raising UnicodeEncodeError. +sys.stdout.reconfigure(errors="replace") # type: ignore[attr-defined] +sys.stderr.reconfigure(errors="replace") # type: ignore[attr-defined] + # Determine workspace root (parent of frontend directory) FRONTEND_DIR = Path(__file__).parent.absolute() WORKSPACE_ROOT = FRONTEND_DIR.parent @@ -115,14 +120,7 @@ def start_backend(initializers: list[str] | None = None): cmd.extend(["--initializers"] + initializers) # Start backend - if is_windows(): - backend = subprocess.Popen( - cmd, - env=env, - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if is_windows() else 0, - ) - else: - backend = subprocess.Popen(cmd, env=env) + backend = subprocess.Popen(cmd, env=env) return backend @@ -136,14 +134,7 @@ def start_frontend(): # Start frontend process npm_cmd = "npm.cmd" if is_windows() else "npm" - - if is_windows(): - frontend = subprocess.Popen( - [npm_cmd, "run", "dev"], - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if is_windows() else 0, - ) - else: - frontend = subprocess.Popen([npm_cmd, "run", "dev"]) + frontend = subprocess.Popen([npm_cmd, "run", "dev"]) return frontend @@ -183,12 +174,8 @@ def wait_for_interrupt(backend, frontend): # Terminate processes try: - if is_windows(): - backend.send_signal(signal.CTRL_BREAK_EVENT) - frontend.send_signal(signal.CTRL_BREAK_EVENT) - else: - backend.terminate() - frontend.terminate() + backend.terminate() + frontend.terminate() # Wait for clean shutdown backend.wait(timeout=5) diff --git a/frontend/e2e/api.spec.ts b/frontend/e2e/api.spec.ts index 1f151ac09f..283f289093 100644 --- a/frontend/e2e/api.spec.ts +++ b/frontend/e2e/api.spec.ts @@ -1,8 +1,30 @@ import { test, expect } from "@playwright/test"; +// API tests go through the Vite dev server proxy (/api -> backend:8000) +// rather than hitting the backend directly, so they work as soon as +// Playwright's webServer (port 3000) is ready. + test.describe("API Health Check", () => { + // The backend may still be starting when Vite (port 3000) is already up. + // Poll the health endpoint through the proxy until the backend is ready. + test.beforeAll(async ({ request }) => { + const maxWait = 30_000; + const interval = 1_000; + const start = Date.now(); + while (Date.now() - start < maxWait) { + try { + const resp = await request.get("/api/health"); + if (resp.ok()) return; + } catch { + // Backend not ready yet + } + await new Promise((r) => setTimeout(r, interval)); + } + throw new Error("Backend did not become healthy within 30 seconds"); + }); + test("should have healthy backend API", async ({ request }) => { - const response = await request.get("http://localhost:8000/api/health"); + const response = await request.get("/api/health"); expect(response.ok()).toBe(true); const data = await response.json(); @@ -10,7 +32,7 @@ test.describe("API Health Check", () => { }); test("should get version from API", async ({ request }) => { - const response = await request.get("http://localhost:8000/api/version"); + const response = await request.get("/api/version"); expect(response.ok()).toBe(true); const data = await response.json(); diff --git a/frontend/playwright.config.ts b/frontend/playwright.config.ts index a35c626030..6958f1208d 100644 --- a/frontend/playwright.config.ts +++ b/frontend/playwright.config.ts @@ -30,8 +30,10 @@ export default defineConfig({ /* Automatically start servers before running tests */ webServer: { command: process.env.CI ? "cd .. && uv run python frontend/dev.py" : "python dev.py", - url: "http://localhost:3000", + // Use 127.0.0.1 to avoid Node.js 17+ resolving localhost to IPv6 ::1 + url: "http://127.0.0.1:3000", reuseExistingServer: !process.env.CI, - timeout: 30000, + // CI needs extra time for uv sync + backend startup + timeout: 120_000, }, }); diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 4b1b52c1d1..ed7b84f1cf 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -1,9 +1,26 @@ -import { defineConfig } from 'vite' +import { createLogger, defineConfig } from 'vite' import react from '@vitejs/plugin-react' import path from 'path' +// Suppress noisy ECONNREFUSED proxy errors while the backend is starting. +// Without this, Vite logs dozens of "http proxy error" stack traces. +const logger = createLogger() +const originalError = logger.error +let proxyWarned = false +logger.error = (msg, options) => { + if (typeof msg === 'string' && msg.includes('http proxy error')) { + if (!proxyWarned) { + console.log('[vite] Waiting for backend on port 8000...') + proxyWarned = true + } + return + } + originalError(msg, options) +} + // https://vitejs.dev/config/ export default defineConfig({ + customLogger: logger, plugins: [react()], resolve: { alias: { @@ -22,8 +39,19 @@ export default defineConfig({ cors: true, proxy: { '/api': { - target: 'http://localhost:8000', + // Use 127.0.0.1 to avoid Node.js 17+ resolving localhost to IPv6 ::1 + target: 'http://127.0.0.1:8000', changeOrigin: true, + // Return 502 on proxy errors so in-flight requests fail fast + // instead of hanging until the backend comes up. + configure: (proxy) => { + proxy.on('error', (_err, _req, res) => { + if (res && 'writeHead' in res && !res.headersSent) { + (res as import('http').ServerResponse).writeHead(502); + (res as import('http').ServerResponse).end(); + } + }); + }, }, }, watch: { diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index ae610a0531..f346a3e7d6 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -17,6 +17,7 @@ import pyrit from pyrit.backend.middleware import register_error_handlers from pyrit.backend.routes import attacks, converters, health, labels, targets, version +from pyrit.memory import CentralMemory from pyrit.setup.initialization import initialize_pyrit_async # Check for development mode from environment variable @@ -26,8 +27,10 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Manage application startup and shutdown lifecycle.""" - # Startup: initialize PyRIT to load .env and .env.local files - await initialize_pyrit_async(memory_db_type="SQLite") + # When launched via pyrit_backend CLI, initialization is already done. + # Only initialize here for standalone uvicorn usage (e.g. uvicorn pyrit.backend.main:app). + if not CentralMemory._memory_instance: + await initialize_pyrit_async(memory_db_type="SQLite") yield # Shutdown: nothing to clean up currently diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index 1043135958..40aa0521f3 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -12,6 +12,11 @@ from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from typing import Optional +# Ensure emoji and other Unicode characters don't crash on Windows consoles +# that use legacy encodings like cp1252. +sys.stdout.reconfigure(errors="replace") # type: ignore[attr-defined] +sys.stderr.reconfigure(errors="replace") # type: ignore[attr-defined] + from pyrit.cli import frontend_core @@ -175,7 +180,7 @@ async def initialize_and_run(*, parsed_args: Namespace) -> int: "pyrit.backend.main:app", host=parsed_args.host, port=parsed_args.port, - log_level=parsed_args.log_level.lower(), + log_level=parsed_args.log_level, reload=parsed_args.reload, ) server = uvicorn.Server(config) diff --git a/tests/unit/backend/test_main.py b/tests/unit/backend/test_main.py index 51e935d61e..6bc21ceb3c 100644 --- a/tests/unit/backend/test_main.py +++ b/tests/unit/backend/test_main.py @@ -20,13 +20,28 @@ class TestLifespan: @pytest.mark.asyncio async def test_lifespan_initializes_pyrit_and_yields(self) -> None: - """Test that lifespan calls initialize_pyrit_async on startup and yields.""" - with patch("pyrit.backend.main.initialize_pyrit_async", new_callable=AsyncMock) as mock_init: + """Test that lifespan calls initialize_pyrit_async on startup when memory is not set.""" + with ( + patch("pyrit.backend.main.CentralMemory._memory_instance", None), + patch("pyrit.backend.main.initialize_pyrit_async", new_callable=AsyncMock) as mock_init, + ): async with lifespan(app): pass # The body of the context manager is the "yield" phase mock_init.assert_awaited_once_with(memory_db_type="SQLite") + @pytest.mark.asyncio + async def test_lifespan_skips_init_when_already_initialized(self) -> None: + """Test that lifespan skips initialization when CentralMemory is already set.""" + with ( + patch("pyrit.backend.main.CentralMemory._memory_instance", MagicMock()), + patch("pyrit.backend.main.initialize_pyrit_async", new_callable=AsyncMock) as mock_init, + ): + async with lifespan(app): + pass + + mock_init.assert_not_awaited() + class TestSetupFrontend: """Tests for the setup_frontend function.""" @@ -81,4 +96,4 @@ def test_frontend_missing_warns_but_continues(self) -> None: # Verify warning was printed printed = " ".join(str(c) for c in mock_print.call_args_list) - assert "WARNING" in printed + assert "warning" in printed.lower() diff --git a/uv.lock b/uv.lock index 45bb4bcbb4..413892500c 100644 --- a/uv.lock +++ b/uv.lock @@ -16,14 +16,6 @@ resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", ] -[manifest] -constraints = [ - { name = "google-auth", specifier = "!=2.46.0" }, - { name = "nbconvert", specifier = ">7.16.6" }, - { name = "numpy", specifier = "!=2.4.0" }, - { name = "virtualenv", specifier = ">=20.36.1" }, -] - [[package]] name = "absl-py" version = "2.3.1" @@ -6060,8 +6052,10 @@ dev = [ { name = "jupyter" }, { name = "jupyter-book" }, { name = "jupytext" }, + { name = "matplotlib" }, { name = "mock-alchemy" }, { name = "mypy" }, + { name = "pandas" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -6155,6 +6149,7 @@ requires-dist = [ { name = "jupyter", marker = "extra == 'dev'", specifier = ">=1.1.1" }, { name = "jupyter-book", marker = "extra == 'dev'", specifier = "==1.0.4" }, { name = "jupytext", marker = "extra == 'dev'", specifier = ">=1.17.1" }, + { name = "matplotlib", marker = "extra == 'dev'", specifier = ">=3.10.0" }, { name = "ml-collections", marker = "extra == 'all'", specifier = ">=1.1.0" }, { name = "ml-collections", marker = "extra == 'gcg'", specifier = ">=1.1.0" }, { name = "mlflow", marker = "extra == 'all'", specifier = ">=2.22.0" }, @@ -6168,6 +6163,7 @@ requires-dist = [ { name = "opencv-python", marker = "extra == 'all'", specifier = ">=4.11.0.86" }, { name = "opencv-python", marker = "extra == 'opencv'", specifier = ">=4.11.0.86" }, { name = "openpyxl", specifier = ">=3.1.5" }, + { name = "pandas", marker = "extra == 'dev'", specifier = ">=2.2.0" }, { name = "pillow", specifier = ">=12.1.0" }, { name = "playwright", marker = "extra == 'all'", specifier = ">=1.49.0" }, { name = "playwright", marker = "extra == 'playwright'", specifier = ">=1.49.0" }, From 49822fce6f435b10ac769b168ce4657cd8017ad3 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sat, 14 Feb 2026 11:28:15 -0800 Subject: [PATCH 35/35] refactor: improve type safety and fix mypy errors across mappers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `from __future__ import annotations` to attack_mappers.py and remove string quotes from ChatMessageRole type hint - Use ScorerIdentifier.class_name instead of dict .get("__type__") for scorer type extraction (ScorerIdentifier is a dataclass, not dict) - Convert score_value str→float when building Score DTO - Annotate prompt_metadata as Optional[Dict[str, str | int]] to match MessagePiece parameter type - Fix type: ignore comments in pyrit_backend.py (attr-defined → union-attr) - Replace placeholder ConverterIdentifier in converter_registry with the real one from pyrit.identifiers, passing supported_input/output_types - Remove redundant string quotes from TargetRegistry return type - Update test mock to use real ScorerIdentifier instead of plain dict All 384 tests pass, pre-commit (including mypy strict) clean. --- pyrit/backend/mappers/__init__.py | 2 - pyrit/backend/mappers/attack_mappers.py | 57 +++++++------------ pyrit/backend/mappers/converter_mappers.py | 5 +- pyrit/backend/mappers/target_mappers.py | 21 ++++--- pyrit/cli/pyrit_backend.py | 4 +- pyrit/models/attack_result.py | 5 +- .../instance_registries/converter_registry.py | 16 +----- .../instance_registries/target_registry.py | 2 +- tests/unit/backend/test_mappers.py | 44 +++++++------- 9 files changed, 67 insertions(+), 89 deletions(-) diff --git a/pyrit/backend/mappers/__init__.py b/pyrit/backend/mappers/__init__.py index 780ffc8214..63577e6efc 100644 --- a/pyrit/backend/mappers/__init__.py +++ b/pyrit/backend/mappers/__init__.py @@ -10,7 +10,6 @@ from pyrit.backend.mappers.attack_mappers import ( attack_result_to_summary, - map_outcome, pyrit_messages_to_dto, pyrit_scores_to_dto, request_piece_to_pyrit_message_piece, @@ -26,7 +25,6 @@ __all__ = [ "attack_result_to_summary", "converter_object_to_instance", - "map_outcome", "pyrit_messages_to_dto", "pyrit_scores_to_dto", "request_piece_to_pyrit_message_piece", diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index d8f316ddee..d155b0e11a 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + """ Attack mappers – domain ↔ DTO translation for attack-related models. @@ -11,43 +13,30 @@ import mimetypes import uuid from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional, Sequence, cast +from typing import Dict, List, Optional, Sequence, cast from pyrit.backend.models.attacks import ( AddMessageRequest, AttackSummary, Message, MessagePiece, + MessagePieceRequest, Score, ) -from pyrit.models import AttackOutcome, AttackResult, ChatMessageRole, PromptDataType +from pyrit.models import AttackResult, ChatMessageRole, PromptDataType from pyrit.models import Message as PyritMessage from pyrit.models import MessagePiece as PyritMessagePiece +from pyrit.models import Score as PyritScore # ============================================================================ # Domain → DTO (for API responses) # ============================================================================ -def map_outcome(outcome: AttackOutcome) -> Optional[Literal["undetermined", "success", "failure"]]: - """ - Map AttackOutcome enum to API outcome string. - - Returns: - Outcome string ('success', 'failure', 'undetermined') or None. - """ - if outcome == AttackOutcome.SUCCESS: - return "success" - elif outcome == AttackOutcome.FAILURE: - return "failure" - else: - return "undetermined" - - def attack_result_to_summary( ar: AttackResult, *, - pieces: Sequence[Any], + pieces: Sequence[PyritMessagePiece], ) -> AttackSummary: """ Build an AttackSummary DTO from an AttackResult and its message pieces. @@ -83,7 +72,7 @@ def attack_result_to_summary( target_unique_name=target_id.unique_name if target_id else None, target_type=target_id.class_name if target_id else None, converters=[c.class_name for c in converter_ids] if converter_ids else [], - outcome=map_outcome(ar.outcome), + outcome=ar.outcome.value, last_message_preview=last_preview, message_count=message_count, labels=_collect_labels_from_pieces(pieces), @@ -92,7 +81,7 @@ def attack_result_to_summary( ) -def pyrit_scores_to_dto(scores: List[Any]) -> List[Score]: +def pyrit_scores_to_dto(scores: List[PyritScore]) -> List[Score]: """ Translate PyRIT score objects to backend Score DTOs. @@ -102,8 +91,8 @@ def pyrit_scores_to_dto(scores: List[Any]) -> List[Score]: return [ Score( score_id=str(s.id), - scorer_type=s.scorer_class_identifier.get("__type__", "unknown"), - score_value=s.score_value, + scorer_type=s.scorer_class_identifier.class_name, + score_value=float(s.score_value), score_rationale=s.score_rationale, scored_at=s.timestamp, ) @@ -111,7 +100,7 @@ def pyrit_scores_to_dto(scores: List[Any]) -> List[Score]: ] -def _infer_mime_type(*, value: Optional[str], data_type: str) -> Optional[str]: +def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Optional[str]: """ Infer MIME type from a value and its data type. @@ -132,7 +121,7 @@ def _infer_mime_type(*, value: Optional[str], data_type: str) -> Optional[str]: return mime_type -def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: +def pyrit_messages_to_dto(pyrit_messages: List[PyritMessage]) -> List[Message]: """ Translate PyRIT messages to backend Message DTOs. @@ -154,7 +143,7 @@ def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: converted_value_mime_type=_infer_mime_type( value=p.converted_value, data_type=p.converted_value_data_type or "text" ), - scores=pyrit_scores_to_dto(p.scores) if hasattr(p, "scores") and p.scores else [], + scores=pyrit_scores_to_dto(p.scores) if p.scores else [], response_error=p.response_error or "none", ) for p in msg.message_pieces @@ -180,8 +169,8 @@ def pyrit_messages_to_dto(pyrit_messages: List[Any]) -> List[Message]: def request_piece_to_pyrit_message_piece( *, - piece: Any, - role: "ChatMessageRole", + piece: MessagePieceRequest, + role: ChatMessageRole, conversation_id: str, sequence: int, labels: Optional[Dict[str, str]] = None, @@ -199,9 +188,8 @@ def request_piece_to_pyrit_message_piece( Returns: PyritMessagePiece domain object. """ - metadata = {"mime_type": piece.mime_type} if getattr(piece, "mime_type", None) else None - raw_id = getattr(piece, "original_prompt_id", None) - original_prompt_id = uuid.UUID(raw_id) if raw_id else None + metadata: Optional[Dict[str, str | int]] = {"mime_type": piece.mime_type} if piece.mime_type else None + original_prompt_id = uuid.UUID(piece.original_prompt_id) if piece.original_prompt_id else None return PyritMessagePiece( role=role, original_value=piece.original_value, @@ -253,7 +241,7 @@ def request_to_pyrit_message( # ============================================================================ -def _get_preview_from_pieces(pieces: Sequence[Any]) -> Optional[str]: +def _get_preview_from_pieces(pieces: Sequence[PyritMessagePiece]) -> Optional[str]: """ Get a preview of the last message from a list of pieces. @@ -267,7 +255,7 @@ def _get_preview_from_pieces(pieces: Sequence[Any]) -> Optional[str]: return text[:100] + "..." if len(text) > 100 else text -def _collect_labels_from_pieces(pieces: Sequence[Any]) -> Dict[str, str]: +def _collect_labels_from_pieces(pieces: Sequence[PyritMessagePiece]) -> Dict[str, str]: """ Collect labels from message pieces. @@ -279,7 +267,6 @@ def _collect_labels_from_pieces(pieces: Sequence[Any]) -> Dict[str, str]: Label dict, or empty dict if no pieces have labels. """ for p in pieces: - labels = getattr(p, "labels", None) - if labels: - return dict(labels) + if p.labels: + return dict(p.labels) return {} diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py index 78e0c54915..11e7c9aff7 100644 --- a/pyrit/backend/mappers/converter_mappers.py +++ b/pyrit/backend/mappers/converter_mappers.py @@ -5,14 +5,15 @@ Converter mappers – domain → DTO translation for converter-related models. """ -from typing import Any, List, Optional +from typing import List, Optional from pyrit.backend.models.converters import ConverterInstance +from pyrit.prompt_converter import PromptConverter def converter_object_to_instance( converter_id: str, - converter_obj: Any, + converter_obj: PromptConverter, *, sub_converter_ids: Optional[List[str]] = None, ) -> ConverterInstance: diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index 5b2b5b2972..88d426f9a1 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -5,12 +5,11 @@ Target mappers – domain → DTO translation for target-related models. """ -from typing import Any - from pyrit.backend.models.targets import TargetInstance +from pyrit.prompt_target import PromptTarget -def target_object_to_instance(target_unique_name: str, target_obj: Any) -> TargetInstance: +def target_object_to_instance(target_unique_name: str, target_obj: PromptTarget) -> TargetInstance: """ Build a TargetInstance DTO from a registry target object. @@ -24,15 +23,15 @@ def target_object_to_instance(target_unique_name: str, target_obj: Any) -> Targe Returns: TargetInstance DTO with metadata derived from the object. """ - identifier = target_obj.get_identifier() if hasattr(target_obj, "get_identifier") else None + identifier = target_obj.get_identifier() return TargetInstance( target_unique_name=target_unique_name, - target_type=identifier.class_name if identifier else target_obj.__class__.__name__, - endpoint=getattr(identifier, "endpoint", None) if identifier else None, - model_name=getattr(identifier, "model_name", None) if identifier else None, - temperature=getattr(identifier, "temperature", None) if identifier else None, - top_p=getattr(identifier, "top_p", None) if identifier else None, - max_requests_per_minute=getattr(identifier, "max_requests_per_minute", None) if identifier else None, - target_specific_params=getattr(identifier, "target_specific_params", None) if identifier else None, + target_type=identifier.class_name, + endpoint=identifier.endpoint or None, + model_name=identifier.model_name or None, + temperature=identifier.temperature, + top_p=identifier.top_p, + max_requests_per_minute=identifier.max_requests_per_minute, + target_specific_params=identifier.target_specific_params, ) diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index 40aa0521f3..a3e3fe647f 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -14,8 +14,8 @@ # Ensure emoji and other Unicode characters don't crash on Windows consoles # that use legacy encodings like cp1252. -sys.stdout.reconfigure(errors="replace") # type: ignore[attr-defined] -sys.stderr.reconfigure(errors="replace") # type: ignore[attr-defined] +sys.stdout.reconfigure(errors="replace") # type: ignore[union-attr] +sys.stderr.reconfigure(errors="replace") # type: ignore[union-attr] from pyrit.cli import frontend_core diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 7f92612f59..3518ce4626 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -16,9 +16,12 @@ AttackResultT = TypeVar("AttackResultT", bound="AttackResult") -class AttackOutcome(Enum): +class AttackOutcome(str, Enum): """ Enum representing the possible outcomes of an attack. + + Inherits from ``str`` so that values serialize naturally in Pydantic + models and REST responses without a dedicated mapping function. """ # The attack was successful in achieving its objective diff --git a/pyrit/registry/instance_registries/converter_registry.py b/pyrit/registry/instance_registries/converter_registry.py index 509f2fb68b..8571b1a4db 100644 --- a/pyrit/registry/instance_registries/converter_registry.py +++ b/pyrit/registry/instance_registries/converter_registry.py @@ -12,10 +12,9 @@ from __future__ import annotations import logging -from dataclasses import dataclass from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import Identifier +from pyrit.identifiers import ConverterIdentifier from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, @@ -27,15 +26,6 @@ logger = logging.getLogger(__name__) -# Placeholder identifier type until proper ConverterIdentifier is defined -# TODO: Replace with ConverterIdentifier when available -@dataclass(frozen=True) -class ConverterIdentifier(Identifier): - """Temporary identifier type for converters.""" - - pass - - class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ConverterIdentifier]): """ Registry for managing available converter instances. @@ -43,8 +33,6 @@ class ConverterRegistry(BaseInstanceRegistry["PromptConverter", ConverterIdentif This registry stores pre-configured PromptConverter instances (not classes). Converters are registered explicitly via initializers after being instantiated with their required parameters. - - NOTE: This is a placeholder. A full implementation will be added soon. """ @classmethod @@ -105,4 +93,6 @@ def _build_metadata(self, name: str, instance: PromptConverter) -> ConverterIden class_module=instance.__class__.__module__, class_description=f"Converter: {name}", identifier_type="instance", + supported_input_types=instance.SUPPORTED_INPUT_TYPES, + supported_output_types=instance.SUPPORTED_OUTPUT_TYPES, ) diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index 3fcdbb3160..ca7d2c8c59 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -36,7 +36,7 @@ class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): """ @classmethod - def get_registry_singleton(cls) -> "TargetRegistry": + def get_registry_singleton(cls) -> TargetRegistry: """ Get the singleton instance of the TargetRegistry. diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 5ff03515e5..941ff06374 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -18,7 +18,6 @@ _collect_labels_from_pieces, _infer_mime_type, attack_result_to_summary, - map_outcome, pyrit_messages_to_dto, pyrit_scores_to_dto, request_piece_to_pyrit_message_piece, @@ -26,7 +25,7 @@ ) from pyrit.backend.mappers.converter_mappers import converter_object_to_instance from pyrit.backend.mappers.target_mappers import target_object_to_instance -from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models import AttackOutcome, AttackResult # ============================================================================ @@ -98,7 +97,11 @@ def _make_mock_score(): """Create a mock score for mapper tests.""" s = MagicMock() s.id = "score-1" - s.scorer_class_identifier = {"__type__": "TrueFalseScorer"} + s.scorer_class_identifier = ScorerIdentifier( + class_name="TrueFalseScorer", + class_module="pyrit.score", + scorer_type="true_false", + ) s.score_value = 1.0 s.score_rationale = "Looks correct" s.timestamp = datetime.now(timezone.utc) @@ -110,19 +113,6 @@ def _make_mock_score(): # ============================================================================ -class TestMapOutcome: - """Tests for map_outcome function.""" - - def test_maps_success(self) -> None: - assert map_outcome(AttackOutcome.SUCCESS) == "success" - - def test_maps_failure(self) -> None: - assert map_outcome(AttackOutcome.FAILURE) == "failure" - - def test_maps_undetermined(self) -> None: - assert map_outcome(AttackOutcome.UNDETERMINED) == "undetermined" - - class TestAttackResultToSummary: """Tests for attack_result_to_summary function.""" @@ -569,8 +559,9 @@ def test_returns_empty_when_no_pieces(self) -> None: assert _collect_labels_from_pieces([]) == {} def test_returns_empty_when_pieces_have_no_labels(self) -> None: - """Returns empty dict when pieces have no labels attribute.""" - p = MagicMock(spec=[]) + """Returns empty dict when pieces have None/empty labels.""" + p = MagicMock() + p.labels = None assert _collect_labels_from_pieces([p]) == {} def test_skips_pieces_with_empty_labels(self) -> None: @@ -631,15 +622,24 @@ def test_no_endpoint_returns_none(self) -> None: assert result.endpoint is None assert result.model_name is None - def test_no_get_identifier_falls_back_to_class_name(self) -> None: - """Test fallback when target has no get_identifier method.""" - target_obj = MagicMock(spec=[]) - target_obj.__class__ = type("FakeTarget", (), {}) + def test_no_get_identifier_uses_class_name(self) -> None: + """Test that target uses class name from identifier.""" + target_obj = MagicMock() + mock_identifier = MagicMock() + mock_identifier.class_name = "FakeTarget" + mock_identifier.endpoint = "" + mock_identifier.model_name = "" + mock_identifier.temperature = None + mock_identifier.top_p = None + mock_identifier.max_requests_per_minute = None + mock_identifier.target_specific_params = None + target_obj.get_identifier.return_value = mock_identifier result = target_object_to_instance("t-1", target_obj) assert result.target_type == "FakeTarget" assert result.endpoint is None + assert result.model_name is None # ============================================================================