Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions temporalio/contrib/openai_agents/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,35 @@ To recover from such failures, you need to implement your own application-level

For network-accessible MCP servers, you can also use `HostedMCPTool` from the OpenAI Agents SDK, which uses an MCP client hosted by OpenAI.

## Streaming

Streaming can be enabled by using the Agent SDK's `Runner.run_streamed` API. This integration provides streaming content from LLMs by collecting all streaming events into a list and delivering them when the activity completes, allowing workflows to iterate over the stream events.

```python
from agents import Runner

# In your workflow
async for event in Runner.run_streamed(
starting_agent=my_agent,
input="Hello, stream this response!",
).stream_events():
# Process each streaming event
if hasattr(event, 'content') and event.content:
print(f"Streamed content: {event.content}")
```

The streaming implementation:
- Collects all stream events during the activity execution
- Returns the complete list when the activity finishes
- Allows workflows to iterate over events deterministically
- Supports the same model configurations and tools as non-streaming calls

Note that stream events are only delivered to the workflow after the entire LLM response is complete, ensuring deterministic execution in Temporal workflows.

## Feature Support

This integration is presently subject to certain limitations.
Streaming and voice agents are not supported.
Voice agents are not supported.
Certain tools are not suitable for a distributed computing environment, so these have been disabled as well.

### Model Providers
Expand All @@ -466,12 +491,10 @@ Certain tools are not suitable for a distributed computing environment, so these

### Model Response format

This integration does not presently support streaming.

| Model Response | Supported |
| :------------- | :-------: |
| Get Response | Yes |
| Streaming | No |
| Streaming | Yes |

### Tools

Expand Down
214 changes: 128 additions & 86 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import enum
from dataclasses import dataclass
from datetime import timedelta
from typing import Any
from typing import Any, NoReturn

from agents import (
AgentOutputSchemaBase,
Expand All @@ -27,6 +27,7 @@
UserError,
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from openai import (
APIStatusError,
AsyncOpenAI,
Expand Down Expand Up @@ -163,54 +164,8 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons
"""Activity that invokes a model with the given input."""
model = self._model_provider.get_model(input.get("model_name"))

async def empty_on_invoke_tool(
_ctx: RunContextWrapper[Any], _input: str
) -> str:
return ""

async def empty_on_invoke_handoff(
_ctx: RunContextWrapper[Any], _input: str
) -> Any:
return None

def make_tool(tool: ToolInput) -> Tool:
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(
tool_config=tool.tool_config,
)
elif isinstance(tool, FunctionToolInput):
return FunctionTool(
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=empty_on_invoke_tool,
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}") # type:ignore[reportUnreachable]

tools = [make_tool(x) for x in input.get("tools", [])]
handoffs: list[Handoff[Any, Any]] = [
Handoff(
tool_name=x.tool_name,
tool_description=x.tool_description,
input_json_schema=x.input_json_schema,
agent_name=x.agent_name,
strict_json_schema=x.strict_json_schema,
on_invoke_handoff=empty_on_invoke_handoff,
)
for x in input.get("handoffs", [])
]
tools = _make_tools(input)
handoffs = _make_handoffs(input)

try:
return await model.get_response(
Expand All @@ -226,40 +181,127 @@ def make_tool(tool: ToolInput) -> Tool:
prompt=input.get("prompt"),
)
except APIStatusError as e:
# Listen to server hints
retry_after = None
retry_after_ms_header = e.response.headers.get("retry-after-ms")
if retry_after_ms_header is not None:
retry_after = timedelta(milliseconds=float(retry_after_ms_header))

if retry_after is None:
retry_after_header = e.response.headers.get("retry-after")
if retry_after_header is not None:
retry_after = timedelta(seconds=float(retry_after_header))

should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable OpenAI error",
non_retryable=True,
next_retry_delay=retry_after,
) from e

# Specifically retryable status codes
if (
e.response.status_code in [408, 409, 429]
or e.response.status_code >= 500
):
raise ApplicationError(
f"Retryable OpenAI status code: {e.response.status_code}",
non_retryable=False,
next_retry_delay=retry_after,
) from e

raise ApplicationError(
f"Non retryable OpenAI status code: {e.response.status_code}",
non_retryable=True,
next_retry_delay=retry_after,
) from e
_handle_error(e)

@activity.defn
@_auto_heartbeater
async def batch_stream_model(
self, input: ActivityModelInput
) -> list[TResponseStreamEvent]:
"""Activity that streams a model with the given input, returning all events as a list."""
model = self._model_provider.get_model(input.get("model_name"))

tools = _make_tools(input)
handoffs = _make_handoffs(input)

try:
events = model.stream_response(
system_instructions=input.get("system_instructions"),
input=input["input"],
model_settings=input["model_settings"],
tools=tools,
output_schema=input.get("output_schema"),
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
conversation_id=input.get("conversation_id"),
prompt=input.get("prompt"),
)
result = []
async for event in events:
event.model_rebuild()
result.append(event)

return result
except APIStatusError as e:
_handle_error(e)


async def _empty_on_invoke_tool(_ctx: RunContextWrapper[Any], _input: str) -> str:
return ""


async def _empty_on_invoke_handoff(_ctx: RunContextWrapper[Any], _input: str) -> Any:
return None


def _make_tool(tool: ToolInput) -> Tool:
if isinstance(
tool,
(
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
),
):
return tool
elif isinstance(tool, HostedMCPToolInput):
return HostedMCPTool(
tool_config=tool.tool_config,
)
elif isinstance(tool, FunctionToolInput):
return FunctionTool(
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
on_invoke_tool=_empty_on_invoke_tool,
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}") # type: ignore[reportUnreachable]


def _make_tools(input: ActivityModelInput) -> list[Tool]:
return [_make_tool(x) for x in input.get("tools", [])]


def _make_handoffs(input: ActivityModelInput) -> list[Handoff[Any, Any]]:
return [
Handoff(
tool_name=x.tool_name,
tool_description=x.tool_description,
input_json_schema=x.input_json_schema,
agent_name=x.agent_name,
strict_json_schema=x.strict_json_schema,
on_invoke_handoff=_empty_on_invoke_handoff,
)
for x in input.get("handoffs", [])
]


def _handle_error(e: APIStatusError) -> NoReturn:
# Listen to server hints
retry_after = None
retry_after_ms_header = e.response.headers.get("retry-after-ms")
if retry_after_ms_header is not None:
retry_after = timedelta(milliseconds=float(retry_after_ms_header))

if retry_after is None:
retry_after_header = e.response.headers.get("retry-after")
if retry_after_header is not None:
retry_after = timedelta(seconds=float(retry_after_header))

should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable OpenAI error",
non_retryable=True,
next_retry_delay=retry_after,
) from e

# Specifically retryable status codes
if e.response.status_code in [408, 409, 429] or e.response.status_code >= 500:
raise ApplicationError(
f"Retryable OpenAI status code: {e.response.status_code}",
non_retryable=False,
next_retry_delay=retry_after,
) from e

raise ApplicationError(
f"Non retryable OpenAI status code: {e.response.status_code}",
non_retryable=True,
next_retry_delay=retry_after,
) from e
Loading
Loading