Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ async def handle_async_request(self, request: Request) -> Response:
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
try:
stream = await stream.start_tls(**kwargs)
except Exception:
await self._connection.aclose()
raise
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
Expand Down
6 changes: 5 additions & 1 deletion httpcore/_sync/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ def handle_request(self, request: Request) -> Response:
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
try:
stream = stream.start_tls(**kwargs)
except Exception:
self._connection.close()
raise
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
Expand Down
63 changes: 63 additions & 0 deletions tests/_async/test_http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,69 @@ async def test_proxy_tunneling_with_auth():
assert response.content == b"Hello, world!"


@pytest.mark.anyio
async def test_proxy_tunneling_tls_failure_cleans_up():
"""
When start_tls raises after a successful CONNECT through a proxy,
the tunnel connection is closed and removed from the pool, so
subsequent requests do not hit PoolTimeout.
"""

class FailingTLSStream(AsyncMockStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
raise OSError("TLS handshake failed")

class FailOnceBackend(AsyncMockBackend):
def __init__(self, buffer: list[bytes]) -> None:
super().__init__(buffer)
self._should_fail = True

async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
if self._should_fail:
self._should_fail = False
return FailingTLSStream(list(self._buffer))
return AsyncMockStream(list(self._buffer))

buffer = [
b"HTTP/1.1 200 OK\r\n\r\n",
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
]
backend = FailOnceBackend(buffer)

async with AsyncConnectionPool(
proxy=Proxy("http://localhost:8080/"),
max_connections=1,
network_backend=backend,
) as proxy:
# First request: CONNECT succeeds, start_tls raises OSError.
with pytest.raises(OSError, match="TLS handshake failed"):
await proxy.request("GET", "https://example.com/")

# The poisoned tunnel connection must have been cleaned up.
assert not proxy.connections

# Second request: pool is clean, so it succeeds normally.
response = await proxy.request("GET", "https://example.com/")
assert response.status == 200
assert response.content == b"Hello, world!"


def test_proxy_headers():
proxy = Proxy(
url="http://localhost:8080/",
Expand Down
62 changes: 62 additions & 0 deletions tests/_sync/test_http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,65 @@ def test_proxy_headers():
assert proxy.headers == [
(b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=")
]


def test_proxy_tunneling_tls_failure_cleans_up():
"""
When start_tls raises after a successful CONNECT through a proxy,
the tunnel connection is closed and removed from the pool, so
subsequent requests do not hit PoolTimeout.
"""

class FailingTLSStream(MockStream):
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> NetworkStream:
raise OSError("TLS handshake failed")

class FailOnceBackend(MockBackend):
def __init__(self, buffer: list[bytes]) -> None:
super().__init__(buffer)
self._should_fail = True

def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
if self._should_fail:
self._should_fail = False
return FailingTLSStream(list(self._buffer))
return MockStream(list(self._buffer))

buffer = [
b"HTTP/1.1 200 OK\r\n\r\n",
b"HTTP/1.1 200 OK\r\n",
b"Content-Type: plain/text\r\n",
b"Content-Length: 13\r\n",
b"\r\n",
b"Hello, world!",
]
backend = FailOnceBackend(buffer)

with ConnectionPool(
proxy=Proxy("http://localhost:8080/"),
max_connections=1,
network_backend=backend,
) as proxy:
# First request: CONNECT succeeds, start_tls raises OSError.
with pytest.raises(OSError, match="TLS handshake failed"):
proxy.request("GET", "https://example.com/")

# The poisoned tunnel connection must have been cleaned up.
assert not proxy.connections

# Second request: pool is clean, so it succeeds normally.
response = proxy.request("GET", "https://example.com/")
assert response.status == 200
assert response.content == b"Hello, world!"
Loading