From e2206498fe215cb357dbda9d3708dd604102e03e Mon Sep 17 00:00:00 2001 From: Jan Buchar Date: Fri, 27 Mar 2026 15:41:05 +0100 Subject: [PATCH] fix: Defer page object cleanup to make it accessible in error handlers --- src/crawlee/_types.py | 3 + .../_adaptive_playwright_crawler.py | 1 + src/crawlee/crawlers/_basic/_basic_crawler.py | 10 ++ .../_playwright/_playwright_crawler.py | 117 ++++++++++-------- .../crawlers/_basic/test_context_pipeline.py | 7 ++ .../_playwright/test_playwright_crawler.py | 36 +++++- tests/unit/test_router.py | 1 + 7 files changed, 121 insertions(+), 54 deletions(-) diff --git a/src/crawlee/_types.py b/src/crawlee/_types.py index f48b545a82..9df2acba59 100644 --- a/src/crawlee/_types.py +++ b/src/crawlee/_types.py @@ -661,6 +661,9 @@ class BasicCrawlingContext: log: logging.Logger """Logger instance.""" + register_deferred_cleanup: Callable[[Callable[[], Coroutine[None, None, None]]], None] + """Register an async callback to be called after request processing completes (including error handlers).""" + async def get_snapshot(self) -> PageSnapshot: """Get snapshot of crawled page.""" return PageSnapshot() diff --git a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py index ef64d1ba58..8a814c1b3f 100644 --- a/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py +++ b/src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py @@ -320,6 +320,7 @@ async def get_input_state( get_key_value_store=result.get_key_value_store, use_state=use_state_function, log=context.log, + register_deferred_cleanup=context.register_deferred_cleanup, ) try: diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 53c37416e0..b1e0268cc3 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -1413,6 +1413,8 @@ async def __run_task_function(self) -> None: proxy_info = await self._get_proxy_info(request, session) result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store, request=request) + deferred_cleanup: list[Callable[[], Awaitable[None]]] = [] + context = BasicCrawlingContext( request=result.request, session=session, @@ -1423,6 +1425,7 @@ async def __run_task_function(self) -> None: get_key_value_store=result.get_key_value_store, use_state=self._use_state, log=self._logger, + register_deferred_cleanup=deferred_cleanup.append, ) self._context_result_map[context] = result @@ -1509,6 +1512,13 @@ async def __run_task_function(self) -> None: ) raise + finally: + for cleanup in deferred_cleanup: + try: + await cleanup() + except Exception: # noqa: PERF203 + self._logger.exception('Error in deferred cleanup') + async def _run_request_handler(self, context: BasicCrawlingContext) -> None: context.request.state = RequestState.BEFORE_NAV await self._context_pipeline( diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index 0aaffd7100..0c9cad65a4 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import logging import warnings from datetime import timedelta @@ -236,6 +237,7 @@ async def _open_page( proxy_info=context.proxy_info, get_key_value_store=context.get_key_value_store, log=context.log, + register_deferred_cleanup=context.register_deferred_cleanup, page=crawlee_page.page, block_requests=partial(block_requests, page=crawlee_page.page), goto_options=GotoOptions(**self._goto_options), @@ -296,63 +298,73 @@ async def _navigate( The enhanced crawling context with the Playwright-specific features (page, response, enqueue_links, infinite_scroll and block_requests). """ - async with context.page: - if context.session: - session_cookies = context.session.cookies.get_cookies_as_playwright_format() - await self._update_cookies(context.page, session_cookies) - - if context.request.headers: - await context.page.set_extra_http_headers(context.request.headers.model_dump()) - # Navigate to the URL and get response. - if context.request.method != 'GET': - # Call the notification only once - warnings.warn( - 'Using other request methods than GET or adding payloads has a high impact on performance' - ' in recent versions of Playwright. Use only when necessary.', - category=UserWarning, - stacklevel=2, - ) + # Enter the page context manager, but defer its cleanup (page.close()) so the page stays open + # during error handler execution. + await context.page.__aenter__() - route_handler = self._prepare_request_interceptor( - method=context.request.method, - headers=context.request.headers, - payload=context.request.payload, - ) + async def _close_page() -> None: + with contextlib.suppress(Exception): + await context.page.__aexit__(None, None, None) - # Set route_handler only for current request - await context.page.route(context.request.url, route_handler) + context.register_deferred_cleanup(_close_page) - try: - async with self._shared_navigation_timeouts[id(context)] as remaining_timeout: - response = await context.page.goto( - context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options - ) - context.request.state = RequestState.AFTER_NAV - except playwright.async_api.TimeoutError as exc: - raise asyncio.TimeoutError from exc - - if response is None: - raise SessionError(f'Failed to load the URL: {context.request.url}') - - # Set the loaded URL to the actual URL after redirection. - context.request.loaded_url = context.page.url - - yield PlaywrightPostNavCrawlingContext( - request=context.request, - session=context.session, - add_requests=context.add_requests, - send_request=context.send_request, - push_data=context.push_data, - use_state=context.use_state, - proxy_info=context.proxy_info, - get_key_value_store=context.get_key_value_store, - log=context.log, - page=context.page, - block_requests=context.block_requests, - goto_options=context.goto_options, - response=response, + if context.session: + session_cookies = context.session.cookies.get_cookies_as_playwright_format() + await self._update_cookies(context.page, session_cookies) + + if context.request.headers: + await context.page.set_extra_http_headers(context.request.headers.model_dump()) + # Navigate to the URL and get response. + if context.request.method != 'GET': + # Call the notification only once + warnings.warn( + 'Using other request methods than GET or adding payloads has a high impact on performance' + ' in recent versions of Playwright. Use only when necessary.', + category=UserWarning, + stacklevel=2, ) + route_handler = self._prepare_request_interceptor( + method=context.request.method, + headers=context.request.headers, + payload=context.request.payload, + ) + + # Set route_handler only for current request + await context.page.route(context.request.url, route_handler) + + try: + async with self._shared_navigation_timeouts[id(context)] as remaining_timeout: + response = await context.page.goto( + context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options + ) + context.request.state = RequestState.AFTER_NAV + except playwright.async_api.TimeoutError as exc: + raise asyncio.TimeoutError from exc + + if response is None: + raise SessionError(f'Failed to load the URL: {context.request.url}') + + # Set the loaded URL to the actual URL after redirection. + context.request.loaded_url = context.page.url + + yield PlaywrightPostNavCrawlingContext( + request=context.request, + session=context.session, + add_requests=context.add_requests, + send_request=context.send_request, + push_data=context.push_data, + use_state=context.use_state, + proxy_info=context.proxy_info, + get_key_value_store=context.get_key_value_store, + log=context.log, + register_deferred_cleanup=context.register_deferred_cleanup, + page=context.page, + block_requests=context.block_requests, + goto_options=context.goto_options, + response=response, + ) + def _create_extract_links_function(self, context: PlaywrightPreNavCrawlingContext) -> ExtractLinksFunction: """Create a callback function for extracting links from context. @@ -508,6 +520,7 @@ async def _create_crawling_context( proxy_info=context.proxy_info, get_key_value_store=context.get_key_value_store, log=context.log, + register_deferred_cleanup=context.register_deferred_cleanup, page=context.page, goto_options=context.goto_options, response=context.response, diff --git a/tests/unit/crawlers/_basic/test_context_pipeline.py b/tests/unit/crawlers/_basic/test_context_pipeline.py index 51f5556cac..23f3686d16 100644 --- a/tests/unit/crawlers/_basic/test_context_pipeline.py +++ b/tests/unit/crawlers/_basic/test_context_pipeline.py @@ -41,6 +41,7 @@ async def test_calls_consumer_without_middleware() -> None: use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=lambda _: None, ) await pipeline(context, consumer) @@ -68,6 +69,7 @@ async def middleware_a(context: BasicCrawlingContext) -> AsyncGenerator[Enhanced use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=context.register_deferred_cleanup, ) events.append('middleware_a_out') @@ -85,6 +87,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=context.register_deferred_cleanup, ) events.append('middleware_b_out') @@ -100,6 +103,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=lambda _: None, ) await pipeline(context, consumer) @@ -126,6 +130,7 @@ async def test_wraps_consumer_errors() -> None: use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=lambda _: None, ) with pytest.raises(RequestHandlerError): @@ -155,6 +160,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=lambda _: None, ) with pytest.raises(ContextPipelineInitializationError): @@ -187,6 +193,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=lambda _: None, ) with pytest.raises(ContextPipelineFinalizationError): diff --git a/tests/unit/crawlers/_playwright/test_playwright_crawler.py b/tests/unit/crawlers/_playwright/test_playwright_crawler.py index 4ad886b384..d409e98483 100644 --- a/tests/unit/crawlers/_playwright/test_playwright_crawler.py +++ b/tests/unit/crawlers/_playwright/test_playwright_crawler.py @@ -21,7 +21,10 @@ service_locator, ) from crawlee.configuration import Configuration -from crawlee.crawlers import PlaywrightCrawler +from crawlee.crawlers import ( + PlaywrightCrawler, + PlaywrightCrawlingContext, +) from crawlee.fingerprint_suite import ( DefaultFingerprintGenerator, FingerprintGenerator, @@ -49,7 +52,6 @@ from crawlee.browsers._types import BrowserType from crawlee.crawlers import ( BasicCrawlingContext, - PlaywrightCrawlingContext, PlaywrightPostNavCrawlingContext, PlaywrightPreNavCrawlingContext, ) @@ -1203,3 +1205,33 @@ async def post_nav_hook_2(_context: PlaywrightPostNavCrawlingContext) -> None: 'post-navigation-hook 2', 'final handler', ] + + +async def test_error_handler_can_access_page(server_url: URL) -> None: + """Test that the error handler can access the Page object via PlaywrightCrawlingContext.""" + + crawler = PlaywrightCrawler(max_request_retries=2) + + request_handler = mock.AsyncMock(side_effect=RuntimeError('Intentional crash')) + crawler.router.default_handler(request_handler) + + error_handler_calls: list[str | None] = [] + + @crawler.error_handler + async def error_handler(context: BasicCrawlingContext | PlaywrightCrawlingContext, _error: Exception) -> None: + error_handler_calls.append( + await context.page.content() if isinstance(context, PlaywrightCrawlingContext) else None + ) + + failed_handler_calls: list[str | None] = [] + + @crawler.failed_request_handler + async def failed_handler(context: BasicCrawlingContext | PlaywrightCrawlingContext, _error: Exception) -> None: + failed_handler_calls.append( + await context.page.content() if isinstance(context, PlaywrightCrawlingContext) else None + ) + + await crawler.run([str(server_url / 'hello-world')]) + + assert error_handler_calls == [HELLO_WORLD.decode(), HELLO_WORLD.decode()] + assert failed_handler_calls == [HELLO_WORLD.decode()] diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index 343f9ffb23..c87a44d323 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -23,6 +23,7 @@ def __init__(self, *, label: str | None) -> None: use_state=AsyncMock(), get_key_value_store=AsyncMock(), log=logging.getLogger(), + register_deferred_cleanup=lambda _: None, )