diff --git a/docs/realtime/guide.md b/docs/realtime/guide.md index 8aac3d620d..94c31c3c57 100644 --- a/docs/realtime/guide.md +++ b/docs/realtime/guide.md @@ -244,7 +244,11 @@ Bare `RealtimeAgent` handoffs are auto-wrapped, and `realtime_handoff(...)` lets ### Guardrails -Realtime agents support output guardrails on agent responses and input guardrails on function-tool calls. Output guardrails run on debounced transcript accumulation rather than on every partial token, and they emit `guardrail_tripped` instead of raising an exception. +Realtime agents support output guardrails on agent responses and input guardrails on the user's +transcribed audio. (Function-tool calls have their own, separate tool input guardrails, which are a +distinct feature from the transcript input guardrails described here.) Output guardrails run on +debounced transcript accumulation rather than on every partial token, and they emit +`guardrail_tripped` instead of raising an exception. ```python from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail @@ -270,6 +274,36 @@ triggered guardrail so the model can produce a replacement response. Your audio listen for `audio_interrupted` and stop local playback immediately, because guardrails run on debounced transcript text and some audio may already be buffered when the tripwire fires. +Realtime agents also support **input guardrails** that run on the user's transcribed audio. Configure +them via `RealtimeAgent.input_guardrails` or `RealtimeRunConfig["input_guardrails"]`; the two lists +are combined and de-duplicated per turn. They run once on the completed user transcript (the +`input_audio_transcription_completed` event), and when one trips the session emits an +`input_guardrail_tripped` event, forces `response.cancel`, and sends a follow-up user message that +names the triggered guardrail. + +```python +from agents.guardrail import GuardrailFunctionOutput, InputGuardrail + + +def no_jailbreak(context, agent, user_input): + return GuardrailFunctionOutput( + tripwire_triggered="jailbreak" in user_input.lower(), + output_info=None, + ) + + +agent = RealtimeAgent( + name="Assistant", + instructions="...", + input_guardrails=[InputGuardrail(guardrail_function=no_jailbreak)], +) +``` + +Two limitations are worth noting. Input guardrails only run on transcribed audio, so text sent +through `session.send_message()` is not checked. And because guardrails run in a background task, +the forced cancel reliably interrupts a response that is already in flight, but a response created +in the narrow window after the guardrail resolves may not be cancelled. + ## SIP and telephony The Python SDK includes a first-class SIP attach flow via [`OpenAIRealtimeSIPModel`][agents.realtime.openai_realtime.OpenAIRealtimeSIPModel]. diff --git a/docs/ref/realtime/events.md b/docs/ref/realtime/events.md index 137d9a6434..cc047b2a58 100644 --- a/docs/ref/realtime/events.md +++ b/docs/ref/realtime/events.md @@ -24,6 +24,7 @@ ### Guardrail Events ::: agents.realtime.events.RealtimeGuardrailTripped +::: agents.realtime.events.RealtimeInputGuardrailTripped ### History Events ::: agents.realtime.events.RealtimeHistoryAdded diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index dc280bc62d..f2ebb0d629 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -194,6 +194,10 @@ async def _serialize_event(self, event: RealtimeSessionEvent) -> dict[str, Any]: base_event["guardrail_results"] = [ {"name": result.guardrail.name} for result in event.guardrail_results ] + elif event.type == "input_guardrail_tripped": + base_event["guardrail_results"] = [ + {"name": result.guardrail.name} for result in event.guardrail_results + ] elif event.type == "raw_model_event": base_event["raw_model_event"] = { "type": event.data.type, diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 8e3db27c25..079616d5d5 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -29,6 +29,7 @@ RealtimeHandoffEvent, RealtimeHistoryAdded, RealtimeHistoryUpdated, + RealtimeInputGuardrailTripped, RealtimeRawModelEvent, RealtimeSessionEvent, RealtimeToolApprovalRequired, @@ -132,6 +133,7 @@ "RealtimeHandoffEvent", "RealtimeHistoryAdded", "RealtimeHistoryUpdated", + "RealtimeInputGuardrailTripped", "RealtimeRawModelEvent", "RealtimeSessionEvent", "RealtimeToolApprovalRequired", diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py index a226e518f7..b1f72ee8ea 100644 --- a/src/agents/realtime/agent.py +++ b/src/agents/realtime/agent.py @@ -9,7 +9,7 @@ from agents.prompts import Prompt from ..agent import AgentBase -from ..guardrail import OutputGuardrail +from ..guardrail import InputGuardrail, OutputGuardrail from ..handoffs import Handoff from ..lifecycle import AgentHooksBase, RunHooksBase from ..logger import logger @@ -79,6 +79,13 @@ class RealtimeAgent(AgentBase, Generic[TContext]): """A class that receives callbacks on various lifecycle events for this agent. """ + input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) + """A list of checks that run on the user's transcribed audio input. They run once on the + completed user transcript and, when tripped, force a cancel of the in-progress response. This + reliably interrupts a response that is already in flight, but a response created after the + guardrail resolves may not be interrupted. Text input sent via `send_message` is not checked. + """ + def __post_init__(self) -> None: if not isinstance(self.name, str): raise TypeError(f"RealtimeAgent name must be a string, got {type(self.name).__name__}") diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 3df0606bdb..a203da7ab7 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -10,7 +10,7 @@ from agents.prompts import Prompt -from ..guardrail import OutputGuardrail +from ..guardrail import InputGuardrail, OutputGuardrail from ..handoffs import Handoff from ..model_settings import ToolChoice from ..run_config import ToolErrorFormatter @@ -279,6 +279,9 @@ class RealtimeRunConfig(TypedDict): tool_error_formatter: NotRequired[ToolErrorFormatter] """Optional callback that formats tool error messages returned to the model.""" + input_guardrails: NotRequired[list[InputGuardrail[Any]]] + """List of input guardrails to run on the user's transcribed audio input.""" + # TODO (rm) Add history audio storage config diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py index 388dac37e8..3572b7a2b7 100644 --- a/src/agents/realtime/events.py +++ b/src/agents/realtime/events.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any, Literal, TypeAlias -from ..guardrail import OutputGuardrailResult +from ..guardrail import InputGuardrailResult, OutputGuardrailResult from ..run_context import RunContextWrapper from ..tool import Tool from .agent import RealtimeAgent @@ -243,6 +243,28 @@ class RealtimeGuardrailTripped: type: Literal["guardrail_tripped"] = "guardrail_tripped" +@dataclass +class RealtimeInputGuardrailTripped: + """An input guardrail has been tripped on the user's transcribed input. + + When a guardrail trips, the session forces a cancel of the in-progress response. This + reliably interrupts a response that is already in flight. Because guardrails run in a + background task, a response that is created in the narrow window after the guardrail + resolves but before the cancel can take effect may not be interrupted. + """ + + guardrail_results: list[InputGuardrailResult] + """The results from all triggered input guardrails.""" + + message: str + """The user transcript that triggered the guardrail.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["input_guardrail_tripped"] = "input_guardrail_tripped" + + @dataclass class RealtimeInputAudioTimeoutTriggered: """Called when the model detects a period of inactivity/silence from the user.""" @@ -268,6 +290,7 @@ class RealtimeInputAudioTimeoutTriggered: | RealtimeHistoryUpdated | RealtimeHistoryAdded | RealtimeGuardrailTripped + | RealtimeInputGuardrailTripped | RealtimeInputAudioTimeoutTriggered ) """An event emitted by the realtime session.""" diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 3b186e5502..93906b6d88 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -17,6 +17,7 @@ ) from ..agent import Agent from ..exceptions import ToolInputGuardrailTripwireTriggered, UserError +from ..guardrail import InputGuardrail, InputGuardrailResult from ..handoffs import Handoff from ..items import ToolApprovalItem from ..logger import logger @@ -43,6 +44,7 @@ RealtimeHistoryAdded, RealtimeHistoryUpdated, RealtimeInputAudioTimeoutTriggered, + RealtimeInputGuardrailTripped, RealtimeRawModelEvent, RealtimeSessionEvent, RealtimeToolApprovalRequired, @@ -202,6 +204,8 @@ def __init__( # Guardrails state tracking self._interrupted_response_ids: set[str] = set() + # User item_ids for which an input guardrail has already interrupted the response. + self._interrupted_input_item_ids: set[str] = set() self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get( @@ -365,6 +369,10 @@ async def on_event(self, event: RealtimeModelEvent) -> None: await self._put_event( RealtimeHistoryUpdated(info=self._event_info, history=self._history) ) + # Run input guardrails on the finalized user transcript. The transcription completes + # around the time the server begins generating a response, so we mirror the + # output-guardrail trip behavior and force a response cancel when a guardrail trips. + self._enqueue_input_guardrail_task(event.transcript, event.item_id) elif event.type == "input_audio_timeout_triggered": await self._put_event( RealtimeInputAudioTimeoutTriggered( @@ -1263,6 +1271,81 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: return False + async def _run_input_guardrails( + self, + text: str, + item_id: str, + agent: RealtimeAgent, + input_guardrails: list[InputGuardrail[Any]], + ) -> bool: + """Run input guardrails on the user's transcribed input. Returns True if any guardrail was + triggered. + + ``agent`` and ``input_guardrails`` are snapshotted when the transcription event is handled + so that a concurrent ``update_agent()`` or handoff cannot swap in a different agent's + guardrails before this background task runs. + """ + # If we've already interrupted the response for this user item, skip. + if not input_guardrails or item_id in self._interrupted_input_item_ids: + return False + + async def _run_one(guardrail: InputGuardrail[Any]) -> InputGuardrailResult | None: + try: + return await guardrail.run( + # TODO (rm) Remove this cast, it's wrong + cast(Agent[Any], agent), + text, + self._context_wrapper, + ) + except Exception as exc: + logger.warning( + "Input guardrail %r raised %s: %s; skipping it.", + guardrail.get_name(), + type(exc).__name__, + exc, + ) + logger.debug("Input guardrail failure details.", exc_info=True) + return None + + # Run the guardrails concurrently so a slow guardrail cannot delay the forced cancel behind + # unrelated guardrails, which would let the unsafe turn keep generating. + results = await asyncio.gather(*(_run_one(guardrail) for guardrail in input_guardrails)) + triggered_results = [ + result for result in results if result is not None and result.output.tripwire_triggered + ] + + if triggered_results: + # Double-check: bail if already interrupted for this user item. + if item_id in self._interrupted_input_item_ids: + return False + + # Mark as interrupted immediately (before any awaits) to minimize the race window. + self._interrupted_input_item_ids.add(item_id) + + # Emit input guardrail tripped event. + await self._put_event( + RealtimeInputGuardrailTripped( + guardrail_results=triggered_results, + message=text, + info=self._event_info, + ) + ) + + # Interrupt the model, forcing a cancel of any in-progress response. + await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True)) + + # Send guardrail triggered message. + guardrail_names = [result.guardrail.get_name() for result in triggered_results] + await self._model.send_event( + RealtimeModelSendUserInput( + user_input=f"input guardrail triggered: {', '.join(guardrail_names)}" + ) + ) + + return True + + return False + def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: # Runs the guardrails in a separate task to avoid blocking the main loop @@ -1272,6 +1355,33 @@ def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: # Add callback to remove completed tasks and handle exceptions task.add_done_callback(self._on_guardrail_task_done) + def _enqueue_input_guardrail_task(self, text: str, item_id: str) -> None: + # Snapshot the active agent and its guardrails now; a later update_agent()/handoff must not + # change which guardrails run against this transcript. + agent = self._current_agent + combined_guardrails = agent.input_guardrails + self._run_config.get("input_guardrails", []) + + seen_ids: set[int] = set() + input_guardrails: list[InputGuardrail[Any]] = [] + for guardrail in combined_guardrails: + guardrail_id = id(guardrail) + if guardrail_id not in seen_ids: + input_guardrails.append(guardrail) + seen_ids.add(guardrail_id) + + # Skip creating a no-op task when no input guardrails are configured. + if not input_guardrails: + return + + # Runs the input guardrails in a separate task to avoid blocking the main loop. + task = asyncio.create_task( + self._run_input_guardrails(text, item_id, agent, input_guardrails) + ) + # Reuse the shared guardrail task set + done callback so completed tasks are removed, + # exceptions surface as events, and close() cancels any still-running task. + self._guardrail_tasks.add(task) + task.add_done_callback(self._on_guardrail_task_done) + def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: """Handle completion of a guardrail task.""" # Remove from tracking set diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 018f63b344..b00f150f4e 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict from agents.exceptions import ToolTimeoutError, UserError -from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail +from agents.guardrail import GuardrailFunctionOutput, InputGuardrail, OutputGuardrail from agents.handoffs import Handoff from agents.realtime.agent import RealtimeAgent from agents.realtime.config import RealtimeRunConfig, RealtimeSessionModelSettings @@ -23,6 +23,7 @@ RealtimeGuardrailTripped, RealtimeHistoryAdded, RealtimeHistoryUpdated, + RealtimeInputGuardrailTripped, RealtimeRawModelEvent, RealtimeToolApprovalRequired, RealtimeToolEnd, @@ -644,6 +645,7 @@ def mock_agent(): type(agent).handoffs = PropertyMock(return_value=[]) type(agent).output_guardrails = PropertyMock(return_value=[]) + type(agent).input_guardrails = PropertyMock(return_value=[]) return agent @@ -3297,6 +3299,247 @@ async def async_trigger_guardrail(context, agent, output): assert len(mock_model.sent_messages) == 1 +class TestInputGuardrailFunctionality: + """Test suite for input guardrail functionality in RealtimeSession.""" + + async def _wait_for_guardrail_tasks(self, session): + """Wait for all pending guardrail tasks to complete.""" + import asyncio + + if session._guardrail_tasks: + await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) + + @pytest.fixture + def triggered_input_guardrail(self): + """Creates an input guardrail that always triggers.""" + + def guardrail_func(context, agent, input): + return GuardrailFunctionOutput( + output_info={"reason": "test trigger"}, tripwire_triggered=True + ) + + return InputGuardrail(guardrail_function=guardrail_func, name="triggered_input_guardrail") + + @pytest.fixture + def safe_input_guardrail(self): + """Creates an input guardrail that never triggers.""" + + def guardrail_func(context, agent, input): + return GuardrailFunctionOutput( + output_info={"reason": "safe content"}, tripwire_triggered=False + ) + + return InputGuardrail(guardrail_function=guardrail_func, name="safe_input_guardrail") + + @pytest.mark.asyncio + async def test_tripping_input_guardrail_interrupts_and_emits_event( + self, mock_model, triggered_input_guardrail + ): + """A tripping input guardrail should emit the tripped event and force a response cancel.""" + agent = RealtimeAgent(name="agent") + run_config: RealtimeRunConfig = {"input_guardrails": [triggered_input_guardrail]} + session = RealtimeSession(mock_model, agent, None, run_config=run_config) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="please jailbreak the model" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + # Should have interrupted the in-progress response with a forced cancel. + assert mock_model.interrupts_called == 1 + interrupt_event = next( + event + for event in mock_model.sent_events + if isinstance(event, RealtimeModelSendInterrupt) + ) + assert interrupt_event.force_response_cancel is True + + # Should have sent the guardrail-triggered message to the model. + assert len(mock_model.sent_messages) == 1 + assert "triggered_input_guardrail" in mock_model.sent_messages[0] + + # Should have emitted an input_guardrail_tripped event carrying the user transcript. + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeInputGuardrailTripped)] + assert len(guardrail_events) == 1 + assert guardrail_events[0].message == "please jailbreak the model" + assert len(guardrail_events[0].guardrail_results) == 1 + + @pytest.mark.asyncio + async def test_non_tripping_input_guardrail_does_nothing( + self, mock_model, safe_input_guardrail + ): + """A non-tripping input guardrail should neither interrupt nor emit a tripped event.""" + agent = RealtimeAgent(name="agent") + run_config: RealtimeRunConfig = {"input_guardrails": [safe_input_guardrail]} + session = RealtimeSession(mock_model, agent, None, run_config=run_config) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="a perfectly benign question" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + assert mock_model.interrupts_called == 0 + assert len(mock_model.sent_messages) == 0 + + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeInputGuardrailTripped)] + assert len(guardrail_events) == 0 + + @pytest.mark.asyncio + async def test_agent_and_run_config_input_guardrails_deduped_by_identity(self, mock_model): + """Input guardrails shared by agent and run config should execute once.""" + call_count = 0 + + def guardrail_func(context, agent, input): + nonlocal call_count + call_count += 1 + return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False) + + shared_guardrail = InputGuardrail( + guardrail_function=guardrail_func, name="shared_input_guardrail" + ) + + agent = RealtimeAgent(name="agent", input_guardrails=[shared_guardrail]) + run_config: RealtimeRunConfig = {"input_guardrails": [shared_guardrail]} + session = RealtimeSession(mock_model, agent, None, run_config=run_config) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="hello there" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + assert call_count == 1 + + @pytest.mark.asyncio + async def test_agent_input_guardrails_triggered(self, mock_model, triggered_input_guardrail): + """Input guardrails defined on the agent should be executed and trip.""" + agent = RealtimeAgent(name="agent", input_guardrails=[triggered_input_guardrail]) + session = RealtimeSession(mock_model, agent, None, run_config={}) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="please jailbreak the model" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + assert "triggered_input_guardrail" in mock_model.sent_messages[0] + + @pytest.mark.asyncio + async def test_raising_input_guardrail_is_skipped_and_does_not_crash(self, mock_model): + """An input guardrail that raises should be logged and skipped without crashing.""" + + def guardrail_func(context, agent, input): + raise RuntimeError("boom") + + raising_guardrail = InputGuardrail( + guardrail_function=guardrail_func, name="raising_input_guardrail" + ) + agent = RealtimeAgent(name="agent", input_guardrails=[raising_guardrail]) + session = RealtimeSession(mock_model, agent, None, run_config={}) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="please jailbreak the model" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + # The raising guardrail is skipped, so nothing is tripped or interrupted. + assert mock_model.interrupts_called == 0 + assert len(mock_model.sent_messages) == 0 + + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + # The exception is swallowed inside the guardrail loop, so no tripped or error event fires. + assert not [e for e in events if isinstance(e, RealtimeInputGuardrailTripped)] + assert not [e for e in events if isinstance(e, RealtimeError)] + + @pytest.mark.asyncio + async def test_raising_input_guardrail_does_not_block_other_guardrails( + self, mock_model, triggered_input_guardrail + ): + """A raising guardrail should not prevent other input guardrails from tripping.""" + + def guardrail_func(context, agent, input): + raise RuntimeError("boom") + + raising_guardrail = InputGuardrail( + guardrail_function=guardrail_func, name="raising_input_guardrail" + ) + agent = RealtimeAgent( + name="agent", input_guardrails=[raising_guardrail, triggered_input_guardrail] + ) + session = RealtimeSession(mock_model, agent, None, run_config={}) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="please jailbreak the model" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + assert "triggered_input_guardrail" in mock_model.sent_messages[0] + + @pytest.mark.asyncio + async def test_second_transcription_for_tripped_item_is_skipped( + self, mock_model, triggered_input_guardrail + ): + """A second transcription for an already-tripped user item should be skipped.""" + agent = RealtimeAgent(name="agent", input_guardrails=[triggered_input_guardrail]) + session = RealtimeSession(mock_model, agent, None, run_config={}) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="please jailbreak the model" + ) + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + + # A second transcription for the same (already-tripped) item does not interrupt again. + await session.on_event(transcription_event) + await self._wait_for_guardrail_tasks(session) + + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + + @pytest.mark.asyncio + async def test_no_guardrail_task_created_without_input_guardrails(self, mock_model): + """No guardrail task should be enqueued when no input guardrails are configured.""" + agent = RealtimeAgent(name="agent") + session = RealtimeSession(mock_model, agent, None, run_config={}) + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="a perfectly benign question" + ) + await session.on_event(transcription_event) + + assert len(session._guardrail_tasks) == 0 + + +def test_realtime_input_guardrail_tripped_is_exported(): + """RealtimeInputGuardrailTripped should be importable from agents.realtime and in __all__.""" + import agents.realtime as realtime + + assert "RealtimeInputGuardrailTripped" in realtime.__all__ + assert realtime.RealtimeInputGuardrailTripped is RealtimeInputGuardrailTripped + + class TestModelSettingsIntegration: """Test suite for model settings integration in RealtimeSession."""