diff --git a/CHANGES/10847.feature.rst b/CHANGES/10847.feature.rst new file mode 100644 index 00000000000..bfa7f6d498a --- /dev/null +++ b/CHANGES/10847.feature.rst @@ -0,0 +1,5 @@ +Implemented shared DNS resolver management to fix excessive resolver object creation +when using multiple client sessions. The new ``_DNSResolverManager`` singleton ensures +only one ``DNSResolver`` object is created for default configurations, significantly +reducing resource usage and improving performance for applications using multiple +client sessions simultaneously -- by :user:`bdraco`. diff --git a/CHANGES/10902.feature.rst b/CHANGES/10902.feature.rst new file mode 120000 index 00000000000..b565aa68ee0 --- /dev/null +++ b/CHANGES/10902.feature.rst @@ -0,0 +1 @@ +9732.feature.rst \ No newline at end of file diff --git a/CHANGES/9212.breaking.rst b/CHANGES/9212.breaking.rst new file mode 120000 index 00000000000..b6deef3c9bf --- /dev/null +++ b/CHANGES/9212.breaking.rst @@ -0,0 +1 @@ +9212.packaging.rst \ No newline at end of file diff --git a/CHANGES/9212.packaging.rst b/CHANGES/9212.packaging.rst new file mode 100644 index 00000000000..a5cf325fc23 --- /dev/null +++ b/CHANGES/9212.packaging.rst @@ -0,0 +1 @@ +Removed remaining `make_mocked_coro` in the test suite -- by :user:`polkapolka`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index e54e97e8ce2..1e5f0da2684 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -295,6 +295,7 @@ Pavol Vargovčík Pawel Kowalski Pawel Miech Pepe Osca +Phebe Polk Philipp A. Pierre-Louis Peeters Pieter van Beek diff --git a/aiohttp/client.py b/aiohttp/client.py index 30a29c36c01..d615ec181c5 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -26,6 +26,7 @@ List, Mapping, Optional, + Sequence, Set, Tuple, Type, @@ -194,7 +195,7 @@ class _RequestOptions(TypedDict, total=False): auto_decompress: Union[bool, None] max_line_size: Union[int, None] max_field_size: Union[int, None] - middlewares: Optional[Tuple[ClientMiddlewareType, ...]] + middlewares: Optional[Sequence[ClientMiddlewareType]] @frozen_dataclass_decorator @@ -295,7 +296,7 @@ def __init__( max_line_size: int = 8190, max_field_size: int = 8190, fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", - middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None, + middlewares: Optional[Sequence[ClientMiddlewareType]] = None, ) -> None: # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. @@ -455,7 +456,7 @@ async def _request( auto_decompress: Optional[bool] = None, max_line_size: Optional[int] = None, max_field_size: Optional[int] = None, - middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None, + middlewares: Optional[Sequence[ClientMiddlewareType]] = None, ) -> ClientResponse: # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants @@ -648,7 +649,6 @@ async def _request( trust_env=self.trust_env, ) - # Core request handler - now includes connection logic async def _connect_and_send_request( req: ClientRequest, ) -> ClientResponse: diff --git a/aiohttp/client_middlewares.py b/aiohttp/client_middlewares.py index 6be353c3a40..3ca2cb202ad 100644 --- a/aiohttp/client_middlewares.py +++ b/aiohttp/client_middlewares.py @@ -1,6 +1,6 @@ """Client middleware support.""" -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from .client_reqrep import ClientRequest, ClientResponse @@ -17,7 +17,7 @@ def build_client_middlewares( handler: ClientHandlerType, - middlewares: tuple[ClientMiddlewareType, ...], + middlewares: Sequence[ClientMiddlewareType], ) -> ClientHandlerType: """ Apply middlewares to request handler. @@ -28,9 +28,6 @@ def build_client_middlewares( This implementation avoids using partial/update_wrapper to minimize overhead and doesn't cache to avoid holding references to stateful middleware. """ - if not middlewares: - return handler - # Optimize for single middleware case if len(middlewares) == 1: middleware = middlewares[0] diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index d99970f3952..0a646b0c189 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -1,6 +1,7 @@ import asyncio import socket -from typing import Any, List, Tuple, Type, Union +import weakref +from typing import Any, List, Optional, Tuple, Type, Union from .abc import AbstractResolver, ResolveResult @@ -88,7 +89,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if aiodns is None: raise RuntimeError("Resolver requires aiodns library") - self._resolver = aiodns.DNSResolver(*args, **kwargs) + self._loop = asyncio.get_running_loop() + self._manager: Optional[_DNSResolverManager] = None + # If custom args are provided, create a dedicated resolver instance + # This means each AsyncResolver with custom args gets its own + # aiodns.DNSResolver instance + if args or kwargs: + self._resolver = aiodns.DNSResolver(*args, **kwargs) + return + # Use the shared resolver from the manager for default arguments + self._manager = _DNSResolverManager() + self._resolver = self._manager.get_resolver(self, self._loop) async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET @@ -142,7 +153,78 @@ async def resolve( return hosts async def close(self) -> None: + if self._manager: + # Release the resolver from the manager if using the shared resolver + self._manager.release_resolver(self, self._loop) + self._manager = None # Clear reference to manager + self._resolver = None # type: ignore[assignment] # Clear reference to resolver + return + # Otherwise cancel our dedicated resolver self._resolver.cancel() + self._resolver = None # type: ignore[assignment] # Clear reference + + +class _DNSResolverManager: + """Manager for aiodns.DNSResolver objects. + + This class manages shared aiodns.DNSResolver instances + with no custom arguments across different event loops. + """ + + _instance: Optional["_DNSResolverManager"] = None + + def __new__(cls) -> "_DNSResolverManager": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._init() + return cls._instance + + def _init(self) -> None: + # Use WeakKeyDictionary to allow event loops to be garbage collected + self._loop_data: weakref.WeakKeyDictionary[ + asyncio.AbstractEventLoop, + tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], + ] = weakref.WeakKeyDictionary() + + def get_resolver( + self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop + ) -> "aiodns.DNSResolver": + """Get or create the shared aiodns.DNSResolver instance for a specific event loop. + + Args: + client: The AsyncResolver instance requesting the resolver. + This is required to track resolver usage. + loop: The event loop to use for the resolver. + """ + # Create a new resolver and client set for this loop if it doesn't exist + if loop not in self._loop_data: + resolver = aiodns.DNSResolver(loop=loop) + client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() + self._loop_data[loop] = (resolver, client_set) + else: + # Get the existing resolver and client set + resolver, client_set = self._loop_data[loop] + + # Register this client with the loop + client_set.add(client) + return resolver + + def release_resolver( + self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop + ) -> None: + """Release the resolver for an AsyncResolver client when it's closed. + + Args: + client: The AsyncResolver instance to release. + loop: The event loop the resolver was using. + """ + # Remove client from its loop's tracking + resolver, client_set = self._loop_data[loop] + client_set.discard(client) + # If no more clients for this loop, cancel and remove its resolver + if not client_set: + resolver.cancel() + del self._loop_data[loop] _DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 569baf7c32d..28b3e21df0a 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -3,7 +3,6 @@ import asyncio import contextlib import gc -import inspect import ipaddress import os import socket @@ -42,7 +41,6 @@ from .abc import AbstractCookieJar, AbstractStreamWriter from .client_reqrep import ClientResponse from .client_ws import ClientWebSocketResponse -from .helpers import sentinel from .http import HttpVersion, RawRequestMessage from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import LooseHeaders, StrOrURL @@ -682,10 +680,10 @@ def make_mocked_request( if writer is None: writer = mock.Mock() - writer.write_headers = make_mocked_coro(None) - writer.write = make_mocked_coro(None) - writer.write_eof = make_mocked_coro(None) - writer.drain = make_mocked_coro(None) + writer.write_headers = mock.AsyncMock(return_value=None) + writer.write = mock.AsyncMock(return_value=None) + writer.write_eof = mock.AsyncMock(return_value=None) + writer.drain = mock.AsyncMock(return_value=None) writer.transport = transport protocol.transport = transport @@ -701,18 +699,3 @@ def make_mocked_request( req._match_info = match_info return req - - -def make_mocked_coro( - return_value: Any = sentinel, raise_exception: Any = sentinel -) -> Any: - """Creates a coroutine mock.""" - - async def mock_coro(*args: Any, **kwargs: Any) -> Any: - if raise_exception is not sentinel: - raise raise_exception - if not inspect.isawaitable(return_value): - return return_value - await return_value - - return mock.Mock(wraps=mock_coro) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 823ac45a8c7..10105249a6a 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -130,32 +130,18 @@ background. Client Middleware ----------------- -aiohttp client supports middleware to intercept requests and responses. This can be +The client supports middleware to intercept requests and responses. This can be useful for authentication, logging, request/response modification, and retries. -To create a middleware, you need to define an async function that accepts the request -and a handler function, and returns the response. The middleware must match the -:type:`ClientMiddlewareType` type signature:: - - import logging - from aiohttp import ClientSession, ClientRequest, ClientResponse, ClientHandlerType - - _LOGGER = logging.getLogger(__name__) - - async def my_middleware( - request: ClientRequest, - handler: ClientHandlerType - ) -> ClientResponse: - # Process request before sending - _LOGGER.debug(f"Request: {request.method} {request.url}") - - # Call the next handler - response = await handler(request) +Creating Middleware +^^^^^^^^^^^^^^^^^^^ - # Process response after receiving - _LOGGER.debug(f"Response: {response.status}") +To create a middleware, define an async function (or callable class) that accepts a request +and a handler function, and returns a response. Middleware must follow the +:type:`ClientMiddlewareType` signature (see :ref:`aiohttp-client-reference` for details). - return response +Using Middleware +^^^^^^^^^^^^^^^^ You can apply middleware to a client session or to individual requests:: @@ -167,175 +153,189 @@ You can apply middleware to a client session or to individual requests:: async with ClientSession() as session: resp = await session.get('http://example.com', middlewares=(my_middleware,)) -Middleware Examples +Middleware Chaining ^^^^^^^^^^^^^^^^^^^ -Here's a simple example showing request modification:: +Multiple middlewares are applied in the order they are listed:: - async def add_api_key_middleware( - request: ClientRequest, - handler: ClientHandlerType - ) -> ClientResponse: - # Add API key to all requests - request.headers['X-API-Key'] = 'my-secret-key' - return await handler(request) + # Middlewares are applied in order: logging -> auth -> request + async with ClientSession(middlewares=(logging_middleware, auth_middleware)) as session: + resp = await session.get('http://example.com') + +A key aspect to understand about the flat middleware structure is that the execution flow follows this pattern: + +1. The first middleware in the list is called first and executes its code before calling the handler +2. The handler is the next middleware in the chain (or the actual request handler if there are no more middleware) +3. When the handler returns a response, execution continues in the first middleware after the handler call +4. This creates a nested "onion-like" pattern for execution + +For example, with ``middlewares=(middleware1, middleware2)``, the execution order would be: + +1. Enter ``middleware1`` (pre-request code) +2. Enter ``middleware2`` (pre-request code) +3. Execute the actual request handler +4. Exit ``middleware2`` (post-response code) +5. Exit ``middleware1`` (post-response code) + +This flat structure means that middleware is applied on each retry attempt inside the client's retry loop, not just once before all retries. This allows middleware to modify requests freshly on each retry attempt. + +.. note:: + + Client middleware is a powerful feature but should be used judiciously. + Each middleware adds overhead to request processing. For simple use cases + like adding static headers, you can often use request parameters + (e.g., ``headers``) or session configuration instead. + +Common Middleware Patterns +^^^^^^^^^^^^^^^^^^^^^^^^^^ .. _client-middleware-retry: -Middleware Retry Pattern -^^^^^^^^^^^^^^^^^^^^^^^^ +Authentication and Retry +"""""""""""""""""""""""" -Client middleware can implement retry logic internally using a ``while`` loop. This allows the middleware to: +There are two recommended approaches for implementing retry logic: -- Retry requests based on response status codes or other conditions -- Modify the request between retries (e.g., refreshing tokens) -- Maintain state across retry attempts -- Control when to stop retrying and return the response +1. **For Loop Pattern (Simple Cases)** -This pattern is particularly useful for: + Use a bounded ``for`` loop when the number of retry attempts is known and fixed:: -- Refreshing authentication tokens after a 401 response -- Switching to fallback servers or authentication methods -- Adding or modifying headers based on error responses -- Implementing back-off strategies with increasing delays + import hashlib + from aiohttp import ClientSession, ClientRequest, ClientResponse, ClientHandlerType -The middleware can maintain state between retries to track which strategies have been tried and modify the request accordingly for the next attempt. + async def auth_retry_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + # Try up to 3 authentication methods + for attempt in range(3): + if attempt == 0: + # First attempt: use API key + request.headers["X-API-Key"] = "my-api-key" + elif attempt == 1: + # Second attempt: use Bearer token + request.headers["Authorization"] = "Bearer fallback-token" + else: + # Third attempt: use hash-based signature + secret_key = "my-secret-key" + url_path = str(request.url.path) + signature = hashlib.sha256(f"{url_path}{secret_key}".encode()).hexdigest() + request.headers["X-Signature"] = signature -Example: Retrying requests with middleware -"""""""""""""""""""""""""""""""""""""""""" + # Send the request + response = await handler(request) -:: + # If successful or not an auth error, return immediately + if response.status != 401: + return response - import logging - import aiohttp + # Return the last response if all retries are exhausted + return response - _LOGGER = logging.getLogger(__name__) +2. **While Loop Pattern (Complex Cases)** - class RetryMiddleware: - def __init__(self, max_retries: int = 3): - self.max_retries = max_retries + For more complex scenarios, use a ``while`` loop with strict exit conditions:: - async def __call__( - self, - request: ClientRequest, - handler: ClientHandlerType - ) -> ClientResponse: - retry_count = 0 - use_fallback_auth = False + import logging - while True: - # Modify request based on retry state - if use_fallback_auth: - request.headers['Authorization'] = 'Bearer fallback-token' + _LOGGER = logging.getLogger(__name__) - response = await handler(request) + class RetryMiddleware: + def __init__(self, max_retries: int = 3): + self.max_retries = max_retries - # Retry on 401 errors with different authentication - if response.status == 401 and retry_count < self.max_retries: - retry_count += 1 - use_fallback_auth = True - _LOGGER.debug(f"Retrying with fallback auth (attempt {retry_count})") - continue + async def __call__( + self, + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + retry_count = 0 - # Retry on 5xx errors - if response.status >= 500 and retry_count < self.max_retries: - retry_count += 1 - _LOGGER.debug(f"Retrying request (attempt {retry_count})") - continue + # Always have clear exit conditions + while retry_count <= self.max_retries: + # Send the request + response = await handler(request) - return response + # Exit conditions + if 200 <= response.status < 400 or retry_count >= self.max_retries: + return response -Middleware Chaining -^^^^^^^^^^^^^^^^^^^ + # Retry logic for different status codes + if response.status in (401, 429, 500, 502, 503, 504): + retry_count += 1 + _LOGGER.debug(f"Retrying request (attempt {retry_count}/{self.max_retries})") + continue -Multiple middlewares are applied in the order they are listed:: + # For any other status code, don't retry + return response - import logging + # Safety return (should never reach here) + return response - _LOGGER = logging.getLogger(__name__) +Request Modification +"""""""""""""""""""" - async def logging_middleware( - request: ClientRequest, - handler: ClientHandlerType - ) -> ClientResponse: - _LOGGER.debug(f"[LOG] {request.method} {request.url}") - return await handler(request) +Modify request properties based on request content:: - async def auth_middleware( + async def content_type_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: - request.headers['Authorization'] = 'Bearer token123' - return await handler(request) + # Examine URL path to determine content-type + if request.url.path.endswith('.json'): + request.headers['Content-Type'] = 'application/json' + elif request.url.path.endswith('.xml'): + request.headers['Content-Type'] = 'application/xml' - # Middlewares are applied in order: logging -> auth -> request - async with ClientSession(middlewares=(logging_middleware, auth_middleware)) as session: - resp = await session.get('http://example.com') + # Add custom headers based on HTTP method + if request.method == 'POST': + request.headers['X-Request-ID'] = f"post-{id(request)}" -.. note:: + return await handler(request) - Client middleware is a powerful feature but should be used judiciously. - Each middleware adds overhead to request processing. For simple use cases - like adding static headers, you can often use request parameters - (e.g., ``headers``) or session configuration instead. +Avoiding Infinite Recursion +^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. warning:: Using the same session from within middleware can cause infinite recursion if the middleware makes HTTP requests using the same session that has the middleware - applied. - - To avoid recursion, use one of these approaches: - - **Recommended:** Pass ``middlewares=()`` to requests made inside the middleware to - disable middleware for those specific requests:: - - async def log_middleware( - request: ClientRequest, - handler: ClientHandlerType - ) -> ClientResponse: - async with request.session.post( - "https://logapi.example/log", - json={"url": str(request.url)}, - middlewares=() # This prevents infinite recursion - ) as resp: - pass + applied. This is especially risky in token refresh middleware or retry logic. - return await handler(request) + When implementing retry or refresh logic, always use bounded loops + (e.g., ``for _ in range(2):`` instead of ``while True:``) to prevent infinite recursion. - **Alternative:** Check the request contents (URL, path, host) to avoid applying - middleware to certain requests:: +To avoid recursion when making requests inside middleware, use one of these approaches: - async def log_middleware( - request: ClientRequest, - handler: ClientHandlerType - ) -> ClientResponse: - if request.url.host != "logapi.example": # Avoid infinite recursion - async with request.session.post( - "https://logapi.example/log", - json={"url": str(request.url)} - ) as resp: - pass +**Option 1:** Disable middleware for internal requests:: - return await handler(request) - -Middleware Type -^^^^^^^^^^^^^^^ - -.. type:: ClientMiddlewareType - - Type alias for client middleware functions. Middleware functions must have this signature:: + async def log_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + async with request.session.post( + "https://logapi.example/log", + json={"url": str(request.url)}, + middlewares=() # This prevents infinite recursion + ) as resp: + pass - Callable[ - [ClientRequest, ClientHandlerType], - Awaitable[ClientResponse] - ] + return await handler(request) -.. type:: ClientHandlerType +**Option 2:** Check request details to avoid recursive application:: - Type alias for client request handler functions:: + async def log_middleware( + request: ClientRequest, + handler: ClientHandlerType + ) -> ClientResponse: + if request.url.host != "logapi.example": # Avoid infinite recursion + async with request.session.post( + "https://logapi.example/log", + json={"url": str(request.url)} + ) as resp: + pass - Callable[ClientRequest, Awaitable[ClientResponse]] + return await handler(request) Custom Cookies -------------- diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 030e07d9ef4..8101f25a872 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -214,7 +214,7 @@ The client session supports the context manager protocol for self closing. disabling. See :ref:`aiohttp-client-tracing-reference` for more information. - :param middlewares: A tuple of middleware instances to apply to all session requests. + :param middlewares: A sequence of middleware instances to apply to all session requests. Each middleware must match the :type:`ClientMiddlewareType` signature. ``None`` (default) is used when no middleware is needed. See :ref:`aiohttp-client-middleware` for more information. @@ -528,7 +528,7 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 3.0 - :param middlewares: A tuple of middleware instances to apply to this request only. + :param middlewares: A sequence of middleware instances to apply to this request only. Each middleware must match the :type:`ClientMiddlewareType` signature. ``None`` by default which uses session middlewares. See :ref:`aiohttp-client-middleware` for more information. @@ -2606,3 +2606,22 @@ Hierarchy of exceptions * :exc:`InvalidUrlRedirectClientError` * :exc:`NonHttpUrlRedirectClientError` + + +Client Types +------------ + +.. type:: ClientMiddlewareType + + Type alias for client middleware functions. Middleware functions must have this signature:: + + Callable[ + [ClientRequest, ClientHandlerType], + Awaitable[ClientResponse] + ] + +.. type:: ClientHandlerType + + Type alias for client request handler functions:: + + Callable[[ClientRequest], Awaitable[ClientResponse]] diff --git a/docs/testing.rst b/docs/testing.rst index 1d29f335460..1e3a12e2302 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -793,25 +793,6 @@ Test Client Utilities ~~~~~~~~~ -.. function:: make_mocked_coro(return_value) - - Creates a coroutine mock. - - Behaves like a coroutine which returns *return_value*. But it is - also a mock object, you might test it as usual - :class:`~unittest.mock.Mock`:: - - mocked = make_mocked_coro(1) - assert 1 == await mocked(1, 2) - mocked.assert_called_with(1, 2) - - - :param return_value: A value that the mock object will return when - called. - :returns: A mock object that behaves as a coroutine which returns - *return_value* when called. - - .. function:: unused_port() Return an unused port number for IPv4 TCP protocol. diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py index 2f79e4fd774..5894795dc21 100644 --- a/tests/test_client_middleware.py +++ b/tests/test_client_middleware.py @@ -74,13 +74,12 @@ async def handler(request: web.Request) -> web.Response: async def retry_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: - retry_count = 0 - while True: + response = None + for _ in range(2): # pragma: no branch response = await handler(request) - if response.status == 503 and retry_count < 1: - retry_count += 1 - continue - return response + if response.ok: + return response + assert False, "not reachable in test" app = web.Application() app.router.add_get("/", handler) @@ -244,30 +243,28 @@ async def handler(request: web.Request) -> web.Response: async def challenge_auth_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: - challenge_data: Dict[str, Union[bool, str, None]] = { - "nonce": None, - "attempted": False, - } + nonce: Optional[str] = None + attempted: bool = False while True: # If we have challenge data from previous attempt, add auth header - if challenge_data["nonce"] and challenge_data["attempted"]: - request.headers["Authorization"] = ( - f'Custom response="{challenge_data["nonce"]}-secret"' - ) + if nonce and attempted: + request.headers["Authorization"] = f'Custom response="{nonce}-secret"' response = await handler(request) # If we get a 401 with challenge, store it and retry - if response.status == 401 and not challenge_data["attempted"]: + if response.status == 401 and not attempted: www_auth = response.headers.get("WWW-Authenticate") - if www_auth and "nonce=" in www_auth: # pragma: no branch + if www_auth and "nonce=" in www_auth: # Extract nonce from authentication header nonce_start = www_auth.find('nonce="') + 7 nonce_end = www_auth.find('"', nonce_start) - challenge_data["nonce"] = www_auth[nonce_start:nonce_end] - challenge_data["attempted"] = True + nonce = www_auth[nonce_start:nonce_end] + attempted = True continue + else: + assert False, "Should not reach here" return response @@ -324,7 +321,7 @@ async def multi_step_auth_middleware( ) -> ClientResponse: request.headers["X-Client-ID"] = "test-client" - while True: + for _ in range(3): # Apply auth based on current state if middleware_state["step"] == 1 and middleware_state["session"]: request.headers["Authorization"] = ( @@ -347,13 +344,17 @@ async def multi_step_auth_middleware( middleware_state["step"] = 1 continue - elif auth_step == "2": # pragma: no branch + elif auth_step == "2": # Second step: store challenge middleware_state["challenge"] = response.headers.get("X-Challenge") middleware_state["step"] = 2 continue + else: + assert False, "Should not reach here" return response + # This should not be reached but keeps mypy happy + assert False, "Should not reach here" app = web.Application() app.router.add_get("/", handler) @@ -396,7 +397,7 @@ async def handler(request: web.Request) -> web.Response: async def token_refresh_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: - while True: + for _ in range(2): # Add token to request request.headers["X-Auth-Token"] = str(token_state["token"]) @@ -407,13 +408,17 @@ async def token_refresh_middleware( data = await response.json() if data.get("error") == "token_expired" and data.get( "refresh_required" - ): # pragma: no branch + ): # Simulate token refresh token_state["token"] = "refreshed-token" token_state["refreshed"] = True continue + else: + assert False, "Should not reach here" return response + # This should not be reached but keeps mypy happy + assert False, "Should not reach here" app = web.Application() app.router.add_get("/", handler) @@ -490,7 +495,6 @@ class RetryMiddleware: def __init__(self, max_retries: int = 3) -> None: self.max_retries = max_retries - self.retry_counts: Dict[int, int] = {} # Track retries per request async def __call__( self, request: ClientRequest, handler: ClientHandlerType @@ -576,10 +580,55 @@ async def handler(request: web.Request) -> web.Response: assert headers_received.get("X-Custom-2") == "value2" -async def test_client_middleware_disable_with_empty_tuple( +async def test_request_middleware_overrides_session_middleware_with_empty( aiohttp_server: AiohttpServer, ) -> None: - """Test that passing middlewares=() to a request disables session-level middlewares.""" + """Test that passing empty middlewares tuple to a request disables session-level middlewares.""" + session_middleware_called = False + + async def handler(request: web.Request) -> web.Response: + auth_header = request.headers.get("Authorization") + if auth_header: + return web.Response(text=f"Auth: {auth_header}") + return web.Response(text="No auth") + + async def session_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + nonlocal session_middleware_called + session_middleware_called = True + request.headers["Authorization"] = "Bearer session-token" + response = await handler(request) + return response + + app = web.Application() + app.router.add_get("/", handler) + server = await aiohttp_server(app) + + # Create session with middleware + async with ClientSession(middlewares=(session_middleware,)) as session: + # First request uses session middleware + async with session.get(server.make_url("/")) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "Auth: Bearer session-token" + assert session_middleware_called is True + + # Reset flags + session_middleware_called = False + + # Second request explicitly disables middlewares with empty tuple + async with session.get(server.make_url("/"), middlewares=()) as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "No auth" + assert session_middleware_called is False + + +async def test_request_middleware_overrides_session_middleware_with_specific( + aiohttp_server: AiohttpServer, +) -> None: + """Test that passing specific middlewares to a request overrides session-level middlewares.""" session_middleware_called = False request_middleware_called = False @@ -625,19 +674,7 @@ async def request_middleware( session_middleware_called = False request_middleware_called = False - # Second request explicitly disables middlewares - async with session.get(server.make_url("/"), middlewares=()) as resp: - assert resp.status == 200 - text = await resp.text() - assert text == "No auth" - assert session_middleware_called is False - assert request_middleware_called is False - - # Reset flags - session_middleware_called = False - request_middleware_called = False - - # Third request uses request-specific middleware + # Second request uses request-specific middleware async with session.get( server.make_url("/"), middlewares=(request_middleware,) ) as resp: @@ -745,9 +782,13 @@ async def blocking_middleware( # Verify that connections were attempted in the correct order assert len(connection_attempts) == 3 - assert allowed_url.host and allowed_url.host in connection_attempts[0] - assert "blocked.example.com" in connection_attempts[1] - assert "evil.com" in connection_attempts[2] + assert allowed_url.host + + assert connection_attempts == [ + str(server.make_url("/")), + "https://blocked.example.com/", + "https://evil.com/path", + ] # Check that no connections were leaked assert len(connector._conns) == 0 @@ -1042,8 +1083,7 @@ def get_hash(self, request: ClientRequest) -> str: data = "{}" # Simulate authentication hash without using real crypto - signature = f"SIGNATURE-{self.secretkey}-{len(data)}-{data[:10]}" - return signature + return f"SIGNATURE-{self.secretkey}-{len(data)}-{data[:10]}" async def __call__( self, request: ClientRequest, handler: ClientHandlerType diff --git a/tests/test_client_request.py b/tests/test_client_request.py index e1e8e3d9992..8458c376b78 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -34,7 +34,6 @@ from aiohttp.compression_utils import ZLibBackend from aiohttp.connector import Connection from aiohttp.http import HttpVersion10, HttpVersion11 -from aiohttp.test_utils import make_mocked_coro from aiohttp.typedefs import LooseCookies @@ -831,7 +830,7 @@ async def test_content_encoding( "post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: - m_writer.return_value.write_headers = make_mocked_coro() + m_writer.return_value.write_headers = mock.AsyncMock() resp = await req.send(conn) assert req.headers["TRANSFER-ENCODING"] == "chunked" assert req.headers["CONTENT-ENCODING"] == "deflate" @@ -867,7 +866,7 @@ async def test_content_encoding_header( loop=loop, ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: - m_writer.return_value.write_headers = make_mocked_coro() + m_writer.return_value.write_headers = mock.AsyncMock() resp = await req.send(conn) assert not m_writer.return_value.enable_compression.called @@ -940,7 +939,7 @@ async def test_chunked_explicit( ) -> None: req = ClientRequest("post", URL("http://python.org/"), chunked=True, loop=loop) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: - m_writer.return_value.write_headers = make_mocked_coro() + m_writer.return_value.write_headers = mock.AsyncMock() resp = await req.send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 75d3ee0d4b3..b3811973630 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -16,7 +16,6 @@ from aiohttp.client_reqrep import ClientResponse, RequestInfo from aiohttp.connector import Connection from aiohttp.helpers import TimerNoop -from aiohttp.test_utils import make_mocked_coro class WriterMock(mock.AsyncMock): @@ -1142,7 +1141,7 @@ async def test_response_read_triggers_callback( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: trace = mock.Mock() - trace.send_response_chunk_received = make_mocked_coro() + trace.send_response_chunk_received = mock.AsyncMock() response_method = "get" response_url = URL("http://def-cl-resp.org") response_body = b"This is response" diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 974d330a3c9..1fc05b04a4e 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -33,7 +33,6 @@ from aiohttp.cookiejar import CookieJar from aiohttp.http import RawResponseMessage from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer -from aiohttp.test_utils import make_mocked_coro from aiohttp.tracing import Trace @@ -791,9 +790,9 @@ async def handler(request: web.Request) -> web.Response: trace_config_ctx = mock.Mock() body = "This is request body" gathered_req_headers: CIMultiDict[str] = CIMultiDict() - on_request_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_request_redirect = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_start = mock.AsyncMock() + on_request_redirect = mock.AsyncMock() + on_request_end = mock.AsyncMock() with io.BytesIO() as gathered_req_body, io.BytesIO() as gathered_res_body: @@ -872,7 +871,7 @@ async def redirect_handler(request: web.Request) -> NoReturn: app.router.add_get("/", root_handler) app.router.add_get("/redirect", redirect_handler) - mocks = [mock.Mock(side_effect=make_mocked_coro(mock.Mock())) for _ in range(7)] + mocks = [mock.AsyncMock() for _ in range(7)] ( on_request_start, on_request_redirect, @@ -963,8 +962,8 @@ def to_url(path: str) -> URL: async def test_request_tracing_exception() -> None: - on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_request_exception = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_end = mock.AsyncMock() + on_request_exception = mock.AsyncMock() trace_config = aiohttp.TraceConfig() trace_config.on_request_end.append(on_request_end) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 3b1bd00b093..500564896c8 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -18,7 +18,6 @@ from aiohttp.http import WS_KEY from aiohttp.http_websocket import WSMessageClose from aiohttp.streams import EofStream -from aiohttp.test_utils import make_mocked_coro async def test_ws_connect( @@ -383,7 +382,7 @@ async def test_close( m_req.return_value.set_result(mresp) writer = mock.Mock() WebSocketWriter.return_value = writer - writer.close = make_mocked_coro() + writer.close = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") @@ -492,7 +491,7 @@ async def test_close_exc( m_req.return_value.set_result(mresp) writer = mock.Mock() WebSocketWriter.return_value = writer - writer.close = make_mocked_coro() + writer.close = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") @@ -628,7 +627,7 @@ async def test_reader_read_exception( writer = mock.Mock() WebSocketWriter.return_value = writer - writer.close = make_mocked_coro() + writer.close = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") @@ -780,7 +779,7 @@ async def test_ws_connect_deflate_per_message( m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = WebSocketWriter.return_value = mock.Mock() - send_frame = writer.send_frame = make_mocked_coro() + send_frame = writer.send_frame = mock.AsyncMock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") diff --git a/tests/test_connector.py b/tests/test_connector.py index 1315dbcd485..5a342ef9641 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -49,7 +49,7 @@ _DNSCacheTable, ) from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer -from aiohttp.test_utils import make_mocked_coro, unused_port +from aiohttp.test_utils import unused_port from aiohttp.tracing import Trace @@ -290,7 +290,6 @@ async def test_create_conn() -> None: conn = aiohttp.BaseConnector() with pytest.raises(NotImplementedError): await conn._create_connection(object(), [], object()) # type: ignore[arg-type] - await conn.close() @@ -318,80 +317,90 @@ async def test_close(key: ConnectionKey) -> None: async def test_get(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: conn = aiohttp.BaseConnector() - assert await conn._get(key, []) is None + try: + assert await conn._get(key, []) is None - proto = create_mocked_conn(loop) - conn._conns[key] = deque([(proto, loop.time())]) - connection = await conn._get(key, []) - assert connection is not None - assert connection.protocol == proto - connection.close() - await conn.close() + proto = create_mocked_conn(loop) + conn._conns[key] = deque([(proto, loop.time())]) + connection = await conn._get(key, []) + assert connection is not None + assert connection.protocol == proto + connection.close() + finally: + await conn.close() async def test_get_unconnected_proto(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() key = ConnectionKey("localhost", 80, False, False, None, None, None) - assert await conn._get(key, []) is None + try: + assert await conn._get(key, []) is None - proto = create_mocked_conn(loop) - conn._conns[key] = deque([(proto, loop.time())]) - connection = await conn._get(key, []) - assert connection is not None - assert connection.protocol == proto - connection.close() + proto = create_mocked_conn(loop) + conn._conns[key] = deque([(proto, loop.time())]) + connection = await conn._get(key, []) + assert connection is not None + assert connection.protocol == proto + connection.close() - assert await conn._get(key, []) is None - conn._conns[key] = deque([(proto, loop.time())]) - proto.is_connected = lambda *args: False - assert await conn._get(key, []) is None - await conn.close() + assert await conn._get(key, []) is None + conn._conns[key] = deque([(proto, loop.time())]) + proto.is_connected = lambda *args: False + assert await conn._get(key, []) is None + finally: + await conn.close() async def test_get_unconnected_proto_ssl(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() key = ConnectionKey("localhost", 80, True, False, None, None, None) - assert await conn._get(key, []) is None + try: + assert await conn._get(key, []) is None - proto = create_mocked_conn(loop) - conn._conns[key] = deque([(proto, loop.time())]) - connection = await conn._get(key, []) - assert connection is not None - assert connection.protocol == proto - connection.close() + proto = create_mocked_conn(loop) + conn._conns[key] = deque([(proto, loop.time())]) + connection = await conn._get(key, []) + assert connection is not None + assert connection.protocol == proto + connection.close() - assert await conn._get(key, []) is None - conn._conns[key] = deque([(proto, loop.time())]) - proto.is_connected = lambda *args: False - assert await conn._get(key, []) is None - await conn.close() + assert await conn._get(key, []) is None + conn._conns[key] = deque([(proto, loop.time())]) + proto.is_connected = lambda *args: False + assert await conn._get(key, []) is None + finally: + await conn.close() async def test_get_expired(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() key = ConnectionKey("localhost", 80, False, False, None, None, None) - assert await conn._get(key, []) is None + try: + assert await conn._get(key, []) is None - proto = create_mocked_conn(loop) - conn._conns[key] = deque([(proto, loop.time() - 1000)]) - assert await conn._get(key, []) is None - assert not conn._conns - await conn.close() + proto = create_mocked_conn(loop) + conn._conns[key] = deque([(proto, loop.time() - 1000)]) + assert await conn._get(key, []) is None + assert not conn._conns + finally: + await conn.close() @pytest.mark.usefixtures("enable_cleanup_closed") async def test_get_expired_ssl(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector(enable_cleanup_closed=True) key = ConnectionKey("localhost", 80, True, False, None, None, None) - assert await conn._get(key, []) is None + try: + assert await conn._get(key, []) is None - proto = create_mocked_conn(loop) - transport = proto.transport - conn._conns[key] = deque([(proto, loop.time() - 1000)]) - assert await conn._get(key, []) is None - assert not conn._conns - assert conn._cleanup_closed_transports == [transport] - await conn.close() + proto = create_mocked_conn(loop) + transport = proto.transport + conn._conns[key] = deque([(proto, loop.time() - 1000)]) + assert await conn._get(key, []) is None + assert not conn._conns + assert conn._cleanup_closed_transports == [transport] + finally: + await conn.close() async def test_release_acquired(key: ConnectionKey) -> None: @@ -1414,10 +1423,10 @@ async def test_tcp_connector_dns_tracing( ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() - on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_resolvehost_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_cache_hit = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_cache_miss = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_resolvehost_start = mock.AsyncMock() + on_dns_resolvehost_end = mock.AsyncMock() + on_dns_cache_hit = mock.AsyncMock() + on_dns_cache_miss = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) @@ -1461,8 +1470,8 @@ async def test_tcp_connector_dns_tracing_cache_disabled( ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() - on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_resolvehost_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_resolvehost_start = mock.AsyncMock() + on_dns_resolvehost_end = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) @@ -1518,8 +1527,8 @@ async def test_tcp_connector_dns_tracing_throttle_requests( ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() - on_dns_cache_hit = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_cache_miss = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_dns_cache_hit = mock.AsyncMock() + on_dns_cache_miss = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) @@ -1662,8 +1671,8 @@ async def test_connect(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> N async def test_connect_tracing(loop: asyncio.AbstractEventLoop) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() - on_connection_create_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_connection_create_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_create_start = mock.AsyncMock() + on_connection_create_end = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) @@ -2668,8 +2677,8 @@ async def test_connect_queued_operation_tracing( ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() - on_connection_queued_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_connection_queued_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_queued_start = mock.AsyncMock() + on_connection_queued_end = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) @@ -2715,7 +2724,7 @@ async def test_connect_reuseconn_tracing( ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() - on_connection_reuseconn = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_connection_reuseconn = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) @@ -3219,7 +3228,7 @@ async def test_unix_connector_not_found(loop: asyncio.AbstractEventLoop) -> None @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") async def test_unix_connector_permission(loop: asyncio.AbstractEventLoop) -> None: - m = make_mocked_coro(raise_exception=PermissionError()) + m = mock.AsyncMock(side_effect=PermissionError()) with mock.patch.object(loop, "create_unix_connection", m): connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) @@ -3258,7 +3267,7 @@ async def test_named_pipe_connector_not_found( async def test_named_pipe_connector_permission( proactor_loop: asyncio.AbstractEventLoop, pipe_name: str ) -> None: - m = make_mocked_coro(raise_exception=PermissionError()) + m = mock.AsyncMock(side_effect=PermissionError()) with mock.patch.object(proactor_loop, "create_pipe_connection", m): asyncio.set_event_loop(proactor_loop) connector = aiohttp.NamedPipeConnector(pipe_name) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 4e0ca4b13ea..3add5a0a073 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -12,7 +12,6 @@ from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend from aiohttp.http_writer import _serialize_headers -from aiohttp.test_utils import make_mocked_coro @pytest.fixture @@ -787,7 +786,7 @@ async def test_write_calls_callback( transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: - on_chunk_sent = make_mocked_coro() + on_chunk_sent = mock.AsyncMock() msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) chunk = b"1" await msg.write(chunk) @@ -800,7 +799,7 @@ async def test_write_eof_calls_callback( transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: - on_chunk_sent = make_mocked_coro() + on_chunk_sent = mock.AsyncMock() msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) chunk = b"1" await msg.write_eof(chunk=chunk) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index a4baabb4047..906e9128995 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -12,7 +12,6 @@ from aiohttp.client_reqrep import ClientRequest, ClientResponse, Fingerprint from aiohttp.connector import _SSL_CONTEXT_VERIFIED from aiohttp.helpers import TimerNoop -from aiohttp.test_utils import make_mocked_coro class TestProxy(unittest.TestCase): @@ -21,7 +20,9 @@ class TestProxy(unittest.TestCase): } mocked_response = mock.Mock(**response_mock_attrs) clientrequest_mock_attrs = { - "return_value.send.return_value.start": make_mocked_coro(mocked_response), + "return_value.send.return_value.start": mock.AsyncMock( + return_value=mocked_response + ), } def setUp(self) -> None: diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 404b7ed1ae4..247b2daae13 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -255,16 +255,22 @@ async def test_uvloop_secure_https_proxy( uvloop_loop: asyncio.AbstractEventLoop, ) -> None: """Ensure HTTPS sites are accessible through a secure proxy without warning when using uvloop.""" - conn = aiohttp.TCPConnector() + conn = aiohttp.TCPConnector(force_close=True) sess = aiohttp.ClientSession(connector=conn) - url = URL("https://example.com") - - async with sess.get(url, proxy=secure_proxy_url, ssl=client_ssl_ctx) as response: - assert response.status == 200 - - await sess.close() - await conn.close() - await asyncio.sleep(0.1) + try: + url = URL("https://example.com") + + async with sess.get( + url, proxy=secure_proxy_url, ssl=client_ssl_ctx + ) as response: + assert response.status == 200 + # Ensure response body is read to completion + await response.read() + finally: + await sess.close() + await conn.close() + await asyncio.sleep(0) + await asyncio.sleep(0.1) @pytest.fixture diff --git a/tests/test_resolver.py b/tests/test_resolver.py index b2c8645e835..1bc779c1ecf 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,6 +1,8 @@ import asyncio +import gc import ipaddress import socket +from collections.abc import Generator from ipaddress import ip_address from typing import ( Any, @@ -22,6 +24,7 @@ AsyncResolver, DefaultResolver, ThreadedResolver, + _DNSResolverManager, ) try: @@ -45,6 +48,48 @@ ] +@pytest.fixture() +def check_no_lingering_resolvers() -> Generator[None, None, None]: + """Verify no resolvers remain after the test. + + This fixture should be used in any test that creates instances of + AsyncResolver or directly uses _DNSResolverManager. + """ + manager = _DNSResolverManager() + before = len(manager._loop_data) + yield + after = len(manager._loop_data) + if after > before: # pragma: no branch + # Force garbage collection to ensure weak references are updated + gc.collect() # pragma: no cover + after = len(manager._loop_data) # pragma: no cover + if after > before: # pragma: no cover + pytest.fail( # pragma: no cover + f"Lingering resolvers found: {(after - before)} " + "new AsyncResolver instances were not properly closed." + ) + + +@pytest.fixture() +def dns_resolver_manager() -> Generator[_DNSResolverManager, None, None]: + """Create a fresh _DNSResolverManager instance for testing. + + Saves and restores the singleton state to avoid affecting other tests. + """ + # Save the original instance + original_instance = _DNSResolverManager._instance + + # Reset the singleton + _DNSResolverManager._instance = None + + # Create and yield a fresh instance + try: + yield _DNSResolverManager() + finally: + # Clean up and restore the original instance + _DNSResolverManager._instance = original_instance + + class FakeAIODNSAddrInfoNode(NamedTuple): family: int @@ -139,6 +184,7 @@ async def fake(*args: Any, **kwargs: Any) -> Tuple[str, int]: @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_positive_ipv4_lookup( loop: asyncio.AbstractEventLoop, ) -> None: @@ -156,9 +202,11 @@ async def test_async_resolver_positive_ipv4_lookup( port=0, type=socket.SOCK_STREAM, ) + await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_positive_link_local_ipv6_lookup( loop: asyncio.AbstractEventLoop, ) -> None: @@ -180,9 +228,11 @@ async def test_async_resolver_positive_link_local_ipv6_lookup( type=socket.SOCK_STREAM, ) mock().getnameinfo.assert_called_with(("fe80::1", 0, 0, 3), _NAME_SOCKET_FLAGS) + await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_multiple_replies(loop: asyncio.AbstractEventLoop) -> None: with patch("aiodns.DNSResolver") as mock: ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] @@ -191,18 +241,22 @@ async def test_async_resolver_multiple_replies(loop: asyncio.AbstractEventLoop) real = await resolver.resolve("www.google.com") ipaddrs = [ipaddress.ip_address(x["host"]) for x in real] assert len(ipaddrs) > 3, "Expecting multiple addresses" + await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_negative_lookup(loop: asyncio.AbstractEventLoop) -> None: with patch("aiodns.DNSResolver") as mock: mock().getaddrinfo.side_effect = aiodns.error.DNSError() resolver = AsyncResolver() with pytest.raises(OSError): await resolver.resolve("doesnotexist.bla") + await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_no_hosts_in_getaddrinfo( loop: asyncio.AbstractEventLoop, ) -> None: @@ -211,6 +265,7 @@ async def test_async_resolver_no_hosts_in_getaddrinfo( resolver = AsyncResolver() with pytest.raises(OSError): await resolver.resolve("doesnotexist.bla") + await resolver.close() async def test_threaded_resolver_positive_lookup() -> None: @@ -314,6 +369,7 @@ async def test_close_for_threaded_resolver(loop: asyncio.AbstractEventLoop) -> N @pytest.mark.skipif(aiodns is None, reason="aiodns required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_close_for_async_resolver(loop: asyncio.AbstractEventLoop) -> None: resolver = AsyncResolver() await resolver.close() @@ -328,6 +384,7 @@ async def test_default_loop_for_threaded_resolver( @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_ipv6_positive_lookup( loop: asyncio.AbstractEventLoop, ) -> None: @@ -343,9 +400,11 @@ async def test_async_resolver_ipv6_positive_lookup( port=0, type=socket.SOCK_STREAM, ) + await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_error_messages_passed( loop: asyncio.AbstractEventLoop, ) -> None: @@ -357,9 +416,11 @@ async def test_async_resolver_error_messages_passed( await resolver.resolve("x.org") assert excinfo.value.strerror == "Test error message" + await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_error_messages_passed_no_hosts( loop: asyncio.AbstractEventLoop, ) -> None: @@ -371,8 +432,10 @@ async def test_async_resolver_error_messages_passed_no_hosts( await resolver.resolve("x.org") assert excinfo.value.strerror == "DNS lookup failed" + await resolver.close() +@pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_aiodns_not_present( loop: asyncio.AbstractEventLoop, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -382,6 +445,7 @@ async def test_async_resolver_aiodns_not_present( @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +@pytest.mark.usefixtures("check_no_lingering_resolvers") def test_aio_dns_is_default() -> None: assert DefaultResolver is AsyncResolver @@ -389,3 +453,164 @@ def test_aio_dns_is_default() -> None: @pytest.mark.skipif(getaddrinfo, reason="aiodns <3.2.0 required") def test_threaded_resolver_is_default() -> None: assert DefaultResolver is ThreadedResolver + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_dns_resolver_manager_sharing( + dns_resolver_manager: _DNSResolverManager, +) -> None: + """Test that the DNSResolverManager shares a resolver among AsyncResolver instances.""" + # Create two default AsyncResolver instances + resolver1 = AsyncResolver() + resolver2 = AsyncResolver() + + # Check that they share the same underlying resolver + assert resolver1._resolver is resolver2._resolver + + # Create an AsyncResolver with custom args + resolver3 = AsyncResolver(nameservers=["8.8.8.8"]) + + # Check that it has its own resolver + assert resolver1._resolver is not resolver3._resolver + + # Cleanup + await resolver1.close() + await resolver2.close() + await resolver3.close() + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_dns_resolver_manager_singleton( + dns_resolver_manager: _DNSResolverManager, +) -> None: + """Test that DNSResolverManager is a singleton.""" + # Create a second manager and check it's the same instance + manager1 = dns_resolver_manager + manager2 = _DNSResolverManager() + + assert manager1 is manager2 + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_dns_resolver_manager_resolver_lifecycle( + dns_resolver_manager: _DNSResolverManager, +) -> None: + """Test that DNSResolverManager creates and destroys resolver correctly.""" + manager = dns_resolver_manager + + # Initially there should be no resolvers + assert not manager._loop_data + + # Create a mock AsyncResolver for testing + mock_client = Mock(spec=AsyncResolver) + mock_client._loop = asyncio.get_running_loop() + + # Getting resolver should create one + mock_loop = mock_client._loop + resolver = manager.get_resolver(mock_client, mock_loop) + assert resolver is not None + assert manager._loop_data[mock_loop][0] is resolver + + # Getting it again should return the same instance + assert manager.get_resolver(mock_client, mock_loop) is resolver + + # Clean up + manager.release_resolver(mock_client, mock_loop) + assert not manager._loop_data + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_dns_resolver_manager_client_registration( + dns_resolver_manager: _DNSResolverManager, +) -> None: + """Test client registration and resolver release logic.""" + with patch("aiodns.DNSResolver") as mock: + # Create resolver instances + resolver1 = AsyncResolver() + resolver2 = AsyncResolver() + + # Both should use the same resolver from the manager + assert resolver1._resolver is resolver2._resolver + + # The manager should be tracking both clients + assert resolver1._manager is resolver2._manager + manager = resolver1._manager + assert manager is not None + loop = asyncio.get_running_loop() + _, client_set = manager._loop_data[loop] + assert len(client_set) == 2 + + # Close one resolver + await resolver1.close() + _, client_set = manager._loop_data[loop] + assert len(client_set) == 1 + + # Resolver should still exist + assert manager._loop_data # Not empty + + # Close the second resolver + await resolver2.close() + assert not manager._loop_data # Should be empty after closing all clients + + # Now all resolvers should be canceled and removed + assert not manager._loop_data # Should be empty + mock().cancel.assert_called_once() + + +@pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") +async def test_dns_resolver_manager_multiple_event_loops( + dns_resolver_manager: _DNSResolverManager, +) -> None: + """Test that DNSResolverManager correctly manages resolvers across different event loops.""" + # Create separate resolvers for each loop + resolver1 = Mock(name="resolver1") + resolver2 = Mock(name="resolver2") + + # Create a patch that returns different resolvers based on the loop argument + mock_resolver = Mock() + mock_resolver.side_effect = lambda loop=None, **kwargs: ( + resolver1 if loop is asyncio.get_running_loop() else resolver2 + ) + + with patch("aiodns.DNSResolver", mock_resolver): + manager = dns_resolver_manager + + # Create two mock clients on different loops + mock_client1 = Mock(spec=AsyncResolver) + mock_client1._loop = asyncio.get_running_loop() + + # Create a second event loop + loop2 = Mock(spec=asyncio.AbstractEventLoop) + mock_client2 = Mock(spec=AsyncResolver) + mock_client2._loop = loop2 + + # Get resolvers for both clients + loop1 = mock_client1._loop + loop2 = mock_client2._loop + + # Get the resolvers through the manager + manager_resolver1 = manager.get_resolver(mock_client1, loop1) + manager_resolver2 = manager.get_resolver(mock_client2, loop2) + + # Should be different resolvers for different loops + assert manager_resolver1 is resolver1 + assert manager_resolver2 is resolver2 + assert manager._loop_data[loop1][0] is resolver1 + assert manager._loop_data[loop2][0] is resolver2 + + # Release the first resolver + manager.release_resolver(mock_client1, loop1) + + # First loop's resolver should be gone, but second should remain + assert loop1 not in manager._loop_data + assert loop2 in manager._loop_data + + # Release the second resolver + manager.release_resolver(mock_client2, loop2) + + # Both resolvers should be gone + assert not manager._loop_data + + # Verify resolver cleanup + resolver1.cancel.assert_called_once() + resolver2.cancel.assert_called_once() diff --git a/tests/test_run_app.py b/tests/test_run_app.py index af71612ae19..bea5b1fa1ef 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -31,7 +31,6 @@ from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web from aiohttp.log import access_logger -from aiohttp.test_utils import make_mocked_coro from aiohttp.web_protocol import RequestHandler from aiohttp.web_runner import BaseRunner @@ -108,9 +107,9 @@ def f(*args: object) -> None: def test_run_app_http(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() - startup_handler = make_mocked_coro() + startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) - cleanup_handler = make_mocked_coro() + cleanup_handler = mock.AsyncMock() app.on_cleanup.append(cleanup_handler) web.run_app(app, print=stopper(patched_loop), loop=patched_loop) @@ -734,9 +733,9 @@ def test_startup_cleanup_signals_even_on_failure( patched_loop.create_server.side_effect = RuntimeError() # type: ignore[attr-defined] app = web.Application() - startup_handler = make_mocked_coro() + startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) - cleanup_handler = make_mocked_coro() + cleanup_handler = mock.AsyncMock() app.on_cleanup.append(cleanup_handler) with pytest.raises(RuntimeError): @@ -752,9 +751,9 @@ def test_run_app_coro(patched_loop: asyncio.AbstractEventLoop) -> None: async def make_app() -> web.Application: nonlocal startup_handler, cleanup_handler app = web.Application() - startup_handler = make_mocked_coro() + startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) - cleanup_handler = make_mocked_coro() + cleanup_handler = mock.AsyncMock() app.on_cleanup.append(cleanup_handler) return app diff --git a/tests/test_tracing.py b/tests/test_tracing.py index d884b4e753e..d989dacf57f 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -1,10 +1,10 @@ from types import SimpleNamespace from typing import Any, Tuple +from unittest import mock from unittest.mock import Mock import pytest -from aiohttp.test_utils import make_mocked_coro from aiohttp.tracing import ( Trace, TraceConfig, @@ -107,7 +107,7 @@ async def test_send( # type: ignore[misc] ) -> None: session = Mock() trace_request_ctx = Mock() - callback = Mock(side_effect=make_mocked_coro(Mock())) + callback = mock.AsyncMock() trace_config = TraceConfig() getattr(trace_config, "on_%s" % signal).append(callback) diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 62d69efb528..41a4fcba7ff 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -6,7 +6,6 @@ from aiohttp import log, web from aiohttp.pytest_plugin import AiohttpClient -from aiohttp.test_utils import make_mocked_coro from aiohttp.typedefs import Handler @@ -22,8 +21,8 @@ def test_app_call() -> None: async def test_app_register_on_finish() -> None: app = web.Application() - cb1 = make_mocked_coro(None) - cb2 = make_mocked_coro(None) + cb1 = mock.AsyncMock(return_value=None) + cb2 = mock.AsyncMock(return_value=None) app.on_cleanup.append(cb1) app.on_cleanup.append(cb2) app.freeze() diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index ffa27ec8acf..40fddc3aaf0 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -36,7 +36,6 @@ from aiohttp.compression_utils import ZLibBackend, ZLibCompressObjProtocol from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer -from aiohttp.test_utils import make_mocked_coro from aiohttp.typedefs import Handler, Middleware from aiohttp.web_protocol import RequestHandler @@ -2018,13 +2017,13 @@ async def handler(request: web.Request) -> web.Response: async def test_request_tracing(aiohttp_server: AiohttpServer) -> None: - on_request_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_dns_resolvehost_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_request_redirect = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_connection_create_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) - on_connection_create_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) + on_request_start = mock.AsyncMock() + on_request_end = mock.AsyncMock() + on_dns_resolvehost_start = mock.AsyncMock() + on_dns_resolvehost_end = mock.AsyncMock() + on_request_redirect = mock.AsyncMock() + on_connection_create_start = mock.AsyncMock() + on_connection_create_end = mock.AsyncMock() async def redirector(request: web.Request) -> NoReturn: raise web.HTTPFound(location=URL("/redirected")) diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index 4837cab030e..ee30e485f1b 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -1,7 +1,6 @@ from unittest import mock from aiohttp import web -from aiohttp.test_utils import make_mocked_coro async def serve(request: web.BaseRequest) -> web.Response: @@ -37,7 +36,7 @@ async def test_shutdown_no_timeout() -> None: handler = mock.Mock(spec_set=web.RequestHandler) handler._task_handler = None - handler.shutdown = make_mocked_coro(mock.Mock()) + handler.shutdown = mock.AsyncMock(return_value=mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) @@ -52,7 +51,7 @@ async def test_shutdown_timeout() -> None: manager = web.Server(serve) handler = mock.Mock() - handler.shutdown = make_mocked_coro(mock.Mock()) + handler.shutdown = mock.AsyncMock(return_value=mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 046bb89ac53..98384d7eabf 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -19,7 +19,7 @@ from aiohttp.http_writer import StreamWriter, _serialize_headers from aiohttp.multipart import BodyPartReader, MultipartWriter from aiohttp.payload import BytesPayload, StringPayload -from aiohttp.test_utils import make_mocked_coro, make_mocked_request +from aiohttp.test_utils import make_mocked_request from aiohttp.typedefs import LooseHeaders @@ -899,7 +899,7 @@ async def test_prepare_twice() -> None: async def test_prepare_calls_signal() -> None: app = mock.create_autospec(web.Application, spec_set=True) - sig = make_mocked_coro() + sig = mock.AsyncMock() app.on_response_prepare = aiosignal.Signal(app) app.on_response_prepare.append(sig) req = make_request("GET", "/", app=app) @@ -1171,8 +1171,8 @@ async def test_send_set_cookie_header( async def test_consecutive_write_eof() -> None: writer = mock.Mock() - writer.write_eof = make_mocked_coro() - writer.write_headers = make_mocked_coro() + writer.write_eof = mock.AsyncMock() + writer.write_headers = mock.AsyncMock() req = make_request("GET", "/", writer=writer) data = b"data" resp = web.Response(body=data) diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index 3d919f44c1e..f7f35a0b388 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -4,7 +4,7 @@ from unittest import mock from aiohttp import hdrs -from aiohttp.test_utils import make_mocked_coro, make_mocked_request +from aiohttp.test_utils import make_mocked_request from aiohttp.web_fileresponse import FileResponse MOCK_MODE = S_IFREG | S_IRUSR | S_IWUSR @@ -31,7 +31,7 @@ def test_using_gzip_if_header_present_and_file_available( file_sender = FileResponse(filepath) file_sender._path = filepath - file_sender._sendfile = make_mocked_coro(None) # type: ignore[method-assign] + file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) @@ -58,7 +58,7 @@ def test_gzip_if_header_not_present_and_file_available( file_sender = FileResponse(filepath) file_sender._path = filepath - file_sender._sendfile = make_mocked_coro(None) # type: ignore[method-assign] + file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) @@ -83,7 +83,7 @@ def test_gzip_if_header_not_present_and_file_not_available( file_sender = FileResponse(filepath) file_sender._path = filepath - file_sender._sendfile = make_mocked_coro(None) # type: ignore[method-assign] + file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) @@ -110,7 +110,7 @@ def test_gzip_if_header_present_and_file_not_available( file_sender = FileResponse(filepath) file_sender._path = filepath - file_sender._sendfile = make_mocked_coro(None) # type: ignore[method-assign] + file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) @@ -129,7 +129,7 @@ def test_status_controlled_by_user(loop: asyncio.AbstractEventLoop) -> None: file_sender = FileResponse(filepath, status=203) file_sender._path = filepath - file_sender._sendfile = make_mocked_coro(None) # type: ignore[method-assign] + file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 139d5fa073e..45719ce6012 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -12,7 +12,7 @@ from aiohttp.http import WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE from aiohttp.http_websocket import WSMessageClose from aiohttp.streams import EofStream -from aiohttp.test_utils import make_mocked_coro, make_mocked_request +from aiohttp.test_utils import make_mocked_request from aiohttp.web_ws import WebSocketReady @@ -421,9 +421,7 @@ async def test_receive_eofstream_in_reader( ws._reader = mock.Mock() exc = EofStream() - res = loop.create_future() - res.set_exception(exc) - ws._reader.read = make_mocked_coro(res) + ws._reader.read = mock.AsyncMock(side_effect=exc) assert ws._payload_writer is not None f = loop.create_future() f.set_result(True) @@ -442,9 +440,7 @@ async def test_receive_exception_in_reader( ws._reader = mock.Mock() exc = Exception() - res = loop.create_future() - res.set_exception(exc) - ws._reader.read = make_mocked_coro(res) + ws._reader.read = mock.AsyncMock(side_effect=exc) f = loop.create_future() assert ws._payload_writer is not None @@ -545,9 +541,7 @@ async def test_receive_timeouterror( assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] ws._reader = mock.Mock() - res = loop.create_future() - res.set_exception(asyncio.TimeoutError()) - ws._reader.read = make_mocked_coro(res) + ws._reader.read = mock.AsyncMock(side_effect=asyncio.TimeoutError()) with pytest.raises(asyncio.TimeoutError): await ws.receive() diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index ec22ac0b5eb..467ed8e94ce 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -9,13 +9,12 @@ from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.http import WebSocketReader, WebSocketWriter -from aiohttp.test_utils import make_mocked_coro @pytest.fixture def protocol() -> mock.Mock: ret = mock.Mock() - ret._drain_helper = make_mocked_coro() + ret._drain_helper = mock.AsyncMock() return ret