Skip to content

Commit 2f771af

Browse files
committed
Stream response body in ASGITransport
Fixes #2186
1 parent 10b7295 commit 2f771af

File tree

2 files changed

+191
-8
lines changed

2 files changed

+191
-8
lines changed

httpx/_transports/asgi.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import types
34
import typing
45

56
from .._models import Request, Response
@@ -52,12 +53,75 @@ def create_event() -> Event:
5253
return asyncio.Event()
5354

5455

56+
class _AwaitableRunner:
57+
def __init__(self, awaitable: typing.Awaitable[typing.Any]):
58+
self._generator = awaitable.__await__()
59+
self._started = False
60+
self._next_item: typing.Any = None
61+
self._finished = False
62+
63+
@types.coroutine
64+
def __call__(
65+
self, *, until: typing.Optional[typing.Callable[[], bool]] = None
66+
) -> typing.Generator[typing.Any, typing.Any, typing.Any]:
67+
while not self._finished and (until is None or not until()):
68+
send_value, throw_value = None, None
69+
if self._started:
70+
try:
71+
send_value = yield self._next_item
72+
except BaseException as e:
73+
throw_value = e
74+
75+
self._started = True
76+
try:
77+
if throw_value is not None:
78+
self._next_item = self._generator.throw(throw_value)
79+
else:
80+
self._next_item = self._generator.send(send_value)
81+
except StopIteration as e:
82+
self._finished = True
83+
return e.value
84+
except BaseException:
85+
self._generator.close()
86+
self._finished = True
87+
raise
88+
89+
5590
class ASGIResponseStream(AsyncByteStream):
56-
def __init__(self, body: list[bytes]) -> None:
91+
def __init__(
92+
self,
93+
body: list[bytes],
94+
raise_app_exceptions: bool,
95+
response_complete: "Event",
96+
app_runner: _AwaitableRunner,
97+
) -> None:
5798
self._body = body
99+
self._raise_app_exceptions = raise_app_exceptions
100+
self._response_complete = response_complete
101+
self._app_runner = app_runner
58102

59103
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
60-
yield b"".join(self._body)
104+
try:
105+
while bool(self._body) or not self._response_complete.is_set():
106+
if self._body:
107+
yield b"".join(self._body)
108+
self._body.clear()
109+
await self._app_runner(
110+
until=lambda: bool(self._body) or self._response_complete.is_set()
111+
)
112+
except Exception: # noqa: PIE786
113+
if self._raise_app_exceptions:
114+
raise
115+
finally:
116+
await self.aclose()
117+
118+
async def aclose(self) -> None:
119+
self._response_complete.set()
120+
try:
121+
await self._app_runner()
122+
except Exception: # noqa: PIE786
123+
if self._raise_app_exceptions:
124+
raise
61125

62126

63127
class ASGITransport(AsyncBaseTransport):
@@ -155,8 +219,10 @@ async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
155219
response_headers = message.get("headers", [])
156220
response_started = True
157221

158-
elif message["type"] == "http.response.body":
159-
assert not response_complete.is_set()
222+
elif (
223+
message["type"] == "http.response.body"
224+
and not response_complete.is_set()
225+
):
160226
body = message.get("body", b"")
161227
more_body = message.get("more_body", False)
162228

@@ -166,9 +232,11 @@ async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
166232
if not more_body:
167233
response_complete.set()
168234

235+
app_runner = _AwaitableRunner(self.app(scope, receive, send))
236+
169237
try:
170-
await self.app(scope, receive, send)
171-
except Exception: # noqa: PIE-786
238+
await app_runner(until=lambda: response_started)
239+
except Exception: # noqa: PIE786
172240
if self.raise_app_exceptions:
173241
raise
174242

@@ -178,10 +246,11 @@ async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
178246
if response_headers is None:
179247
response_headers = {}
180248

181-
assert response_complete.is_set()
182249
assert status_code is not None
183250
assert response_headers is not None
184251

185-
stream = ASGIResponseStream(body_parts)
252+
stream = ASGIResponseStream(
253+
body_parts, self.raise_app_exceptions, response_complete, app_runner
254+
)
186255

187256
return Response(status_code, headers=response_headers, stream=stream)

tests/test_asgi.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22

3+
import anyio
34
import pytest
45

56
import httpx
@@ -60,13 +61,24 @@ async def raise_exc(scope, receive, send):
6061
raise RuntimeError()
6162

6263

64+
async def raise_exc_after_response_start(scope, receive, send):
65+
status = 200
66+
output = b"Hello, World!"
67+
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
68+
69+
await send({"type": "http.response.start", "status": status, "headers": headers})
70+
await anyio.sleep(0)
71+
raise RuntimeError()
72+
73+
6374
async def raise_exc_after_response(scope, receive, send):
6475
status = 200
6576
output = b"Hello, World!"
6677
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
6778

