Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions strands-py/src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ class OpenAIConfig(BaseModelConfig, total=False):
params: Model parameters (e.g., max_tokens).
For a complete list of supported parameters, see
https://platform.openai.com/docs/api-reference/chat/create.
stream: Whether to use OpenAI chat completion streaming. Defaults to True.
"""

model_id: str
params: dict[str, Any] | None
stream: bool

def __init__(
self,
Expand Down Expand Up @@ -500,13 +502,16 @@ def format_request(
TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible
format.
"""
return {
params = dict(cast(dict[str, Any], self.config.get("params") or {}))
stream = bool(self.config.get("stream", params.pop("stream", True)))
stream_options = params.pop("stream_options", {"include_usage": True})

request = {
"messages": self.format_request_messages(
messages, system_prompt, system_prompt_content=system_prompt_content
),
"model": self.config["model_id"],
"stream": True,
"stream_options": {"include_usage": True},
"stream": stream,
"tools": [
{
"type": "function",
Expand All @@ -519,9 +524,14 @@ def format_request(
for tool_spec in tool_specs or []
],
**(self._format_request_tool_choice(tool_choice)),
**cast(dict[str, Any], self.config.get("params", {})),
**params,
}

if stream:
request["stream_options"] = stream_options

return request

def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
"""Format an OpenAI response event into a standardized message chunk.

Expand Down Expand Up @@ -601,6 +611,44 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

def _format_non_streaming_response(self, response: Any) -> list[StreamEvent]:
"""Convert a non-streaming OpenAI chat completion into Strands stream events."""
chunks = [self.format_chunk({"chunk_type": "message_start"})]
choices = getattr(response, "choices", None) or []
choice = choices[0] if choices else None
message = getattr(choice, "message", None)

reasoning_content = getattr(message, "reasoning_content", None) or getattr(message, "reasoning", None)
if reasoning_content:
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}))
chunks.append(
self.format_chunk(
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": reasoning_content}
)
)
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}))

if content := getattr(message, "content", None):
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": "text"}))
chunks.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content}))
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}))

for tool_call in getattr(message, "tool_calls", None) or []:
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}))
chunks.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}))
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}))

chunks.append(
self.format_chunk(
{"chunk_type": "message_stop", "data": getattr(choice, "finish_reason", None) or "end_turn"}
)
)

if usage := getattr(response, "usage", None):
chunks.append(self.format_chunk({"chunk_type": "metadata", "data": usage}))

return chunks

@asynccontextmanager
async def _get_client(self) -> AsyncIterator[Any]:
"""Get an OpenAI client for making requests.
Expand Down Expand Up @@ -687,6 +735,11 @@ async def stream(
# Re-raise other APIError exceptions
raise

if not request["stream"]:
for chunk in self._format_non_streaming_response(response):
yield chunk
return

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
tool_calls: dict[int, list[Any]] = {}
Expand Down
60 changes: 60 additions & 0 deletions strands-py/tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,31 @@ def test_format_request(model, messages, tool_specs, system_prompt):
assert tru_request == exp_request


def test_format_request_can_disable_stream(openai_client, model_id, messages):
_ = openai_client
model = OpenAIModel(model_id=model_id, stream=False, params={"max_tokens": 1})

tru_request = model.format_request(messages)
exp_request = {
"messages": [{"role": "user", "content": [{"text": "test", "type": "text"}]}],
"model": model_id,
"stream": False,
"tools": [],
"max_tokens": 1,
}
assert tru_request == exp_request


def test_format_request_respects_legacy_stream_param(openai_client, model_id, messages):
_ = openai_client
model = OpenAIModel(model_id=model_id, params={"max_tokens": 1, "stream": False})

tru_request = model.format_request(messages)

assert tru_request["stream"] is False
assert "stream_options" not in tru_request


def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt):
tool_choice = {"auto": {}}
tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice)
Expand Down Expand Up @@ -1143,6 +1168,41 @@ async def test_stream_with_empty_choices(openai_client, model, agenerator, alist
openai_client.chat.completions.create.assert_called_once_with(**expected_request)


@pytest.mark.asyncio
async def test_stream_can_use_non_streaming_chat_completion(openai_client, model_id, messages, alist):
mock_usage = unittest.mock.Mock(prompt_tokens=7, completion_tokens=3, total_tokens=10, prompt_tokens_details=None)
mock_message = unittest.mock.Mock(content="done", tool_calls=None, reasoning_content=None, reasoning=None)
mock_choice = unittest.mock.Mock(message=mock_message, finish_reason="stop")
mock_response = unittest.mock.Mock(choices=[mock_choice], usage=mock_usage)

openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response)
model = OpenAIModel(model_id=model_id, stream=False, params={"max_tokens": 1})

tru_events = await alist(model.stream(messages))
exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "done"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
{
"metadata": {
"usage": {"inputTokens": 7, "outputTokens": 3, "totalTokens": 10},
"metrics": {"latencyMs": 0},
}
},
]

assert tru_events == exp_events
openai_client.chat.completions.create.assert_called_once_with(
max_tokens=1,
model=model_id,
messages=[{"role": "user", "content": [{"text": "test", "type": "text"}]}],
stream=False,
tools=[],
)


@pytest.mark.asyncio
async def test_structured_output(openai_client, model, test_output_model_cls, alist):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
Expand Down