Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions plugboard/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,16 @@ async def _io_read_with_status_check(self) -> None:
otherwise another read attempt is made.
"""
read_timeout = 1e-3 if self._has_outputs and not self._has_inputs else None
done, pending = await asyncio.wait(
(
asyncio.create_task(self._periodic_status_check()),
asyncio.create_task(self.io.read(timeout=read_timeout)),
),
return_when=asyncio.FIRST_COMPLETED,
)
status_task = asyncio.create_task(self._periodic_status_check())
io_task = asyncio.create_task(self.io.read(timeout=read_timeout))
try:
done, pending = await asyncio.wait(
(status_task, io_task),
return_when=asyncio.FIRST_COMPLETED,
)
except BaseException:
status_task.cancel()
raise
for task in pending:
task.cancel()
for task in done:
Expand Down
82 changes: 81 additions & 1 deletion tests/integration/test_process_with_components_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
import typing as _t
from unittest.mock import patch

from aiofile import async_open
from pydantic import BaseModel
Expand All @@ -21,7 +22,7 @@
RayConnector,
)
from plugboard.events import Event
from plugboard.exceptions import NotInitialisedError, ProcessStatusError
from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError
from plugboard.process import LocalProcess, Process, RayProcess
from plugboard.schemas import ConnectorSpec, Status
from tests.conftest import ComponentTestHelper, zmq_connector_cls
Expand Down Expand Up @@ -456,3 +457,82 @@ async def test_event_driven_process_shutdown(
assert actuator.actions == [f"do_{i}" for i in range(ticks)]

await process.destroy()


_SHORT_TIMEOUT = 0.1


class ConstraintErrorComponent(ComponentTestHelper):
"""Component that raises a ConstraintError on the first step."""

io = IO(outputs=["out_1"])

async def step(self) -> None:
raise ConstraintError("Constraint violated")


class BackgroundTaskTracker(ComponentTestHelper):
"""Component that counts how many times _periodic_status_check loops after process ends.

Overrides _periodic_status_check without calling super() to avoid early termination
via ProcessStatusError, so we can detect if the task leaks after process failure.
"""

io = IO(inputs=["in_1"])
exports = ["background_run_count"]

def __init__(self, *args: _t.Any, **kwargs: _t.Any) -> None:
super().__init__(*args, **kwargs)
self.background_run_count: int = 0

async def step(self) -> None:
await super().step()

async def _periodic_status_check(self) -> None:
while True:
await asyncio.sleep(_SHORT_TIMEOUT)
self.background_run_count += 1


@pytest.mark.asyncio
async def test_constraint_error_stops_background_status_check() -> None:
"""Test that background status check tasks are cancelled when ConstraintError is raised.

Regression test for: a bug where the periodic status check task was not cancelled
when the process was cancelled due to a ConstraintError raised by another component.
"""
with patch("plugboard.component.component.IO_READ_TIMEOUT_SECONDS", _SHORT_TIMEOUT):
producer = ConstraintErrorComponent(name="producer")
consumer = BackgroundTaskTracker(name="consumer")

connector = AsyncioConnector(
spec=ConnectorSpec(source="producer.out_1", target="consumer.in_1")
)
process = LocalProcess(
components=[producer, consumer],
connectors=[connector],
)

await process.init()

with pytest.raises(ExceptionGroup) as exc_info:
await process.run()

# Verify the ConstraintError propagated
exceptions = exc_info.value.exceptions
assert any(isinstance(e, ConstraintError) for e in exceptions)

# Record background task count immediately after the process ends
count_after_failure = consumer.background_run_count

# Wait longer than the patched IO_READ_TIMEOUT_SECONDS to ensure any leaked
# background tasks would have had time to run
await asyncio.sleep(_SHORT_TIMEOUT * 5)

# Background tasks should NOT have run again after the process ended
assert consumer.background_run_count == count_after_failure, (
f"Background status check ran {consumer.background_run_count - count_after_failure} "
f"extra time(s) after process ended, indicating a task leak"
)

await process.destroy()
Loading