From 8d233a079a7f71ee408ca37fea353ea1e6bdb689 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 19 Feb 2026 09:20:45 -0800 Subject: [PATCH 1/2] Add batch-based streaming support for OpenAI Agents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements streaming API using a list-based approach where: - Stream events are collected during activity execution - Complete list is returned when activity finishes - Workflows can iterate over events deterministically - No real-time signaling to maintain workflow determinism Changes: - Add batch_stream_model activity for collecting streaming events - Implement stream_response method in _TemporalModelStub - Update TemporalOpenAIRunner to support run_streamed - Add streaming documentation and update feature support - Refactor common code into helper functions 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- temporalio/contrib/openai_agents/README.md | 31 ++- .../openai_agents/_invoke_model_activity.py | 214 +++++++++++------- .../contrib/openai_agents/_openai_runner.py | 132 ++++++----- .../openai_agents/_temporal_model_stub.py | 69 +++++- .../openai_agents/_temporal_openai_agents.py | 6 +- 5 files changed, 301 insertions(+), 151 deletions(-) diff --git a/temporalio/contrib/openai_agents/README.md b/temporalio/contrib/openai_agents/README.md index 45aa51fb6..ba8077767 100644 --- a/temporalio/contrib/openai_agents/README.md +++ b/temporalio/contrib/openai_agents/README.md @@ -451,10 +451,35 @@ To recover from such failures, you need to implement your own application-level For network-accessible MCP servers, you can also use `HostedMCPTool` from the OpenAI Agents SDK, which uses an MCP client hosted by OpenAI. +## Streaming + +Streaming can be enabled by using the Agent SDK's `Runner.run_streamed` API. This integration provides streaming content from LLMs by collecting all streaming events into a list and delivering them when the activity completes, allowing workflows to iterate over the stream events. + +```python +from agents import Runner + +# In your workflow +async for event in Runner.run_streamed( + starting_agent=my_agent, + input="Hello, stream this response!", +).stream_events(): + # Process each streaming event + if hasattr(event, 'content') and event.content: + print(f"Streamed content: {event.content}") +``` + +The streaming implementation: +- Collects all stream events during the activity execution +- Returns the complete list when the activity finishes +- Allows workflows to iterate over events deterministically +- Supports the same model configurations and tools as non-streaming calls + +Note that stream events are only delivered to the workflow after the entire LLM response is complete, ensuring deterministic execution in Temporal workflows. + ## Feature Support This integration is presently subject to certain limitations. -Streaming and voice agents are not supported. +Voice agents are not supported. Certain tools are not suitable for a distributed computing environment, so these have been disabled as well. ### Model Providers @@ -466,12 +491,10 @@ Certain tools are not suitable for a distributed computing environment, so these ### Model Response format -This integration does not presently support streaming. - | Model Response | Supported | | :------------- | :-------: | | Get Response | Yes | -| Streaming | No | +| Streaming | Yes | ### Tools diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 945a05ec6..04f15949d 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -6,7 +6,7 @@ import enum from dataclasses import dataclass from datetime import timedelta -from typing import Any +from typing import Any, NoReturn from agents import ( AgentOutputSchemaBase, @@ -27,6 +27,7 @@ UserError, WebSearchTool, ) +from agents.items import TResponseStreamEvent from openai import ( APIStatusError, AsyncOpenAI, @@ -163,54 +164,8 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons """Activity that invokes a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) - async def empty_on_invoke_tool( - _ctx: RunContextWrapper[Any], _input: str - ) -> str: - return "" - - async def empty_on_invoke_handoff( - _ctx: RunContextWrapper[Any], _input: str - ) -> Any: - return None - - def make_tool(tool: ToolInput) -> Tool: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPToolInput): - return HostedMCPTool( - tool_config=tool.tool_config, - ) - elif isinstance(tool, FunctionToolInput): - return FunctionTool( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - on_invoke_tool=empty_on_invoke_tool, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise UserError(f"Unknown tool type: {tool.name}") # type:ignore[reportUnreachable] - - tools = [make_tool(x) for x in input.get("tools", [])] - handoffs: list[Handoff[Any, Any]] = [ - Handoff( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - on_invoke_handoff=empty_on_invoke_handoff, - ) - for x in input.get("handoffs", []) - ] + tools = _make_tools(input) + handoffs = _make_handoffs(input) try: return await model.get_response( @@ -226,40 +181,127 @@ def make_tool(tool: ToolInput) -> Tool: prompt=input.get("prompt"), ) except APIStatusError as e: - # Listen to server hints - retry_after = None - retry_after_ms_header = e.response.headers.get("retry-after-ms") - if retry_after_ms_header is not None: - retry_after = timedelta(milliseconds=float(retry_after_ms_header)) - - if retry_after is None: - retry_after_header = e.response.headers.get("retry-after") - if retry_after_header is not None: - retry_after = timedelta(seconds=float(retry_after_header)) - - should_retry_header = e.response.headers.get("x-should-retry") - if should_retry_header == "true": - raise e - if should_retry_header == "false": - raise ApplicationError( - "Non retryable OpenAI error", - non_retryable=True, - next_retry_delay=retry_after, - ) from e - - # Specifically retryable status codes - if ( - e.response.status_code in [408, 409, 429] - or e.response.status_code >= 500 - ): - raise ApplicationError( - f"Retryable OpenAI status code: {e.response.status_code}", - non_retryable=False, - next_retry_delay=retry_after, - ) from e - - raise ApplicationError( - f"Non retryable OpenAI status code: {e.response.status_code}", - non_retryable=True, - next_retry_delay=retry_after, - ) from e + _handle_error(e) + + @activity.defn + @_auto_heartbeater + async def batch_stream_model( + self, input: ActivityModelInput + ) -> list[TResponseStreamEvent]: + """Activity that streams a model with the given input, returning all events as a list.""" + model = self._model_provider.get_model(input.get("model_name")) + + tools = _make_tools(input) + handoffs = _make_handoffs(input) + + try: + events = model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ) + result = [] + async for event in events: + event.model_rebuild() + result.append(event) + + return result + except APIStatusError as e: + _handle_error(e) + + +async def _empty_on_invoke_tool(_ctx: RunContextWrapper[Any], _input: str) -> str: + return "" + + +async def _empty_on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Any: + return None + + +def _make_tool(tool: ToolInput) -> Tool: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPToolInput): + return HostedMCPTool( + tool_config=tool.tool_config, + ) + elif isinstance(tool, FunctionToolInput): + return FunctionTool( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + on_invoke_tool=_empty_on_invoke_tool, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise UserError(f"Unknown tool type: {tool.name}") # type: ignore[reportUnreachable] + + +def _make_tools(input: ActivityModelInput) -> list[Tool]: + return [_make_tool(x) for x in input.get("tools", [])] + + +def _make_handoffs(input: ActivityModelInput) -> list[Handoff[Any, Any]]: + return [ + Handoff( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + on_invoke_handoff=_empty_on_invoke_handoff, + ) + for x in input.get("handoffs", []) + ] + + +def _handle_error(e: APIStatusError) -> NoReturn: + # Listen to server hints + retry_after = None + retry_after_ms_header = e.response.headers.get("retry-after-ms") + if retry_after_ms_header is not None: + retry_after = timedelta(milliseconds=float(retry_after_ms_header)) + + if retry_after is None: + retry_after_header = e.response.headers.get("retry-after") + if retry_after_header is not None: + retry_after = timedelta(seconds=float(retry_after_header)) + + should_retry_header = e.response.headers.get("x-should-retry") + if should_retry_header == "true": + raise e + if should_retry_header == "false": + raise ApplicationError( + "Non retryable OpenAI error", + non_retryable=True, + next_retry_delay=retry_after, + ) from e + + # Specifically retryable status codes + if e.response.status_code in [408, 409, 429] or e.response.status_code >= 500: + raise ApplicationError( + f"Retryable OpenAI status code: {e.response.status_code}", + non_retryable=False, + next_retry_delay=retry_after, + ) from e + + raise ApplicationError( + f"Non retryable OpenAI status code: {e.response.status_code}", + non_retryable=True, + next_retry_delay=retry_after, + ) from e diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index 30e27f061..601b70c26 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -14,7 +14,7 @@ TContext, TResponseInputItem, ) -from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner +from agents.run import DEFAULT_AGENT_RUNNER, AgentRunner from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters @@ -104,65 +104,15 @@ async def run( **kwargs, ) - for t in starting_agent.tools: - if callable(t): - raise ValueError( - "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." - ) - - if starting_agent.mcp_servers: - from temporalio.contrib.openai_agents._mcp import ( - _StatefulMCPServerReference, - _StatelessMCPServerReference, - ) - - for s in starting_agent.mcp_servers: - if not isinstance( - s, - ( - _StatelessMCPServerReference, - _StatefulMCPServerReference, - ), - ): - raise ValueError( - f"Unknown mcp_server type {type(s)} may not work durably." - ) - - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - session = kwargs.get("session") - - if isinstance(session, SQLiteSession): - raise ValueError("Temporal workflows don't support SQLite sessions.") - - if run_config is None: - run_config = RunConfig() + _check_preconditions(starting_agent, **kwargs) - if run_config.model: - if not isinstance(run_config.model, str): - raise ValueError( - "Temporal workflows require a model name to be a string in the run config." - ) - run_config = dataclasses.replace( - run_config, - model=_TemporalModelStub( - run_config.model, model_params=self.model_params, agent=None - ), - ) + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) try: return await self._runner.run( starting_agent=_convert_agent(self.model_params, starting_agent, None), input=input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - session=session, + **kwargs, ) except AgentsException as e: # In order for workflow failures to properly fail the workflow, we need to rewrap them in @@ -176,6 +126,25 @@ async def run( else: raise e + def _process_run_config(self, run_config: RunConfig | None) -> RunConfig: + if run_config is None: + run_config = RunConfig() + + if run_config.model: + if not isinstance(run_config.model, str): + raise ValueError( + "Temporal workflows require a model name to be a string in the run config." + ) + run_config = dataclasses.replace( + run_config, + model=_TemporalModelStub( + run_config.model, + model_params=self.model_params, + agent=None, + ), + ) + return run_config + def run_sync( self, starting_agent: Agent[TContext], @@ -197,14 +166,35 @@ def run_streamed( input: str | list[TResponseInputItem], **kwargs: Any, ) -> RunResultStreaming: - """Run the agent with streaming responses (not supported in Temporal workflows).""" + """Run the agent with streaming responses.""" if not workflow.in_workflow(): return self._runner.run_streamed( starting_agent, input, **kwargs, ) - raise RuntimeError("Temporal workflows do not support streaming.") + + _check_preconditions(starting_agent, **kwargs) + + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) + + try: + return self._runner.run_streamed( + starting_agent=_convert_agent(self.model_params, starting_agent, None), + input=input, + **kwargs, + ) + except AgentsException as e: + # In order for workflow failures to properly fail the workflow, we need to rewrap them in + # a Temporal error + if e.__cause__ and workflow.is_failure_exception(e.__cause__): + reraise = AgentsWorkflowError( + f"Workflow failure exception in Agents Framework: {e}" + ) + reraise.__traceback__ = e.__traceback__ + raise reraise from e.__cause__ + else: + raise e def _model_name(agent: Agent[Any]) -> str | None: @@ -214,3 +204,33 @@ def _model_name(agent: Agent[Any]) -> str | None: "Temporal workflows require a model name to be a string in the agent." ) return name + + +def _check_preconditions(starting_agent: Agent[TContext], **kwargs: Any) -> None: + for t in starting_agent.tools: + if callable(t): + raise ValueError( + "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." + ) + + if starting_agent.mcp_servers: + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + _StatelessMCPServerReference, + ) + + for s in starting_agent.mcp_servers: + if not isinstance( + s, + ( + _StatelessMCPServerReference, + _StatefulMCPServerReference, + ), + ): + raise ValueError( + f"Unknown mcp_server type {type(s)} may not work durably." + ) + + session = kwargs.get("session") + if isinstance(session, SQLiteSession): + raise ValueError("Temporal workflows don't support SQLite sessions.") diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index f55821309..d1b66e4dd 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -71,6 +71,38 @@ async def get_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> ModelResponse: + activity_input, summary = self._prepare_activity_input( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + return await self._execute_activity( + ModelActivity.invoke_model_activity, + activity_input, + summary, + ) + + def _prepare_activity_input( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> tuple[ActivityModelInput, str | None]: def make_tool_info(tool: Tool) -> ToolInput: if isinstance( tool, @@ -154,9 +186,17 @@ def make_tool_info(tool: Tool) -> ToolInput: else: summary = None + return activity_input, summary + + async def _execute_activity( + self, + activity_method: Any, + activity_input: ActivityModelInput, + summary: str | None, + ) -> Any: if self.model_params.use_local_activity: return await workflow.execute_local_activity_method( - ModelActivity.invoke_model_activity, + activity_method, activity_input, summary=summary, schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, @@ -167,7 +207,7 @@ def make_tool_info(tool: Tool) -> ToolInput: ) else: return await workflow.execute_activity_method( - ModelActivity.invoke_model_activity, + activity_method, activity_input, summary=summary, task_queue=self.model_params.task_queue, @@ -181,7 +221,7 @@ def make_tool_info(tool: Tool) -> ToolInput: priority=self.model_params.priority, ) - def stream_response( + async def stream_response( self, system_instructions: str | None, input: str | list[TResponseInputItem], @@ -195,4 +235,25 @@ def stream_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: - raise NotImplementedError("Temporal model doesn't support streams yet") + activity_input, summary = self._prepare_activity_input( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + events = await self._execute_activity( + ModelActivity.batch_stream_model, + activity_input, + summary, + ) + + # Convert the list of events into an async iterator + for event in events: + yield event diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 16a1403ef..a0c7a7c08 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -229,7 +229,11 @@ def add_activities( if not register_activities: return activities or [] - new_activities = [ModelActivity(model_provider).invoke_model_activity] + model_activity = ModelActivity(model_provider) + new_activities = [ + model_activity.invoke_model_activity, + model_activity.batch_stream_model, + ] server_names = [server.name for server in mcp_server_providers] if len(server_names) != len(set(server_names)): From 49378add79212b914119f46a39ee32ee7adc85e2 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 19 Feb 2026 09:28:32 -0800 Subject: [PATCH 2/2] Add tests for batch streaming functionality MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add EventBuilders class for creating test streaming events - Update TestModel to support streaming with streaming_fn parameter - Add streaming factory methods: streaming_events and streaming_events_with_ending - Add test_batch_streaming to verify streaming works end-to-end - Update testing module exports to include EventBuilders 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- temporalio/contrib/openai_agents/testing.py | 145 +++++++++++++++++++- tests/contrib/openai_agents/test_openai.py | 69 ++++++++++ 2 files changed, 211 insertions(+), 3 deletions(-) diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py index 45a7a1465..abac16cb2 100644 --- a/temporalio/contrib/openai_agents/testing.py +++ b/temporalio/contrib/openai_agents/testing.py @@ -17,9 +17,14 @@ ) from agents.items import TResponseOutputItem, TResponseStreamEvent from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartDoneEvent, ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from temporalio.client import Client @@ -32,6 +37,7 @@ __all__ = [ "AgentEnvironment", + "EventBuilders", "ResponseBuilders", "TestModel", "TestModelProvider", @@ -109,6 +115,100 @@ def output_message(text: str) -> ModelResponse: ) +class EventBuilders: + """Builders for creating stream events for testing. + + .. warning:: + This API is experimental and may change in the future. + """ + + @staticmethod + def text_delta(text: str) -> ResponseTextDeltaEvent: + """Create a TResponseStreamEvent with an text delta. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseTextDeltaEvent( + content_index=0, + delta=text, + item_id="", + logprobs=[], + output_index=0, + sequence_number=0, + type="response.output_text.delta", + ) + + @staticmethod + def content_part_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for content part completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseContentPartDoneEvent( + content_index=0, + item_id="", + output_index=0, + sequence_number=0, + type="response.content_part.done", + part=ResponseOutputText( + text=text, + annotations=[], + type="output_text", + ), + ) + + @staticmethod + def output_item_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for output item completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseOutputItemDoneEvent( + output_index=0, + sequence_number=0, + type="response.output_item.done", + item=ResponseBuilders.response_output_message(text), + ) + + @staticmethod + def response_completion(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for response completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseCompletedEvent( + response=Response( + id="", + created_at=0.0, + object="response", + model="", + parallel_tool_calls=False, + tool_choice="none", + tools=[], + output=[ResponseBuilders.response_output_message(text)], + ), + sequence_number=0, + type="response.completed", + ) + + @staticmethod + def ending(text: str) -> list[TResponseStreamEvent]: + """Create a list of TResponseStreamEvent for the end of a stream. + + .. warning:: + This API is experimental and may change in the future. + """ + return [ + EventBuilders.content_part_done(text), + EventBuilders.output_item_done(text), + EventBuilders.response_completion(text), + ] + + class TestModelProvider(ModelProvider): """Test model provider which simply returns the given module. @@ -144,13 +244,19 @@ class TestModel(Model): __test__ = False - def __init__(self, fn: Callable[[], ModelResponse]) -> None: + def __init__( + self, + fn: Callable[[], ModelResponse] | None, + *, + streaming_fn: Callable[[], AsyncIterator[TResponseStreamEvent]] | None = None, + ) -> None: """Initialize a test model with a callable. .. warning:: This API is experimental and may change in the future. """ self.fn = fn + self.streaming_fn = streaming_fn async def get_response( self, @@ -164,6 +270,8 @@ async def get_response( **kwargs: Any, ) -> ModelResponse: """Get a response from the mocked model, by calling the callable passed to the constructor.""" + if self.fn is None: + raise ValueError("No non-streaming function provided") return self.fn() def stream_response( @@ -177,8 +285,10 @@ def stream_response( tracing: ModelTracing, **kwargs: Any, ) -> AsyncIterator[TResponseStreamEvent]: - """Get a streamed response from the model. Unimplemented.""" - raise NotImplementedError() + """Get a streamed response from the model.""" + if self.streaming_fn is None: + raise ValueError("No streaming function provided") + return self.streaming_fn() @staticmethod def returning_responses(responses: list[ModelResponse]) -> "TestModel": @@ -190,6 +300,35 @@ def returning_responses(responses: list[ModelResponse]) -> "TestModel": i = iter(responses) return TestModel(lambda: next(i)) + @staticmethod + def streaming_events(events: list[TResponseStreamEvent]) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. + + .. warning:: + This API is experimental and may change in the future. + """ + + async def generator(): + for event in events: + yield event + + return TestModel(None, streaming_fn=lambda: generator()) + + @staticmethod + def streaming_events_with_ending( + events: list[ResponseTextDeltaEvent], + ) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. Appends ending markers + + .. warning:: + This API is experimental and may change in the future. + """ + content = "" + for event in events: + content += event.delta + + return TestModel.streaming_events(events + EventBuilders.ending(content)) + class AgentEnvironment: """Testing environment for OpenAI agents with Temporal integration. diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 7a19afcc8..8dfc32903 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -63,6 +63,7 @@ ResponseCodeInterpreterToolCall, ResponseFileSearchToolCall, ResponseFunctionWebSearch, + ResponseTextDeltaEvent, ) from openai.types.responses.response_file_search_tool_call import Result from openai.types.responses.response_function_web_search import ActionSearch @@ -90,6 +91,7 @@ ) from temporalio.contrib.openai_agents.testing import ( AgentEnvironment, + EventBuilders, ResponseBuilders, TestModel, TestModelProvider, @@ -2690,3 +2692,70 @@ async def test_multiple_handoffs_workflow(client: Client): "I'll analyze the requirements and create a plan." in planner_response_data ) + + +@workflow.defn +class StreamingBatchTestWorkflow: + def __init__(self): + self.events = [] + + @workflow.run + async def run(self, prompt: str) -> str | None: + agent = Agent[None]( + name="Assistant", + instructions="You are a helpful assistant.", + ) + + result = Runner.run_streamed( + starting_agent=agent, + input=prompt, + ) + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + self.events.append(event.data.delta) + + return result.final_output if result else None + + @workflow.query + def get_events(self) -> list[str]: + return self.events + + +def streaming_hello_model(): + return TestModel.streaming_events_with_ending( + [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ] + ) + + +async def test_batch_streaming(client: Client): + async with AgentEnvironment( + model=streaming_hello_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingBatchTestWorkflow, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingBatchTestWorkflow.run, + args=["Say hello."], + id=f"streaming-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" + + # Verify we collected the streaming events + events = await handle.query(StreamingBatchTestWorkflow.get_events) + assert len(events) == 3 + assert events == ["Hello", " there", "!"]