diff --git a/plugboard/component/component.py b/plugboard/component/component.py index bb2ec58e..be750b2d 100644 --- a/plugboard/component/component.py +++ b/plugboard/component/component.py @@ -395,13 +395,16 @@ async def _io_read_with_status_check(self) -> None: otherwise another read attempt is made. """ read_timeout = 1e-3 if self._has_outputs and not self._has_inputs else None - done, pending = await asyncio.wait( - ( - asyncio.create_task(self._periodic_status_check()), - asyncio.create_task(self.io.read(timeout=read_timeout)), - ), - return_when=asyncio.FIRST_COMPLETED, - ) + status_task = asyncio.create_task(self._periodic_status_check()) + io_task = asyncio.create_task(self.io.read(timeout=read_timeout)) + try: + done, pending = await asyncio.wait( + (status_task, io_task), + return_when=asyncio.FIRST_COMPLETED, + ) + except BaseException: + status_task.cancel() + raise for task in pending: task.cancel() for task in done: diff --git a/tests/integration/test_process_with_components_run.py b/tests/integration/test_process_with_components_run.py index b941cccf..fe047ae8 100644 --- a/tests/integration/test_process_with_components_run.py +++ b/tests/integration/test_process_with_components_run.py @@ -5,6 +5,7 @@ from pathlib import Path from tempfile import NamedTemporaryFile import typing as _t +from unittest.mock import patch from aiofile import async_open from pydantic import BaseModel @@ -21,7 +22,7 @@ RayConnector, ) from plugboard.events import Event -from plugboard.exceptions import NotInitialisedError, ProcessStatusError +from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError from plugboard.process import LocalProcess, Process, RayProcess from plugboard.schemas import ConnectorSpec, Status from tests.conftest import ComponentTestHelper, zmq_connector_cls @@ -456,3 +457,82 @@ async def test_event_driven_process_shutdown( assert actuator.actions == [f"do_{i}" for i in range(ticks)] await process.destroy() + + +_SHORT_TIMEOUT = 0.1 + + +class ConstraintErrorComponent(ComponentTestHelper): + """Component that raises a ConstraintError on the first step.""" + + io = IO(outputs=["out_1"]) + + async def step(self) -> None: + raise ConstraintError("Constraint violated") + + +class BackgroundTaskTracker(ComponentTestHelper): + """Component that counts how many times _periodic_status_check loops after process ends. + + Overrides _periodic_status_check without calling super() to avoid early termination + via ProcessStatusError, so we can detect if the task leaks after process failure. + """ + + io = IO(inputs=["in_1"]) + exports = ["background_run_count"] + + def __init__(self, *args: _t.Any, **kwargs: _t.Any) -> None: + super().__init__(*args, **kwargs) + self.background_run_count: int = 0 + + async def step(self) -> None: + await super().step() + + async def _periodic_status_check(self) -> None: + while True: + await asyncio.sleep(_SHORT_TIMEOUT) + self.background_run_count += 1 + + +@pytest.mark.asyncio +async def test_constraint_error_stops_background_status_check() -> None: + """Test that background status check tasks are cancelled when ConstraintError is raised. + + Regression test for: a bug where the periodic status check task was not cancelled + when the process was cancelled due to a ConstraintError raised by another component. + """ + with patch("plugboard.component.component.IO_READ_TIMEOUT_SECONDS", _SHORT_TIMEOUT): + producer = ConstraintErrorComponent(name="producer") + consumer = BackgroundTaskTracker(name="consumer") + + connector = AsyncioConnector( + spec=ConnectorSpec(source="producer.out_1", target="consumer.in_1") + ) + process = LocalProcess( + components=[producer, consumer], + connectors=[connector], + ) + + await process.init() + + with pytest.raises(ExceptionGroup) as exc_info: + await process.run() + + # Verify the ConstraintError propagated + exceptions = exc_info.value.exceptions + assert any(isinstance(e, ConstraintError) for e in exceptions) + + # Record background task count immediately after the process ends + count_after_failure = consumer.background_run_count + + # Wait longer than the patched IO_READ_TIMEOUT_SECONDS to ensure any leaked + # background tasks would have had time to run + await asyncio.sleep(_SHORT_TIMEOUT * 5) + + # Background tasks should NOT have run again after the process ended + assert consumer.background_run_count == count_after_failure, ( + f"Background status check ran {consumer.background_run_count - count_after_failure} " + f"extra time(s) after process ended, indicating a task leak" + ) + + await process.destroy()