Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,6 @@ def __stream__(self) -> Iterator[_T]:
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
Expand Down Expand Up @@ -166,21 +151,6 @@ async def __stream__(self) -> AsyncIterator[_T]:
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
Expand Down
53 changes: 49 additions & 4 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from openai import OpenAI, AsyncOpenAI
from openai._streaming import Stream, AsyncStream, ServerSentEvent
from openai._exceptions import APIError


@pytest.mark.asyncio
Expand Down Expand Up @@ -104,6 +105,37 @@ def body() -> Iterator[bytes]:
await assert_empty_iter(iterator)


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_thread_events_keep_event_wrapper(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
def body() -> Iterator[bytes]:
yield b"event: thread.message.delta\n"
yield b'data: {"id":"msg_123","delta":{"content":[]}}\n'
yield b"\n"

stream = make_stream(content=body(), sync=sync, client=client, async_client=async_client)

assert await iter_next(stream) == {
"data": {"id": "msg_123", "delta": {"content": []}},
"event": "thread.message.delta",
}
await assert_empty_iter(stream)


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_error_events_raise_api_error(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
def body() -> Iterator[bytes]:
yield b"event: error\n"
yield b'data: {"error":{"message":"boom"}}\n'
yield b"\n"

stream = make_stream(content=body(), sync=sync, client=client, async_client=async_client)

with pytest.raises(APIError, match="boom"):
await iter_next(stream)


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_multiple_data_lines_with_empty_line(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
Expand Down Expand Up @@ -240,9 +272,22 @@ def make_event_iterator(
client: OpenAI,
async_client: AsyncOpenAI,
) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
response = httpx.Response(200, content=content if sync else to_aiter(content), request=httpx.Request("GET", "https://example.com"))
if sync:
return Stream(cast_to=object, client=client, response=response)._iter_events()

return AsyncStream(cast_to=object, client=async_client, response=response)._iter_events()


def make_stream(
content: Iterator[bytes],
*,
sync: bool,
client: OpenAI,
async_client: AsyncOpenAI,
) -> Iterator[object] | AsyncIterator[object]:
response = httpx.Response(200, content=content if sync else to_aiter(content), request=httpx.Request("GET", "https://example.com"))
if sync:
return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()
return iter(Stream(cast_to=object, client=client, response=response))

return AsyncStream(
cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
)._iter_events()
return AsyncStream(cast_to=object, client=async_client, response=response).__stream__()