diff --git a/src/pydantic_ai_lightspeed/capabilities/__init__.py b/src/pydantic_ai_lightspeed/capabilities/__init__.py new file mode 100644 index 000000000..ba4c5074e --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/__init__.py @@ -0,0 +1 @@ +"""Pydantic AI capabilities for Lightspeed Core Stack.""" diff --git a/src/pydantic_ai_lightspeed/capabilities/redaction/__init__.py b/src/pydantic_ai_lightspeed/capabilities/redaction/__init__.py new file mode 100644 index 000000000..226ff41ef --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/redaction/__init__.py @@ -0,0 +1,21 @@ +"""PII redaction capability for Pydantic AI agents.""" + +from pydantic_ai_lightspeed.capabilities.redaction.capability import ( + PiiRedactionCapability, +) +from pydantic_ai_lightspeed.capabilities.redaction.config import ( + RedactionConfig, + RedactionRule, +) +from pydantic_ai_lightspeed.capabilities.redaction.core import ( + RedactionResult, + redact_text, +) + +__all__ = [ + "PiiRedactionCapability", + "RedactionConfig", + "RedactionResult", + "RedactionRule", + "redact_text", +] diff --git a/src/pydantic_ai_lightspeed/capabilities/redaction/capability.py b/src/pydantic_ai_lightspeed/capabilities/redaction/capability.py new file mode 100644 index 000000000..97b0c68ca --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/redaction/capability.py @@ -0,0 +1,322 @@ +"""Pydantic AI capability for PII redaction of model messages.""" + +from collections.abc import Sequence +from dataclasses import dataclass, replace +from typing import Any + +from pydantic_ai import RunContext +from pydantic_ai.capabilities import AbstractCapability +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextContent, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestContext + +from pydantic_ai_lightspeed.capabilities.redaction.config import ( + RedactionConfig, +) +from pydantic_ai_lightspeed.capabilities.redaction.core import ( + CompiledPatterns, + redact_text, +) + + +def _redact_string_content( + text: str, compiled_patterns: CompiledPatterns +) -> str | None: + """Redact PII from a string and return the redacted version if changed. + + Args: + text: The string to redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + The redacted string if redaction occurred, None otherwise. + """ + result = redact_text(text, compiled_patterns) + if result.redacted: + return result.content + return None + + +def _redact_text_content( + item: TextContent, compiled_patterns: CompiledPatterns +) -> TextContent | None: + """Redact PII from TextContent and return a new instance if changed. + + Args: + item: The TextContent to redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new TextContent with redacted content if changed, None otherwise. + """ + redacted_text = _redact_string_content(item.content, compiled_patterns) + if redacted_text is not None: + return replace(item, content=redacted_text) + return None + + +def _redact_content_item( + item: Any, compiled_patterns: CompiledPatterns +) -> tuple[Any, bool]: + """Redact a single content item and indicate whether it changed. + + Args: + item: The content item to redact (TextContent, str, or other). + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A tuple of (redacted_item, changed_flag). + """ + if isinstance(item, TextContent): + redacted = _redact_text_content(item, compiled_patterns) + if redacted is not None: + return redacted, True + return item, False + + if isinstance(item, str): + redacted_text = _redact_string_content(item, compiled_patterns) + if redacted_text is not None: + return redacted_text, True + return item, False + + return item, False + + +def _redact_content_list( + content: Sequence[Any], compiled_patterns: CompiledPatterns +) -> list[Any] | None: + """Redact PII from a list of content items. + + Args: + content: The list of content items to redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new list with redacted items if any changed, None otherwise. + """ + new_items: list[Any] = [] + any_changed = False + + for item in content: + redacted_item, changed = _redact_content_item(item, compiled_patterns) + new_items.append(redacted_item) + any_changed = any_changed or changed + + if any_changed: + return new_items + return None + + +def _redact_user_prompt_part( + part: UserPromptPart, + compiled_patterns: CompiledPatterns, +) -> UserPromptPart: + """Return a new UserPromptPart with PII redacted from text content. + + Returns the original instance unchanged if no redaction occurred. + Callers can detect changes via identity (``new is not original``). + + Args: + part: The user prompt part to redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new UserPromptPart with redacted content, or the original. + """ + if isinstance(part.content, str): + redacted_text = _redact_string_content(part.content, compiled_patterns) + if redacted_text is not None: + return replace(part, content=redacted_text) + return part + + redacted_list = _redact_content_list(part.content, compiled_patterns) + if redacted_list is not None: + return replace(part, content=redacted_list) + return part + + +def _redact_message_parts( + parts: Sequence[Any], compiled_patterns: CompiledPatterns +) -> list[Any] | None: + """Redact PII from message parts. + + Args: + parts: The message parts to redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new list with redacted parts if any changed, None otherwise. + """ + new_parts: list[Any] = [] + any_changed = False + + for part in parts: + if isinstance(part, UserPromptPart): + redacted_part = _redact_user_prompt_part(part, compiled_patterns) + new_parts.append(redacted_part) + any_changed = any_changed or (redacted_part is not part) + else: + new_parts.append(part) + + if any_changed: + return new_parts + return None + + +def _redact_model_request( + message: ModelRequest, compiled_patterns: CompiledPatterns +) -> ModelRequest | None: + """Redact PII from a ModelRequest message. + + Args: + message: The ModelRequest to redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new ModelRequest with redacted parts if changed, None otherwise. + """ + redacted_parts = _redact_message_parts(message.parts, compiled_patterns) + if redacted_parts is not None: + return replace(message, parts=redacted_parts) + return None + + +def _redact_messages( + messages: list[ModelMessage], + compiled_patterns: CompiledPatterns, +) -> list[ModelMessage]: + """Return a new message list with PII redacted from user prompt parts. + + Returns the original list unchanged if no redaction occurred. + + Args: + messages: The messages to scan and redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new list with redacted messages, or the original list. + """ + new_messages: list[ModelMessage] = [] + any_changed = False + + for message in messages: + if isinstance(message, ModelRequest): + redacted_message = _redact_model_request(message, compiled_patterns) + if redacted_message is not None: + new_messages.append(redacted_message) + any_changed = True + else: + new_messages.append(message) + else: + new_messages.append(message) + + if any_changed: + return new_messages + return messages + + +def _redact_response( + response: ModelResponse, + compiled_patterns: CompiledPatterns, +) -> ModelResponse: + """Return a new ModelResponse with PII redacted from text parts. + + Returns the original instance unchanged if no redaction occurred. + + Args: + response: The model response to scan and redact. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A new ModelResponse with redacted content, or the original. + """ + changed = False + new_parts: list[Any] = [] + + for part in response.parts: + if isinstance(part, TextPart): + result = redact_text(part.content, compiled_patterns) + if result.redacted: + new_parts.append(replace(part, content=result.content)) + changed = True + else: + new_parts.append(part) + else: + new_parts.append(part) + + if changed: + return replace(response, parts=new_parts) + return response + + +@dataclass +class PiiRedactionCapability(AbstractCapability[Any]): + """Pydantic AI capability that redacts PII from agent messages. + + Applies configurable regex-based redaction rules to user prompt + text before it reaches the model, and to model response text + before it is returned to the caller. + + Rules are validated and compiled at configuration time via + ``RedactionConfig``. Invalid regex patterns are rejected + immediately with a clear error. + + Attributes: + config: Redaction configuration with compiled regex rules. + """ + + config: RedactionConfig + + async def before_model_request( + self, + ctx: RunContext[Any], + request_context: ModelRequestContext, + ) -> ModelRequestContext: + """Redact PII from user messages before they reach the model. + + Args: + ctx: The current run context. + request_context: The model request context containing messages. + + Returns: + A new ModelRequestContext with redacted messages, or the + original if no redaction occurred. + """ + new_messages = _redact_messages( + request_context.messages, + self.config.compiled_patterns, + ) + if new_messages is not request_context.messages: + return replace(request_context, messages=new_messages) + return request_context + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + """Redact PII from model response text parts. + + Args: + ctx: The current run context. + request_context: The model request context. + response: The model response to redact. + + Returns: + A new ModelResponse with redacted text parts, or the + original if no redaction occurred. + """ + new_response = _redact_response( + response, + self.config.compiled_patterns, + ) + + return new_response diff --git a/src/pydantic_ai_lightspeed/capabilities/redaction/config.py b/src/pydantic_ai_lightspeed/capabilities/redaction/config.py new file mode 100644 index 000000000..b541c288d --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/redaction/config.py @@ -0,0 +1,107 @@ +"""Configuration models for PII redaction rules.""" + +import re +from typing import Self + +from pydantic import Field, PrivateAttr, model_validator + +from models.config import ConfigurationBase +from pydantic_ai_lightspeed.capabilities.redaction.core import ( + CompiledPatterns, +) + + +class RedactionRule(ConfigurationBase): + """A single regex-based redaction rule. + + Attributes: + pattern: Raw regex pattern string to match sensitive data. + replacement: Text to substitute for each match. + case_sensitive: Per-rule override for case sensitivity. + When None, the global ``RedactionConfig.case_sensitive`` + flag applies. + """ + + pattern: str = Field( + ..., + title="Pattern", + description="Regex pattern to match sensitive data", + ) + replacement: str = Field( + ..., + title="Replacement", + description="Replacement string for matched text", + ) + case_sensitive: bool | None = Field( + None, + title="Case sensitive", + description=( + "Per-rule case sensitivity override. " + "When None, the global config flag applies." + ), + ) + + +class RedactionConfig(ConfigurationBase): + """Configuration for PII redaction with regex-based rules. + + Rules are validated and compiled at construction time. Invalid + regex patterns raise a ``ValueError`` immediately. + + Attributes: + rules: Ordered list of redaction rules applied sequentially. + case_sensitive: When False, patterns are compiled with + ``re.IGNORECASE``. Defaults to False. + """ + + rules: list[RedactionRule] = Field( + default_factory=list, + title="Redaction rules", + description="Ordered list of PII redaction rules", + ) + case_sensitive: bool = Field( + False, + title="Case sensitive", + description=("When False, patterns are compiled with re.IGNORECASE"), + ) + + _compiled_patterns: CompiledPatterns = PrivateAttr( + default_factory=list, + ) + + @model_validator(mode="after") + def compile_patterns(self) -> Self: + """Compile regex patterns and reject invalid ones. + + Per-rule ``case_sensitive`` overrides the global flag when set. + + Raises: + ValueError: If any rule contains an invalid regex pattern. + + Returns: + The validated configuration instance. + """ + global_case_sensitive = self.case_sensitive + compiled: CompiledPatterns = [] + for rule in self.rules: + effective = ( + rule.case_sensitive + if rule.case_sensitive is not None + else global_case_sensitive + ) + flags = 0 if effective else re.IGNORECASE + try: + pattern = re.compile(rule.pattern, flags) + except re.error as e: + raise ValueError(f"Invalid regex pattern: {rule.pattern}: {e}") from e + compiled.append((pattern, rule.replacement)) + self._compiled_patterns = compiled + return self + + @property + def compiled_patterns(self) -> CompiledPatterns: + """Pre-compiled (regex, replacement) pairs. + + Returns a shallow copy to prevent mutation of internal state. + """ + return list(self._compiled_patterns) diff --git a/src/pydantic_ai_lightspeed/capabilities/redaction/core.py b/src/pydantic_ai_lightspeed/capabilities/redaction/core.py new file mode 100644 index 000000000..bb1288bb8 --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/redaction/core.py @@ -0,0 +1,55 @@ +"""Core redaction logic for PII detection and replacement.""" + +from re import Pattern + +from pydantic import BaseModel, ConfigDict + +CompiledPatterns = list[tuple[Pattern[str], str]] + + +class RedactionResult(BaseModel): + """Result of applying PII redaction rules to text. + + Attributes: + content: The text after all redaction rules have been applied. + redacted: True if at least one rule matched and changed the text. + redaction_count: Total number of substitutions made across all rules. + """ + + model_config = ConfigDict(frozen=True) + + content: str + redacted: bool + redaction_count: int + + +def redact_text( + content: str, + compiled_patterns: CompiledPatterns, +) -> RedactionResult: + """Apply PII redaction rules to the given text. + + Rules are applied sequentially in the order provided. Earlier rules + may affect later rule matches. + + Args: + content: The text to redact. Not mutated. + compiled_patterns: Pre-compiled (pattern, replacement) pairs. + + Returns: + A RedactionResult with the redacted content, a flag indicating + whether any substitution occurred, and the total substitution + count. + """ + result = content + total_count = 0 + + for pattern, replacement in compiled_patterns: + result, count = pattern.subn(replacement, result) + total_count += count + + return RedactionResult( + content=result, + redacted=total_count > 0, + redaction_count=total_count, + ) diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py b/tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py new file mode 100644 index 000000000..711e813cd --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py @@ -0,0 +1 @@ +"""Tests for pydantic_ai_lightspeed.capabilities package.""" diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/__init__.py b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/__init__.py new file mode 100644 index 000000000..9ea075adc --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/__init__.py @@ -0,0 +1 @@ +"""Tests for pydantic_ai_lightspeed.capabilities.redaction package.""" diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_capability.py b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_capability.py new file mode 100644 index 000000000..1376c76e1 --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_capability.py @@ -0,0 +1,316 @@ +"""Unit tests for pydantic_ai_lightspeed.capabilities.redaction.capability module.""" + +import re + +import pytest +from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextContent, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestContext +from pytest_mock import MockerFixture + +from pydantic_ai_lightspeed.capabilities.redaction.capability import ( + PiiRedactionCapability, + _redact_content_item, + _redact_content_list, + _redact_message_parts, + _redact_messages, + _redact_model_request, + _redact_response, + _redact_user_prompt_part, +) +from pydantic_ai_lightspeed.capabilities.redaction.config import ( + RedactionConfig, + RedactionRule, +) +from pydantic_ai_lightspeed.capabilities.redaction.core import ( + CompiledPatterns, +) + +EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") +EMAIL_PATTERNS: CompiledPatterns = [(EMAIL_PATTERN, "[REDACTED_EMAIL]")] + + +class TestRedactContentItem: + """Tests for _redact_content_item helper.""" + + def test_text_content_redacted(self) -> None: + """Test redaction of a TextContent item.""" + tc = TextContent(content="email: user@test.com") + result, changed = _redact_content_item(tc, EMAIL_PATTERNS) + assert changed is True + assert isinstance(result, TextContent) + assert result.content == "email: [REDACTED_EMAIL]" + + def test_text_content_no_match(self) -> None: + """Test that non-matching TextContent is returned unchanged.""" + tc = TextContent(content="safe text") + result, changed = _redact_content_item(tc, EMAIL_PATTERNS) + assert changed is False + assert result is tc + + def test_string_content_redacted(self) -> None: + """Test redaction of a plain string item.""" + result, changed = _redact_content_item("user@test.com", EMAIL_PATTERNS) + assert changed is True + assert result == "[REDACTED_EMAIL]" + + def test_string_content_no_match(self) -> None: + """Test that non-matching string is returned unchanged.""" + result, changed = _redact_content_item("safe text", EMAIL_PATTERNS) + assert changed is False + assert result == "safe text" + + def test_other_type_passes_through(self) -> None: + """Test that non-text types are returned unchanged.""" + sentinel = object() + result, changed = _redact_content_item(sentinel, EMAIL_PATTERNS) + assert changed is False + assert result is sentinel + + +class TestRedactContentList: + """Tests for _redact_content_list helper.""" + + def test_redacts_matching_items(self) -> None: + """Test that matching items in a list are redacted.""" + items = [TextContent(content="a@b.com"), "safe"] + result = _redact_content_list(items, EMAIL_PATTERNS) + assert result is not None + assert result[0].content == "[REDACTED_EMAIL]" + assert result[1] == "safe" + + def test_returns_none_when_no_match(self) -> None: + """Test that None is returned when nothing changes.""" + items = [TextContent(content="safe"), "also safe"] + result = _redact_content_list(items, EMAIL_PATTERNS) + assert result is None + + def test_empty_list(self) -> None: + """Test that an empty list returns None.""" + result = _redact_content_list([], EMAIL_PATTERNS) + assert result is None + + +class TestRedactUserPromptPart: + """Tests for _redact_user_prompt_part helper.""" + + def test_string_content_redacted(self) -> None: + """Test redaction of plain string content.""" + part = UserPromptPart(content="email: user@test.com") + result = _redact_user_prompt_part(part, EMAIL_PATTERNS) + assert result is not part + assert result.content == "email: [REDACTED_EMAIL]" + + def test_string_content_no_match(self) -> None: + """Test that non-matching string content returns the same instance.""" + part = UserPromptPart(content="no emails here") + result = _redact_user_prompt_part(part, EMAIL_PATTERNS) + assert result is part + assert result.content == "no emails here" + + def test_text_content_sequence_redacted(self) -> None: + """Test redaction of TextContent items in a sequence.""" + tc = TextContent(content="contact admin@corp.com") + part = UserPromptPart(content=[tc]) + result = _redact_user_prompt_part(part, EMAIL_PATTERNS) + assert result is not part + assert result.content[0].content == "contact [REDACTED_EMAIL]" + + def test_text_content_sequence_no_match(self) -> None: + """Test that non-matching TextContent sequence returns same instance.""" + tc = TextContent(content="safe text") + part = UserPromptPart(content=[tc]) + result = _redact_user_prompt_part(part, EMAIL_PATTERNS) + assert result is part + + +class TestRedactMessageParts: + """Tests for _redact_message_parts helper.""" + + def test_redacts_user_prompt_parts(self) -> None: + """Test that UserPromptPart items are redacted.""" + parts = [UserPromptPart(content="a@b.com")] + result = _redact_message_parts(parts, EMAIL_PATTERNS) + assert result is not None + assert result[0].content == "[REDACTED_EMAIL]" + + def test_returns_none_when_no_match(self) -> None: + """Test that None is returned when nothing changes.""" + parts = [UserPromptPart(content="safe")] + result = _redact_message_parts(parts, EMAIL_PATTERNS) + assert result is None + + +class TestRedactModelRequest: + """Tests for _redact_model_request helper.""" + + def test_returns_new_request_when_redacted(self) -> None: + """Test that a new ModelRequest is returned when parts change.""" + req = ModelRequest(parts=[UserPromptPart(content="a@b.com")]) + result = _redact_model_request(req, EMAIL_PATTERNS) + assert result is not None + assert result is not req + assert result.parts[0].content == "[REDACTED_EMAIL]" + + def test_returns_none_when_no_match(self) -> None: + """Test that None is returned when nothing changes.""" + req = ModelRequest(parts=[UserPromptPart(content="safe")]) + result = _redact_model_request(req, EMAIL_PATTERNS) + assert result is None + + +class TestRedactMessages: + """Tests for _redact_messages helper.""" + + def test_redacts_user_prompt_in_request(self) -> None: + """Test that user prompt parts within ModelRequest are redacted.""" + req = ModelRequest(parts=[UserPromptPart(content="hi user@x.com")]) + messages = [req] + result = _redact_messages(messages, EMAIL_PATTERNS) + assert result is not messages + assert result[0].parts[0].content == "hi [REDACTED_EMAIL]" + + def test_skips_model_response_messages(self) -> None: + """Test that ModelResponse messages pass through unchanged.""" + resp = ModelResponse(parts=[TextPart(content="user@x.com")]) + messages = [resp] + result = _redact_messages(messages, EMAIL_PATTERNS) + assert result is messages + part = result[0].parts[0] + assert isinstance(part, TextPart) + assert part.content == "user@x.com" + + def test_no_redaction_returns_original_list(self) -> None: + """Test that original list is returned when nothing changes.""" + req = ModelRequest(parts=[UserPromptPart(content="clean text")]) + messages = [req] + result = _redact_messages(messages, EMAIL_PATTERNS) + assert result is messages + + +class TestRedactResponse: + """Tests for _redact_response helper.""" + + def test_redacts_text_parts(self) -> None: + """Test that text parts in a response are redacted.""" + resp = ModelResponse(parts=[TextPart(content="reply to user@x.com")]) + result = _redact_response(resp, EMAIL_PATTERNS) + assert result is not resp + part = result.parts[0] + assert isinstance(part, TextPart) + assert part.content == "reply to [REDACTED_EMAIL]" + + def test_no_match_returns_original(self) -> None: + """Test that original response is returned when nothing matches.""" + resp = ModelResponse(parts=[TextPart(content="clean reply")]) + result = _redact_response(resp, EMAIL_PATTERNS) + assert result is resp + + +class TestPiiRedactionCapability: + """Tests for PiiRedactionCapability lifecycle hooks.""" + + @pytest.fixture(name="capability") + def capability_fixture(self) -> PiiRedactionCapability: + """Create a PiiRedactionCapability with an email redaction rule. + + Returns: + A configured PiiRedactionCapability instance. + """ + config = RedactionConfig( + rules=[ + RedactionRule( + pattern=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", + replacement="[REDACTED_EMAIL]", + ) + ], + case_sensitive=True, + ) + return PiiRedactionCapability(config=config) + + @pytest.mark.asyncio() + async def test_before_model_request_redacts_user_messages( + self, + capability: PiiRedactionCapability, + mocker: MockerFixture, + ) -> None: + """Test that before_model_request redacts PII from user messages.""" + req = ModelRequest(parts=[UserPromptPart(content="email: a@b.com")]) + request_context = ModelRequestContext( + model=mocker.Mock(), + messages=[req], + model_settings=None, + model_request_parameters=mocker.Mock(), + ) + result = await capability.before_model_request(mocker.Mock(), request_context) + assert result is not request_context + assert result.messages[0].parts[0].content == "email: [REDACTED_EMAIL]" + + @pytest.mark.asyncio() + async def test_before_model_request_no_match( + self, + capability: PiiRedactionCapability, + mocker: MockerFixture, + ) -> None: + """Test that before_model_request returns original when nothing matches.""" + req = ModelRequest(parts=[UserPromptPart(content="safe text")]) + request_context = ModelRequestContext( + model=mocker.Mock(), + messages=[req], + model_settings=None, + model_request_parameters=mocker.Mock(), + ) + result = await capability.before_model_request(mocker.Mock(), request_context) + assert result is request_context + assert req.parts[0].content == "safe text" + + @pytest.mark.asyncio() + async def test_after_model_request_redacts_response( + self, + capability: PiiRedactionCapability, + mocker: MockerFixture, + ) -> None: + """Test that after_model_request redacts PII from response text.""" + resp = ModelResponse(parts=[TextPart(content="leaked a@b.com")]) + request_context = ModelRequestContext( + model=mocker.Mock(), + messages=[], + model_settings=None, + model_request_parameters=mocker.Mock(), + ) + result = await capability.after_model_request( + mocker.Mock(), + request_context=request_context, + response=resp, + ) + assert result is not resp + part = result.parts[0] + assert isinstance(part, TextPart) + assert part.content == "leaked [REDACTED_EMAIL]" + + @pytest.mark.asyncio() + async def test_after_model_request_no_match( + self, + capability: PiiRedactionCapability, + mocker: MockerFixture, + ) -> None: + """Test that after_model_request returns original when nothing matches.""" + resp = ModelResponse(parts=[TextPart(content="clean response")]) + request_context = ModelRequestContext( + model=mocker.Mock(), + messages=[], + model_settings=None, + model_request_parameters=mocker.Mock(), + ) + result = await capability.after_model_request( + mocker.Mock(), + request_context=request_context, + response=resp, + ) + assert result is resp + assert resp.parts[0].content == "clean response" diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_config.py b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_config.py new file mode 100644 index 000000000..8aac48e53 --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_config.py @@ -0,0 +1,157 @@ +"""Unit tests for pydantic_ai_lightspeed.capabilities.redaction.config module.""" + +import re + +import pytest +from pydantic import ValidationError + +from pydantic_ai_lightspeed.capabilities.redaction.config import ( + RedactionConfig, + RedactionRule, +) + + +class TestRedactionRule: + """Tests for the RedactionRule model.""" + + def test_construction(self) -> None: + """Test that a RedactionRule can be constructed with valid fields.""" + rule = RedactionRule(pattern=r"\d+", replacement="[NUM]", case_sensitive=False) + assert rule.pattern == r"\d+" + assert rule.replacement == "[NUM]" + assert rule.case_sensitive is False + + def test_case_sensitive_defaults_to_none(self) -> None: + """Test that case_sensitive defaults to None when omitted.""" + rule = RedactionRule(pattern=r"\d+", replacement="[NUM]") + assert rule.case_sensitive is None + + def test_case_sensitive_override(self) -> None: + """Test that per-rule case_sensitive can be set.""" + rule = RedactionRule( + pattern=r"secret", replacement="[REDACTED]", case_sensitive=True + ) + assert rule.case_sensitive is True + + def test_rejects_extra_fields(self) -> None: + """Test that extra fields are rejected by ConfigurationBase.""" + with pytest.raises(ValidationError): + RedactionRule(pattern=r"\d+", replacement="[NUM]", unknown_field="bad") + + +class TestRedactionConfigCompilation: + """Tests for RedactionConfig pattern compilation.""" + + def test_empty_rules(self) -> None: + """Test that empty rules produce no compiled patterns.""" + config = RedactionConfig(rules=[]) + assert not config.compiled_patterns + + def test_single_rule_compiles(self) -> None: + """Test that a single rule is compiled into a pattern.""" + config = RedactionConfig( + rules=[ + RedactionRule( + pattern=r"\d{3}-\d{4}", + replacement="[PHONE]", + case_sensitive=False, + ) + ] + ) + patterns = config.compiled_patterns + assert len(patterns) == 1 + compiled_re, replacement = patterns[0] + assert replacement == "[PHONE]" + assert compiled_re.search("call 555-1234") + + def test_multiple_rules_compile(self) -> None: + """Test that multiple rules produce multiple compiled patterns.""" + config = RedactionConfig( + rules=[ + RedactionRule(pattern=r"foo", replacement="[A]", case_sensitive=False), + RedactionRule(pattern=r"bar", replacement="[B]", case_sensitive=True), + ] + ) + assert len(config.compiled_patterns) == 2 + + def test_invalid_regex_raises(self) -> None: + """Test that an invalid regex pattern raises ValueError.""" + with pytest.raises(ValueError, match="Invalid regex pattern"): + RedactionConfig( + rules=[ + RedactionRule( + pattern=r"[invalid", + replacement="x", + case_sensitive=False, + ) + ] + ) + + +class TestRedactionConfigCaseSensitivity: + """Tests for case sensitivity behavior in RedactionConfig.""" + + def test_default_case_insensitive(self) -> None: + """Test that patterns are case-insensitive by default.""" + config = RedactionConfig( + rules=[ + RedactionRule( + pattern=r"secret", + replacement="[REDACTED]", + case_sensitive=None, + ) + ] + ) + compiled_re, _ = config.compiled_patterns[0] + assert compiled_re.flags & re.IGNORECASE + + def test_global_case_sensitive(self) -> None: + """Test that global case_sensitive=True disables IGNORECASE.""" + config = RedactionConfig( + rules=[ + RedactionRule( + pattern=r"secret", + replacement="[REDACTED]", + case_sensitive=None, + ) + ], + case_sensitive=True, + ) + compiled_re, _ = config.compiled_patterns[0] + assert (compiled_re.flags & re.IGNORECASE) == 0 + + def test_per_rule_override(self) -> None: + """Test that per-rule case_sensitive overrides the global flag.""" + config = RedactionConfig( + rules=[ + RedactionRule( + pattern=r"secret", + replacement="[REDACTED]", + case_sensitive=True, + ), + ], + case_sensitive=False, + ) + compiled_re, _ = config.compiled_patterns[0] + assert (compiled_re.flags & re.IGNORECASE) == 0 + + +class TestRedactionConfigCompiledPatternsProperty: + """Tests for compiled_patterns property behavior.""" + + def test_returns_list(self) -> None: + """Test that compiled_patterns returns a list.""" + config = RedactionConfig( + rules=[RedactionRule(pattern=r"x", replacement="y", case_sensitive=False)] + ) + assert isinstance(config.compiled_patterns, list) + + def test_returns_copy(self) -> None: + """Test that compiled_patterns returns a copy, not the internal list.""" + config = RedactionConfig( + rules=[RedactionRule(pattern=r"x", replacement="y", case_sensitive=False)] + ) + a = config.compiled_patterns + b = config.compiled_patterns + assert a == b + assert a is not b diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_core.py b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_core.py new file mode 100644 index 000000000..697c46786 --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/redaction/test_core.py @@ -0,0 +1,153 @@ +"""Unit tests for pydantic_ai_lightspeed.capabilities.redaction.core module.""" + +import re + +import pytest +from pydantic import ValidationError + +from pydantic_ai_lightspeed.capabilities.redaction.core import ( + CompiledPatterns, + RedactionResult, + redact_text, +) + +EMAIL_PATTERN = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") +PASSWORD_PATTERN = re.compile(r"(?i)(password|passwd)[\s:=]+[^\s]+") +SECRET_PATTERN = re.compile(r"(?i)(api_key|secret|token)[\s:=]+[a-zA-Z0-9\-_]{16,}") + + +class TestRedactionResult: + """Tests for the RedactionResult model.""" + + def test_construction(self) -> None: + """Test that RedactionResult can be constructed with valid fields.""" + result = RedactionResult(content="redacted", redacted=True, redaction_count=1) + assert result.content == "redacted" + assert result.redacted is True + assert result.redaction_count == 1 + + def test_frozen(self) -> None: + """Test that RedactionResult is immutable.""" + result = RedactionResult(content="text", redacted=False, redaction_count=0) + with pytest.raises(ValidationError): + result.content = "modified" + + +class TestRedactTextNoRules: + """Tests for redact_text with no or non-matching rules.""" + + def test_empty_rules(self) -> None: + """Test that empty rules return original content unchanged.""" + result = redact_text("hello world", []) + assert result.content == "hello world" + assert result.redacted is False + assert result.redaction_count == 0 + + def test_no_match(self) -> None: + """Test that non-matching rules return original content.""" + patterns: CompiledPatterns = [(EMAIL_PATTERN, "[REDACTED_EMAIL]")] + result = redact_text("no emails here", patterns) + assert result.content == "no emails here" + assert result.redacted is False + assert result.redaction_count == 0 + + def test_empty_content(self) -> None: + """Test that empty string input returns empty string.""" + patterns: CompiledPatterns = [(EMAIL_PATTERN, "[REDACTED_EMAIL]")] + result = redact_text("", patterns) + assert result.content == "" + assert result.redacted is False + assert result.redaction_count == 0 + + def test_does_not_mutate_input(self) -> None: + """Test that the original content string is not mutated.""" + original = "user@example.com" + original_copy = original + patterns: CompiledPatterns = [(EMAIL_PATTERN, "[REDACTED_EMAIL]")] + redact_text(original, patterns) + assert original == original_copy + assert original == "user@example.com" + + +class TestRedactTextMatching: + """Tests for redact_text with matching patterns.""" + + def test_single_match(self) -> None: + """Test that a single match is redacted.""" + patterns: CompiledPatterns = [(EMAIL_PATTERN, "[REDACTED_EMAIL]")] + result = redact_text("contact user@example.com please", patterns) + assert result.content == "contact [REDACTED_EMAIL] please" + assert result.redacted is True + assert result.redaction_count == 1 + + def test_multiple_matches_same_pattern(self) -> None: + """Test that multiple occurrences of the same pattern are all redacted.""" + patterns: CompiledPatterns = [(EMAIL_PATTERN, "[REDACTED_EMAIL]")] + result = redact_text("from a@b.com to c@d.com", patterns) + assert result.content == ("from [REDACTED_EMAIL] to [REDACTED_EMAIL]") + assert result.redacted is True + assert result.redaction_count == 2 + + def test_sequential_rule_application(self) -> None: + """Test that rules are applied sequentially; earlier rules affect later matches.""" + patterns: CompiledPatterns = [ + (re.compile(r"foo"), "bar"), + (re.compile(r"bar"), "baz"), + ] + result = redact_text("foo", patterns) + assert result.content == "baz" + assert result.redacted is True + assert result.redaction_count == 2 + + def test_multiple_different_patterns(self) -> None: + """Test redaction with multiple different pattern types.""" + patterns: CompiledPatterns = [ + (EMAIL_PATTERN, "[REDACTED_EMAIL]"), + (PASSWORD_PATTERN, "[REDACTED_PASSWORD]"), + ] + result = redact_text( + "email: user@test.com password: s3cret123", + patterns, + ) + assert "[REDACTED_EMAIL]" in result.content + assert "[REDACTED_PASSWORD]" in result.content + assert result.redacted is True + assert result.redaction_count == 2 + + def test_redaction_count_accumulates_across_rules(self) -> None: + """Test that redaction_count sums substitutions from all rules.""" + patterns: CompiledPatterns = [ + (EMAIL_PATTERN, "[REDACTED_EMAIL]"), + (PASSWORD_PATTERN, "[REDACTED_PASSWORD]"), + ] + text = "a@b.com c@d.com password: secret" + result = redact_text(text, patterns) + assert result.redaction_count == 3 + + +class TestRedactTextCaseSensitivity: + """Tests for case sensitivity behavior in redact_text.""" + + def test_case_insensitive(self) -> None: + """Test that case-insensitive patterns match mixed case.""" + patterns: CompiledPatterns = [ + ( + re.compile(r"password[\s:=]+[^\s]+", re.IGNORECASE), + "[REDACTED_PASSWORD]", + ) + ] + result = redact_text("PASSWORD: secret123", patterns) + assert result.content == "[REDACTED_PASSWORD]" + assert result.redacted is True + + def test_case_sensitive(self) -> None: + """Test that case-sensitive patterns only match exact case.""" + patterns: CompiledPatterns = [ + ( + re.compile(r"password[\s:=]+[^\s]+"), + "[REDACTED_PASSWORD]", + ) + ] + result = redact_text("PASSWORD: secret123", patterns) + assert result.content == "PASSWORD: secret123" + assert result.redacted is False