Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions pyrit/backend/services/converter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pyrit.models import PromptDataType
from pyrit.models.data_type_serializer import data_serializer_factory
from pyrit.prompt_converter import PromptConverter
from pyrit.prompt_target import PromptChatTarget
from pyrit.prompt_target import PromptTarget
from pyrit.registry.object_registries import ConverterRegistry

_DATA_TYPE_EXTENSION: dict[str, str] = {
Expand Down Expand Up @@ -184,7 +184,16 @@ def _extract_parameters(converter_class: type) -> list[ConverterParameterSchema]


def _is_llm_based(converter_class: type) -> bool:
"""Return True if the converter requires an LLM target parameter."""
"""
Check if the converter requires a target parameter.

Matches any converter whose ``__init__`` accepts
a ``PromptTarget`` (or subclass) parameter.
These converters perform LLM-based transformations and should not automatically be applied

Returns:
bool: True if the converter is LLM-based, False otherwise.
"""
try:
sig = inspect.signature(converter_class.__init__)
except (ValueError, TypeError):
Expand All @@ -197,7 +206,7 @@ def _is_llm_based(converter_class: type) -> bool:
if ann is inspect.Parameter.empty:
continue
try:
if isinstance(ann, type) and issubclass(ann, PromptChatTarget):
if isinstance(ann, type) and issubclass(ann, PromptTarget):
return True
except TypeError:
continue
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/backend/test_converter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
ConverterPreviewRequest,
CreateConverterRequest,
)
from pyrit.backend.services.converter_service import ConverterService, get_converter_service
from pyrit.backend.services.converter_service import ConverterService, _is_llm_based, get_converter_service
from pyrit.identifiers import ComponentIdentifier
from pyrit.prompt_converter import (
Base64Converter,
CaesarConverter,
LLMGenericTextConverter,
NoiseConverter,
PersuasionConverter,
RepeatTokenConverter,
SuffixAppendConverter,
TenseConverter,
ToneConverter,
TranslationConverter,
VariationConverter,
)
from pyrit.prompt_converter.prompt_converter import get_converter_modalities
from pyrit.registry.object_registries import ConverterRegistry
Expand Down Expand Up @@ -607,3 +614,25 @@ def test_base64_converter_default_params(self) -> None:
# Verify type info is populated from identifier
assert isinstance(result.supported_input_types, list)
assert isinstance(result.supported_output_types, list)


class TestIsLlmBased:
"""Tests for the _is_llm_based introspection helper"""

def test_detects_llm_text_converter(self) -> None:
# Test that _is_llm_based correctly identifies converters that use LLMS as LLM-based.
for cls in (
LLMGenericTextConverter,
NoiseConverter,
PersuasionConverter,
ToneConverter,
TenseConverter,
TranslationConverter,
VariationConverter,
):
assert _is_llm_based(cls) is True, f"{cls.__name__} should be detected as LLM-based"

def test_does_not_flag_non_target_converters(self) -> None:
# Test that _is_llm_based does not incorrectly flag non-LLM converters.
assert _is_llm_based(Base64Converter) is False
assert _is_llm_based(CaesarConverter) is False
Loading