diff --git a/strands-py/src/strands/models/openai.py b/strands-py/src/strands/models/openai.py index 09452ec77..46aed908f 100644 --- a/strands-py/src/strands/models/openai.py +++ b/strands-py/src/strands/models/openai.py @@ -64,10 +64,12 @@ class OpenAIConfig(BaseModelConfig, total=False): params: Model parameters (e.g., max_tokens). For a complete list of supported parameters, see https://platform.openai.com/docs/api-reference/chat/create. + stream: Whether to use OpenAI chat completion streaming. Defaults to True. """ model_id: str params: dict[str, Any] | None + stream: bool def __init__( self, @@ -500,13 +502,16 @@ def format_request( TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible format. """ - return { + params = dict(cast(dict[str, Any], self.config.get("params") or {})) + stream = bool(self.config.get("stream", params.pop("stream", True))) + stream_options = params.pop("stream_options", {"include_usage": True}) + + request = { "messages": self.format_request_messages( messages, system_prompt, system_prompt_content=system_prompt_content ), "model": self.config["model_id"], - "stream": True, - "stream_options": {"include_usage": True}, + "stream": stream, "tools": [ { "type": "function", @@ -519,9 +524,14 @@ def format_request( for tool_spec in tool_specs or [] ], **(self._format_request_tool_choice(tool_choice)), - **cast(dict[str, Any], self.config.get("params", {})), + **params, } + if stream: + request["stream_options"] = stream_options + + return request + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. @@ -601,6 +611,44 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + def _format_non_streaming_response(self, response: Any) -> list[StreamEvent]: + """Convert a non-streaming OpenAI chat completion into Strands stream events.""" + chunks = [self.format_chunk({"chunk_type": "message_start"})] + choices = getattr(response, "choices", None) or [] + choice = choices[0] if choices else None + message = getattr(choice, "message", None) + + reasoning_content = getattr(message, "reasoning_content", None) or getattr(message, "reasoning", None) + if reasoning_content: + chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"})) + chunks.append( + self.format_chunk( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": reasoning_content} + ) + ) + chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})) + + if content := getattr(message, "content", None): + chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": "text"})) + chunks.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content})) + chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})) + + for tool_call in getattr(message, "tool_calls", None) or []: + chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call})) + chunks.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call})) + chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})) + + chunks.append( + self.format_chunk( + {"chunk_type": "message_stop", "data": getattr(choice, "finish_reason", None) or "end_turn"} + ) + ) + + if usage := getattr(response, "usage", None): + chunks.append(self.format_chunk({"chunk_type": "metadata", "data": usage})) + + return chunks + @asynccontextmanager async def _get_client(self) -> AsyncIterator[Any]: """Get an OpenAI client for making requests. @@ -687,6 +735,11 @@ async def stream( # Re-raise other APIError exceptions raise + if not request["stream"]: + for chunk in self._format_non_streaming_response(response): + yield chunk + return + logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) tool_calls: dict[int, list[Any]] = {} diff --git a/strands-py/tests/strands/models/test_openai.py b/strands-py/tests/strands/models/test_openai.py index 3746ec334..7fe7d2c5b 100644 --- a/strands-py/tests/strands/models/test_openai.py +++ b/strands-py/tests/strands/models/test_openai.py @@ -661,6 +661,31 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_can_disable_stream(openai_client, model_id, messages): + _ = openai_client + model = OpenAIModel(model_id=model_id, stream=False, params={"max_tokens": 1}) + + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": [{"text": "test", "type": "text"}]}], + "model": model_id, + "stream": False, + "tools": [], + "max_tokens": 1, + } + assert tru_request == exp_request + + +def test_format_request_respects_legacy_stream_param(openai_client, model_id, messages): + _ = openai_client + model = OpenAIModel(model_id=model_id, params={"max_tokens": 1, "stream": False}) + + tru_request = model.format_request(messages) + + assert tru_request["stream"] is False + assert "stream_options" not in tru_request + + def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): tool_choice = {"auto": {}} tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) @@ -1143,6 +1168,41 @@ async def test_stream_with_empty_choices(openai_client, model, agenerator, alist openai_client.chat.completions.create.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_can_use_non_streaming_chat_completion(openai_client, model_id, messages, alist): + mock_usage = unittest.mock.Mock(prompt_tokens=7, completion_tokens=3, total_tokens=10, prompt_tokens_details=None) + mock_message = unittest.mock.Mock(content="done", tool_calls=None, reasoning_content=None, reasoning=None) + mock_choice = unittest.mock.Mock(message=mock_message, finish_reason="stop") + mock_response = unittest.mock.Mock(choices=[mock_choice], usage=mock_usage) + + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + model = OpenAIModel(model_id=model_id, stream=False, params={"max_tokens": 1}) + + tru_events = await alist(model.stream(messages)) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "done"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 7, "outputTokens": 3, "totalTokens": 10}, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with( + max_tokens=1, + model=model_id, + messages=[{"role": "user", "content": [{"text": "test", "type": "text"}]}], + stream=False, + tools=[], + ) + + @pytest.mark.asyncio async def test_structured_output(openai_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]