6879
await send({"type": "http.response.start", "status": status, "headers": headers})
6980
await send({"type": "http.response.body", "body": output})
81+
await anyio.sleep(0)
7082
raise RuntimeError()
7183

7284

@@ -172,6 +184,14 @@ async def test_asgi_exc():
172184
await client.get("http://www.example.org/")
173185

174186

187+
@pytest.mark.anyio
188+
async def test_asgi_exc_after_response_start():
189+
transport = httpx.ASGITransport(app=raise_exc_after_response_start)
190+
async with httpx.AsyncClient(transport=transport) as client:
191+
with pytest.raises(RuntimeError):
192+
await client.get("http://www.example.org/")
193+
194+
175195
@pytest.mark.anyio
176196
async def test_asgi_exc_after_response():
177197
transport = httpx.ASGITransport(app=raise_exc_after_response)
@@ -222,3 +242,97 @@ async def test_asgi_exc_no_raise():
222242
response = await client.get("http://www.example.org/")
223243

224244
assert response.status_code == 500
245+
246+
247+
@pytest.mark.anyio
248+
async def test_asgi_exc_no_raise_after_response_start():
249+
transport = httpx.ASGITransport(
250+
app=raise_exc_after_response_start, raise_app_exceptions=False
251+
)
252+
async with httpx.AsyncClient(transport=transport) as client:
253+
response = await client.get("http://www.example.org/")
254+
255+
assert response.status_code == 200
256+
257+
258+
@pytest.mark.anyio
259+
async def test_asgi_exc_no_raise_after_response():
260+
transport = httpx.ASGITransport(
261+
app=raise_exc_after_response, raise_app_exceptions=False
262+
)
263+
async with httpx.AsyncClient(transport=transport) as client:
264+
response = await client.get("http://www.example.org/")
265+
266+
assert response.status_code == 200
267+
268+
269+
@pytest.mark.anyio
270+
async def test_asgi_stream_returns_before_waiting_for_body():
271+
start_response_body = anyio.Event()
272+
273+
async def send_response_body_after_event(scope, receive, send):
274+
status = 200
275+
headers = [(b"content-type", b"text/plain")]
276+
await send(
277+
{"type": "http.response.start", "status": status, "headers": headers}
278+
)
279+
await start_response_body.wait()
280+
await send({"type": "http.response.body", "body": b"body", "more_body": False})
281+
282+
transport = httpx.ASGITransport(app=send_response_body_after_event)
283+
async with httpx.AsyncClient(transport=transport) as client:
284+
async with client.stream("GET", "http://www.example.org/") as response:
285+
assert response.status_code == 200
286+
start_response_body.set()
287+
await response.aread()
288+
assert response.text == "body"
289+
290+
291+
@pytest.mark.anyio
292+
async def test_asgi_stream_allows_iterative_streaming():
293+
stream_events = [anyio.Event() for i in range(4)]
294+
295+
async def send_response_body_after_event(scope, receive, send):
296+
status = 200
297+
headers = [(b"content-type", b"text/plain")]
298+
await send(
299+
{"type": "http.response.start", "status": status, "headers": headers}
300+
)
301+
for e in stream_events:
302+
await e.wait()
303+
await send(
304+
{
305+
"type": "http.response.body",
306+
"body": b"chunk",
307+
"more_body": e is not stream_events[-1],
308+
}
309+
)
310+
311+
transport = httpx.ASGITransport(app=send_response_body_after_event)
312+
async with httpx.AsyncClient(transport=transport) as client:
313+
async with client.stream("GET", "http://www.example.org/") as response:
314+
assert response.status_code == 200
315+
iterator = response.aiter_raw()
316+
for e in stream_events:
317+
e.set()
318+
assert await iterator.__anext__() == b"chunk"
319+
with pytest.raises(StopAsyncIteration):
320+
await iterator.__anext__()
321+
322+
323+
@pytest.mark.anyio
324+
async def test_asgi_can_be_canceled():
325+
# This test exists to cover transmission of the cancellation exception through
326+
# _AwaitableRunner
327+
app_started = anyio.Event()
328+
329+
async def never_return(scope, receive, send):
330+
app_started.set()
331+
await anyio.sleep_forever()
332+
333+
transport = httpx.ASGITransport(app=never_return)
334+
async with httpx.AsyncClient(transport=transport) as client:
335+
async with anyio.create_task_group() as task_group:
336+
task_group.start_soon(client.get, "http://www.example.org/")
337+
await app_started.wait()
338+
task_group.cancel_scope.cancel()

0 commit comments

Comments
 (0)